通过Spark进行ALS离线和Stream实时推荐

ALS简介

ALS是alternating least squares的缩写 , 意为交替最小二乘法;而ALS-WR是alternating-least-squares with weighted-λ -regularization的缩写,意为加权正则化交替最小二乘法。该方法常用于基于矩阵分解的推荐系统中。例如:将用户(user)对商品(item)的评分矩阵分解为两个矩阵:一个是用户对商品隐含特征的偏好矩阵,另一个是商品所包含的隐含特征的矩阵。在这个矩阵分解的过程中,评分缺失项得到了填充,也就是说我们可以基于这个填充的评分来给用户最商品推荐了。
ALS is the abbreviation of squares alternating least, meaning the alternating least squares method; and the ALS-WR is alternating-least-squares with weighted- lambda -regularization acronym, meaning weighted regularized alternating least squares method. This method is often used in recommender systems based on matrix factorization. For example, the user (user) score matrix of item is decomposed into two matrices: one is the user preference matrix for the implicit features of the commodity, and the other is the matrix of the implied features of the commodity. In the process of decomposing the matrix, the score missing is filled, that is, we can give the user the most recommended commodity based on the filled score.

ALS-WR算法,简单地说就是:
(数据格式为:userId, itemId, rating, timestamp )
1 对每个userId随机初始化N(10)个factor值,由这些值影响userId的权重。
2 对每个itemId也随机初始化N(10)个factor值。
3 固定userId,从userFactors矩阵和rating矩阵中分解出itemFactors矩阵。即[Item Factors Matrix] = [User Factors Matrix]^-1 * [Rating Matrix].
4 固定itemId,从itemFactors矩阵和rating矩阵中分解出userFactors矩阵。即[User Factors Matrix] = [Item Factors Matrix]^-1 * [Rating Matrix].
5 重复迭代第3,第4步,最后可以收敛到稳定的userFactors和itemFactors。
6 对itemId进行推断就为userFactors * itemId = rating value;对userId进行推断就为itemFactors * userId = rating value。

Spark支持ML和MLLIB两种机器学习库,官方推荐的是ML, 因为ML功能更全面更灵活,未来会主要支持ML。

 

ML实现ALS推荐:

import java.io.Serializable;

import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;

/**
 * @category ALS-WR
 * @author huangyueran
 *
 */
public class JavaALSExampleByMl {

	public static class Rating implements Serializable {
		// 0::2::3::1424380312
		private int userId; // 0
		private int movieId; // 2
		private float rating; // 3
		private long timestamp; // 1424380312

		public Rating() {
		}

		public Rating(int userId, int movieId, float rating, long timestamp) {
			this.userId = userId;
			this.movieId = movieId;
			this.rating = rating;
			this.timestamp = timestamp;
		}

		public int getUserId() {
			return userId;
		}

		public int getMovieId() {
			return movieId;
		}

		public float getRating() {
			return rating;
		}

		public long getTimestamp() {
			return timestamp;
		}

		public static Rating parseRating(String str) {
			String[] fields = str.split("::");
			if (fields.length != 4) {
				throw new IllegalArgumentException("Each line must contain 4 fields");
			}
			int userId = Integer.parseInt(fields[0]);
			int movieId = Integer.parseInt(fields[1]);
			float rating = Float.parseFloat(fields[2]);
			long timestamp = Long.parseLong(fields[3]);
			return new Rating(userId, movieId, rating, timestamp);
		}
	}

	public static void main(String[] args) {
		Logger.getLogger("org").setLevel(Level.WARN);
		SparkConf conf = new SparkConf().setAppName("JavaALSExample").setMaster("local");
		JavaSparkContext jsc = new JavaSparkContext(conf);
		SQLContext sqlContext = new SQLContext(jsc);

		JavaRDD ratingsRDD = jsc.textFile("data/sample_movielens_ratings.txt")
				.map(new Function() {
					public Rating call(String str) {
						return Rating.parseRating(str);
					}
				});
		DataFrame ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class);
		DataFrame[] splits = ratings.randomSplit(new double[] { 0.8, 0.2 }); // //对数据进行分割,80%为训练样例,剩下的为测试样例。
		DataFrame training = splits[0];
		DataFrame test = splits[1];

		// Build the recommendation model using ALS on the training data
		ALS als = new ALS().setMaxIter(5) // 设置迭代次数
				.setRegParam(0.01) // //正则化参数,使每次迭代平滑一些,此数据集取0.1好像错误率低一些。
				.setUserCol("userId").setItemCol("movieId")
				.setRatingCol("rating");
		ALSModel model = als.fit(training); // //调用算法开始训练
		
		
		DataFrame itemFactors = model.itemFactors();
		itemFactors.show(1500);
		DataFrame userFactors = model.userFactors();
		userFactors.show();
		
		// Evaluate the model by computing the RMSE on the test data
		DataFrame rawPredictions = model.transform(test); //对测试数据进行预测
		DataFrame predictions = rawPredictions
				.withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType))
				.withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType));

		RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating")
				.setPredictionCol("prediction");
		Double rmse = evaluator.evaluate(predictions);
		System.out.println("Root-mean-square error = " + rmse); // 均方根误差;
		
		jsc.stop();
	}
}

 

MLLIB实现ALS推荐:

import java.io.File;
import java.io.IOException;

import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaDoubleRDD;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;

import scala.Tuple2;

/**
 * @category ALS
 * @author huangyueran
 *
 */
public class JavaALSExampleByMlLib {

	public static void main(String[] args) {
		Logger.getLogger("org").setLevel(Level.WARN);
		SparkConf conf = new SparkConf().setAppName("JavaALSExample").setMaster("local[4]");
		JavaSparkContext jsc = new JavaSparkContext(conf);

		JavaRDD data = jsc.textFile("data/sample_movielens_ratings.txt");

		JavaRDD ratings = data.map(new Function() {
			public Rating call(String s) {
				String[] sarray = StringUtils.split(StringUtils.trim(s), "::");
				return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]),
						Double.parseDouble(sarray[2]));
			}
		});

		// Build the recommendation model using ALS
		int rank = 10;
		int numIterations = 6;
		MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01);

		// Evaluate the model on rating data
		JavaRDD> userProducts = ratings.map(new Function>() {
			public Tuple2 call(Rating r) {
				return new Tuple2(r.user(), r.product());
			}
		});

		// 预测的评分
		JavaPairRDD, Double> predictions = JavaPairRDD
				.fromJavaRDD(model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD()
						.map(new Function, Double>>() {
							public Tuple2, Double> call(Rating r) {
								return new Tuple2, Double>(
										new Tuple2(r.user(), r.product()), r.rating());
							}
						}));

		JavaPairRDD, Tuple2> ratesAndPreds = JavaPairRDD
				.fromJavaRDD(ratings.map(new Function, Double>>() {
					public Tuple2, Double> call(Rating r) {
						return new Tuple2, Double>(
								new Tuple2(r.user(), r.product()), r.rating());
					}
				})).join(predictions);

		// 得到按照用户ID排序后的评分列表 key:用户id
		JavaPairRDD> fromJavaRDD = JavaPairRDD.fromJavaRDD(ratesAndPreds.map(
				new Function, Tuple2>, Tuple2>>() {

					@Override
					public Tuple2> call(
							Tuple2, Tuple2> t) throws Exception {
						// TODO Auto-generated method stub
						return new Tuple2>(t._1._1,
								new Tuple2(t._1._2, t._2._2));
					}
				})).sortByKey(false);
		
//		List>> list = fromJavaRDD.collect();
//		for(Tuple2> t:list){
//			System.out.println(t._1+":"+t._2._1+"===="+t._2._2);
//		}

		JavaRDD> ratesAndPredsValues = ratesAndPreds.values();

		double MSE = JavaDoubleRDD.fromRDD(ratesAndPredsValues.map(new Function, Object>() {
			public Object call(Tuple2 pair) {
				Double err = pair._1() - pair._2();
				return err * err;
			}
		}).rdd()).mean();

		try {
			FileUtils.deleteDirectory(new File("result"));
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}

		ratesAndPreds.repartition(1).saveAsTextFile("result/ratesAndPreds.txt");

		//为指定用户推荐10个商品(电影)
		Rating[] recommendProducts = model.recommendProducts(2, 10);
		for(Rating r:recommendProducts){
			System.out.println(r.toString());
		}
		
		// 为所有用户推荐TOP N个物品
		//model.recommendUsersForProducts(10);
		
		// 为所有物品推荐TOP N个用户
		//model.recommendProductsForUsers(10)
		
		model.userFeatures().saveAsTextFile("result/userFea.txt");
		model.productFeatures().saveAsTextFile("result/productFea.txt");
		System.out.println("Mean Squared Error = " + MSE);

	}

}

 

以上两种主要是通过Spark进行离线的ALS推荐。还有一种是通过Spark-Streaming流式计算,对像Kafka消息队列中,缓冲的实时数据进行在线(实时)计算。

 

Spark-Streaming进行ALS实时推荐:

通过Spark-Streaming进行ALS推荐仅仅是其中的一环。真实项目中还涉及了很多其他技术处理。

比如用户行为日志数据的埋点处理,通过flume来进行监控拉取,存储到hdfs中。通过kafka来进行海量行为数据的消费、缓冲。

以及通过Spark机器学习计算后生成的训练模型的离线存储,Web拉取模型进行缓存,对用户进行推荐等等。

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;

import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.rdd.RDD;
import org.apache.spark.streaming.Durations;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairInputDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import org.apache.spark.streaming.kafka.KafkaUtils;

import kafka.serializer.StringDecoder;
import scala.Tuple2;

/**
 * @category 基于Spark-streaming、kafka的实时推荐模板DEMO 原系统中包含商城项目、logback、flume、hadoop
 * The real time recommendation template DEMO based on Spark-streaming and Kafka contains the mall project, logback, flume and Hadoop in the original system
 * @author huangyueran
 *
 */
public final class SparkALSByStreaming {

	//	基于Hadoop、Flume、Kafka、spark-streaming、logback、商城系统的实时推荐系统DEMO
	//	Real time recommendation system DEMO based on Hadoop, Flume, Kafka, spark-streaming, logback and mall system
	//	商城系统采集的数据集格式 Data Format:
	//	用户ID,商品ID,用户行为评分,时间戳
	//	UserID,ItemId,Rating,TimeStamp
	//	53,1286513,9,1508221762
	//	53,1172348420,9,1508221762
	//	53,1179495514,12,1508221762
	//	53,1184890730,3,1508221762
	//	53,1210793742,159,1508221762
	//	53,1215837445,9,1508221762
	
	public static void main(String[] args) {
		System.setProperty("HADOOP_USER_NAME", "root"); // 设置权限用户
		
		SparkConf sparkConf = new SparkConf().setAppName("JavaKafkaDirectWordCount").setMaster("local[1]");

		final JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(6));

		Map kafkaParams = new HashMap(); // key是topic名称,value是线程数量
		kafkaParams.put("metadata.broker.list", "master:9092,slave1:9092,slave2:9092"); // 指定broker在哪
		HashSet topicsSet = new HashSet();
		topicsSet.add("taotao-server-recommend-logs"); // 指定操作的topic

		// Create direct kafka stream with brokers and topics
		// createDirectStream()
		JavaPairInputDStream messages = KafkaUtils.createDirectStream(jssc, String.class, String.class,
				StringDecoder.class, StringDecoder.class, kafkaParams, topicsSet);

		JavaDStream lines = messages.map(new Function, String>() {
			@Override
			public String call(Tuple2 tuple2) {
				return tuple2._2();
			}
		});

		JavaDStream ratingsStream = lines.map(new Function() {
			public Rating call(String s) {
				String[] sarray = StringUtils.split(StringUtils.trim(s), ",");
				return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]),
						Double.parseDouble(sarray[2]));
			}
		});

		// 进行流推荐计算
		ratingsStream.foreachRDD(new Function, Void>() {

			@Override
			public Void call(JavaRDD ratings) throws Exception {
				// TODO 获取到原始的数据集
				SparkContext sc = ratings.context();

				RDD textFileRDD = sc.textFile("hdfs://master:8020/flume/logs", 3); // 读取原始数据集文件
				JavaRDD originalTextFile = textFileRDD.toJavaRDD();

				final JavaRDD originaldatas = originalTextFile.map(new Function() {
					public Rating call(String s) {
						String[] sarray = StringUtils.split(StringUtils.trim(s), ",");
						return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]),
								Double.parseDouble(sarray[2]));
					}
				});
				System.out.println("====================");
				System.out.println("originalTextFile count:" + originalTextFile.count()); // HDFS中已经存储的原始用户行为日志数据
				System.out.println("====================");

				// TODO 将原始数据集和新的用户行为数据进行合并
				JavaRDD calculations = originaldatas.union(ratings);

				System.out.println("计算总数:" + calculations.count());

				// Build the recommendation model using ALS
				int rank = 10; // 模型中隐语义因子的个数
				int numIterations = 6; // 训练次数

				// 得到训练模型
				if (null != ratings && !ratings.isEmpty()) { // 如果有用户行为数据
					MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(calculations), rank, numIterations, 0.01);
					// TODO 判断文件是否存在,如果存在 删除文件目录
					Configuration hadoopConfiguration = sc.hadoopConfiguration();
					hadoopConfiguration.set("fs.defaultFS", "hdfs://master:8020");
					FileSystem fs = FileSystem.get(hadoopConfiguration);
					Path outpath = new Path("/spark-als/model");
					if (fs.exists(outpath)) {
						//System.out.println("########### 删除"+outpath.getName()+" ###########");
						fs.delete(outpath, true);
					}
					
					// 保存model
					model.save(sc, "hdfs://master:8020/spark-als/model");
					// TODO 读取model
					//MatrixFactorizationModel modelLoad = MatrixFactorizationModel.load(sc, "hdfs://master:8020/spark-als/model");

					// 为指定用户推荐10个商品(电影)
					Rating[] recommendProducts = model.recommendProducts(53, 10);
					for (Rating r : recommendProducts) {
						System.out.println(r.toString());
					}
				}

				return null;
			}
		});

		// ==========================================================================================

		jssc.start();
		jssc.awaitTermination();

		// jssc.stop();
		// jssc.close();
	}

}

 

用户行为数据集

商城系统采集的数据集格式 Data Format:
用户ID,商品ID,用户行为评分,时间戳
UserID,ItemId,Rating,TimeStamp
53,1286513,9,1508221762
53,1172348420,9,1508221762
53,1179495514,12,1508221762
53,1184890730,3,1508221762
53,1210793742,159,1508221762
53,1215837445,9,1508221762

 

maven依赖


		
		
			org.apache.spark
			spark-core_2.10
			1.6.3
		
		
		
			org.apache.spark
			spark-mllib_2.10
			1.6.3
			provided
		
		
		
			org.apache.spark
			spark-sql_2.10
			1.6.3
		
		
		
			org.apache.spark
			spark-streaming_2.10
			1.6.3
			provided
		
		
		
			org.apache.spark
			spark-streaming-kafka_2.10
			1.6.3
		 
	

 

以上代码以及数据集可以去Github上的项目找到

https://github.com/huangyueranbbc/Spark_ALS 

 

 

 

 

 

 

 

 

你可能感兴趣的:(大数据,机器学习,spark)