package org.example.recommendation; import com.google.common.collect.Sets; import org.apache.predictionio.controller.java.PJavaAlgorithm; import org.apache.predictionio.data.storage.Event; import org.apache.predictionio.data.store.java.LJavaEventStore; import org.apache.predictionio.data.store.java.OptionHelper; import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.recommendation.ALS; import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; import org.apache.spark.mllib.recommendation.Rating; import org.apache.spark.rdd.RDD; import org.jblas.DoubleMatrix; import org.joda.time.DateTime; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.Option; import scala.Tuple2; import scala.concurrent.duration.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; public class Algorithm extends PJavaAlgorithm<PreparedData, Model, Query, PredictedResult> { private static final Logger logger = LoggerFactory.getLogger(Algorithm.class); private final AlgorithmParams ap; public Algorithm(AlgorithmParams ap) { this.ap = ap; } @Override public Model train(SparkContext sc, PreparedData preparedData) { TrainingData data = preparedData.getTrainingData(); // user stuff JavaPairRDD<String, Integer> userIndexRDD = data.getUsers().map(new Function<Tuple2<String, User>, String>() { @Override public String call(Tuple2<String, User> idUser) throws Exception { return idUser._1(); } }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() { @Override public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception { return new Tuple2<>(element._1(), element._2().intValue()); } }); final Map<String, Integer> userIndexMap = userIndexRDD.collectAsMap(); // item stuff JavaPairRDD<String, Integer> itemIndexRDD = data.getItems().map(new Function<Tuple2<String, Item>, String>() { @Override public String call(Tuple2<String, Item> idItem) throws Exception { return idItem._1(); } }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() { @Override public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception { return new Tuple2<>(element._1(), element._2().intValue()); } }); final Map<String, Integer> itemIndexMap = itemIndexRDD.collectAsMap(); JavaPairRDD<Integer, String> indexItemRDD = itemIndexRDD.mapToPair(new PairFunction<Tuple2<String, Integer>, Integer, String>() { @Override public Tuple2<Integer, String> call(Tuple2<String, Integer> element) throws Exception { return element.swap(); } }); final Map<Integer, String> indexItemMap = indexItemRDD.collectAsMap(); // ratings stuff JavaRDD<Rating> ratings = data.getViewEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() { @Override public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent viewEvent) throws Exception { Integer userIndex = userIndexMap.get(viewEvent.getUser()); Integer itemIndex = itemIndexMap.get(viewEvent.getItem()); return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1); } }).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() { @Override public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception { return (element != null); } }).reduceByKey(new Function2<Integer, Integer, Integer>() { @Override public Integer call(Integer integer, Integer integer2) throws Exception { return integer + integer2; } }).map(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Rating>() { @Override public Rating call(Tuple2<Tuple2<Integer, Integer>, Integer> userItemCount) throws Exception { return new Rating(userItemCount._1()._1(), userItemCount._1()._2(), userItemCount._2().doubleValue()); } }); if (ratings.isEmpty()) throw new AssertionError("Please check if your events contain valid user and item ID."); // MLlib ALS stuff MatrixFactorizationModel matrixFactorizationModel = ALS.trainImplicit(JavaRDD.toRDD(ratings), ap.getRank(), ap.getIteration(), ap.getLambda(), -1, 1.0, ap.getSeed()); JavaPairRDD<Integer, double[]> userFeatures = matrixFactorizationModel.userFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() { @Override public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception { return new Tuple2<>((Integer) element._1(), element._2()); } }); JavaPairRDD<Integer, double[]> productFeaturesRDD = matrixFactorizationModel.productFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() { @Override public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception { return new Tuple2<>((Integer) element._1(), element._2()); } }); // popularity scores JavaRDD<ItemScore> itemPopularityScore = data.getBuyEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() { @Override public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent buyEvent) throws Exception { Integer userIndex = userIndexMap.get(buyEvent.getUser()); Integer itemIndex = itemIndexMap.get(buyEvent.getItem()); return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1); } }).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() { @Override public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception { return (element != null); } }).mapToPair(new PairFunction<Tuple2<Tuple2<Integer, Integer>, Integer>, Integer, Integer>() { @Override public Tuple2<Integer, Integer> call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception { return new Tuple2<>(element._1()._2(), element._2()); } }).reduceByKey(new Function2<Integer, Integer, Integer>() { @Override public Integer call(Integer integer, Integer integer2) throws Exception { return integer + integer2; } }).map(new Function<Tuple2<Integer, Integer>, ItemScore>() { @Override public ItemScore call(Tuple2<Integer, Integer> element) throws Exception { return new ItemScore(indexItemMap.get(element._1()), element._2().doubleValue()); } }); JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = indexItemRDD.join(productFeaturesRDD); return new Model(userFeatures, indexItemFeatures, userIndexRDD, itemIndexRDD, itemPopularityScore, data.getItems().collectAsMap()); } @Override public PredictedResult predict(Model model, final Query query) { final JavaPairRDD<String, Integer> matchedUser = model.getUserIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() { @Override public Boolean call(Tuple2<String, Integer> userIndex) throws Exception { return userIndex._1().equals(query.getUserEntityId()); } }); double[] userFeature = null; if (!matchedUser.isEmpty()) { final Integer matchedUserIndex = matchedUser.first()._2(); userFeature = model.getUserFeatures().filter(new Function<Tuple2<Integer, double[]>, Boolean>() { @Override public Boolean call(Tuple2<Integer, double[]> element) throws Exception { return element._1().equals(matchedUserIndex); } }).first()._2(); } if (userFeature != null) { return new PredictedResult(topItemsForUser(userFeature, model, query)); } else { List<double[]> recentProductFeatures = getRecentProductFeatures(query, model); if (recentProductFeatures.isEmpty()) { return new PredictedResult(mostPopularItems(model, query)); } else { return new PredictedResult(similarItems(recentProductFeatures, model, query)); } } } @Override public RDD<Tuple2<Object, PredictedResult>> batchPredict(Model model, RDD<Tuple2<Object, Query>> qs) { List<Tuple2<Object, Query>> indexQueries = qs.toJavaRDD().collect(); List<Tuple2<Object, PredictedResult>> results = new ArrayList<>(); for (Tuple2<Object, Query> indexQuery : indexQueries) { results.add(new Tuple2<>(indexQuery._1(), predict(model, indexQuery._2()))); } return new JavaSparkContext(qs.sparkContext()).parallelize(results).rdd(); } private List<double[]> getRecentProductFeatures(Query query, Model model) { try { List<double[]> result = new ArrayList<>(); List<Event> events = LJavaEventStore.findByEntity( ap.getAppName(), "user", query.getUserEntityId(), OptionHelper.<String>none(), OptionHelper.some(ap.getSimilarItemEvents()), OptionHelper.some(OptionHelper.some("item")), OptionHelper.<Option<String>>none(), OptionHelper.<DateTime>none(), OptionHelper.<DateTime>none(), OptionHelper.some(10), true, Duration.apply(10, TimeUnit.SECONDS)); for (final Event event : events) { if (event.targetEntityId().isDefined()) { JavaPairRDD<String, Integer> filtered = model.getItemIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() { @Override public Boolean call(Tuple2<String, Integer> element) throws Exception { return element._1().equals(event.targetEntityId().get()); } }); final Integer itemIndex = filtered.first()._2(); if (!filtered.isEmpty()) { JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = model.getIndexItemFeatures().filter(new Function<Tuple2<Integer, Tuple2<String, double[]>>, Boolean>() { @Override public Boolean call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception { return itemIndex.equals(element._1()); } }); List<Tuple2<Integer, Tuple2<String, double[]>>> oneIndexItemFeatures = indexItemFeatures.collect(); if (oneIndexItemFeatures.size() > 0) { result.add(oneIndexItemFeatures.get(0)._2()._2()); } } } } return result; } catch (Exception e) { logger.error("Error reading recent events for user " + query.getUserEntityId()); throw new RuntimeException(e.getMessage(), e); } } private List<ItemScore> topItemsForUser(double[] userFeature, Model model, Query query) { final DoubleMatrix userMatrix = new DoubleMatrix(userFeature); JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() { @Override public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception { return new ItemScore(element._2()._1(), userMatrix.dot(new DoubleMatrix(element._2()._2()))); } }); itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId()); return sortAndTake(itemScores, query.getNumber()); } private List<ItemScore> similarItems(final List<double[]> recentProductFeatures, Model model, Query query) { JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() { @Override public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception { double similarity = 0.0; for (double[] recentFeature : recentProductFeatures) { similarity += cosineSimilarity(element._2()._2(), recentFeature); } return new ItemScore(element._2()._1(), similarity); } }); itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId()); return sortAndTake(itemScores, query.getNumber()); } private List<ItemScore> mostPopularItems(Model model, Query query) { JavaRDD<ItemScore> itemScores = validScores(model.getItemPopularityScore(), query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId()); return sortAndTake(itemScores, query.getNumber()); } private double cosineSimilarity(double[] a, double[] b) { DoubleMatrix matrixA = new DoubleMatrix(a); DoubleMatrix matrixB = new DoubleMatrix(b); return matrixA.dot(matrixB) / (matrixA.norm2() * matrixB.norm2()); } private List<ItemScore> sortAndTake(JavaRDD<ItemScore> all, int number) { return all.sortBy(new Function<ItemScore, Double>() { @Override public Double call(ItemScore itemScore) throws Exception { return itemScore.getScore(); } }, false, all.partitions().size()).take(number); } private JavaRDD<ItemScore> validScores(JavaRDD<ItemScore> all, final Set<String> whitelist, final Set<String> blacklist, final Set<String> categories, final Map<String, Item> items, String userEntityId) { final Set<String> seenItemEntityIds = seenItemEntityIds(userEntityId); final Set<String> unavailableItemEntityIds = unavailableItemEntityIds(); return all.filter(new Function<ItemScore, Boolean>() { @Override public Boolean call(ItemScore itemScore) throws Exception { Item item = items.get(itemScore.getItemEntityId()); return (item != null && passWhitelistCriteria(whitelist, item.getEntityId()) && passBlacklistCriteria(blacklist, item.getEntityId()) && passCategoryCriteria(categories, item) && passUnseenCriteria(seenItemEntityIds, item.getEntityId()) && passAvailabilityCriteria(unavailableItemEntityIds, item.getEntityId())); } }); } private boolean passWhitelistCriteria(Set<String> whitelist, String itemEntityId) { return (whitelist.isEmpty() || whitelist.contains(itemEntityId)); } private boolean passBlacklistCriteria(Set<String> blacklist, String itemEntityId) { return !blacklist.contains(itemEntityId); } private boolean passCategoryCriteria(Set<String> categories, Item item) { return (categories.isEmpty() || Sets.intersection(categories, item.getCategories()).size() > 0); } private boolean passUnseenCriteria(Set<String> seen, String itemEntityId) { return !seen.contains(itemEntityId); } private boolean passAvailabilityCriteria(Set<String> unavailableItemEntityIds, String entityId) { return !unavailableItemEntityIds.contains(entityId); } private Set<String> unavailableItemEntityIds() { try { List<Event> unavailableConstraintEvents = LJavaEventStore.findByEntity( ap.getAppName(), "constraint", "unavailableItems", OptionHelper.<String>none(), OptionHelper.some(Collections.singletonList("$set")), OptionHelper.<Option<String>>none(), OptionHelper.<Option<String>>none(), OptionHelper.<DateTime>none(), OptionHelper.<DateTime>none(), OptionHelper.some(1), true, Duration.apply(10, TimeUnit.SECONDS)); if (unavailableConstraintEvents.isEmpty()) return Collections.emptySet(); Event unavailableConstraint = unavailableConstraintEvents.get(0); List<String> unavailableItems = unavailableConstraint.properties().getStringList("items"); return new HashSet<>(unavailableItems); } catch (Exception e) { logger.error("Error reading constraint events"); throw new RuntimeException(e.getMessage(), e); } } private Set<String> seenItemEntityIds(String userEntityId) { if (!ap.isUnseenOnly()) return Collections.emptySet(); try { Set<String> result = new HashSet<>(); List<Event> seenEvents = LJavaEventStore.findByEntity( ap.getAppName(), "user", userEntityId, OptionHelper.<String>none(), OptionHelper.some(ap.getSeenItemEvents()), OptionHelper.some(OptionHelper.some("item")), OptionHelper.<Option<String>>none(), OptionHelper.<DateTime>none(), OptionHelper.<DateTime>none(), OptionHelper.<Integer>none(), true, Duration.apply(10, TimeUnit.SECONDS)); for (Event event : seenEvents) { result.add(event.targetEntityId().get()); } return result; } catch (Exception e) { logger.error("Error reading seen events for user " + userEntityId); throw new RuntimeException(e.getMessage(), e); } } }