package org.example.recommendation;

import com.google.common.collect.Sets;
import org.apache.predictionio.controller.java.PJavaAlgorithm;
import org.apache.predictionio.data.storage.Event;
import org.apache.predictionio.data.store.java.LJavaEventStore;
import org.apache.predictionio.data.store.java.OptionHelper;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.recommendation.ALS;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.rdd.RDD;
import org.jblas.DoubleMatrix;
import org.joda.time.DateTime;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;
import scala.Tuple2;
import scala.concurrent.duration.Duration;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;

public class Algorithm extends PJavaAlgorithm<PreparedData, Model, Query, PredictedResult> {

    private static final Logger logger = LoggerFactory.getLogger(Algorithm.class);
    private final AlgorithmParams ap;

    public Algorithm(AlgorithmParams ap) {
        this.ap = ap;
    }

    @Override
    public Model train(SparkContext sc, PreparedData preparedData) {
        TrainingData data = preparedData.getTrainingData();

        // user stuff
        JavaPairRDD<String, Integer> userIndexRDD = data.getUsers().map(new Function<Tuple2<String, User>, String>() {
            @Override
            public String call(Tuple2<String, User> idUser) throws Exception {
                return idUser._1();
            }
        }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() {
            @Override
            public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception {
                return new Tuple2<>(element._1(), element._2().intValue());
            }
        });
        final Map<String, Integer> userIndexMap = userIndexRDD.collectAsMap();

        // item stuff
        JavaPairRDD<String, Integer> itemIndexRDD = data.getItems().map(new Function<Tuple2<String, Item>, String>() {
            @Override
            public String call(Tuple2<String, Item> idItem) throws Exception {
                return idItem._1();
            }
        }).zipWithIndex().mapToPair(new PairFunction<Tuple2<String, Long>, String, Integer>() {
            @Override
            public Tuple2<String, Integer> call(Tuple2<String, Long> element) throws Exception {
                return new Tuple2<>(element._1(), element._2().intValue());
            }
        });
        final Map<String, Integer> itemIndexMap = itemIndexRDD.collectAsMap();
        JavaPairRDD<Integer, String> indexItemRDD = itemIndexRDD.mapToPair(new PairFunction<Tuple2<String, Integer>, Integer, String>() {
            @Override
            public Tuple2<Integer, String> call(Tuple2<String, Integer> element) throws Exception {
                return element.swap();
            }
        });
        final Map<Integer, String> indexItemMap = indexItemRDD.collectAsMap();

        // ratings stuff
        JavaRDD<Rating> ratings = data.getViewEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() {
            @Override
            public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent viewEvent) throws Exception {
                Integer userIndex = userIndexMap.get(viewEvent.getUser());
                Integer itemIndex = itemIndexMap.get(viewEvent.getItem());

                return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1);
            }
        }).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() {
            @Override
            public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
                return (element != null);
            }
        }).reduceByKey(new Function2<Integer, Integer, Integer>() {
            @Override
            public Integer call(Integer integer, Integer integer2) throws Exception {
                return integer + integer2;
            }
        }).map(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Rating>() {
            @Override
            public Rating call(Tuple2<Tuple2<Integer, Integer>, Integer> userItemCount) throws Exception {
                return new Rating(userItemCount._1()._1(), userItemCount._1()._2(), userItemCount._2().doubleValue());
            }
        });

        if (ratings.isEmpty())
            throw new AssertionError("Please check if your events contain valid user and item ID.");

        // MLlib ALS stuff
        MatrixFactorizationModel matrixFactorizationModel = ALS.trainImplicit(JavaRDD.toRDD(ratings), ap.getRank(), ap.getIteration(), ap.getLambda(), -1, 1.0, ap.getSeed());
        JavaPairRDD<Integer, double[]> userFeatures = matrixFactorizationModel.userFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() {
            @Override
            public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception {
                return new Tuple2<>((Integer) element._1(), element._2());
            }
        });
        JavaPairRDD<Integer, double[]> productFeaturesRDD = matrixFactorizationModel.productFeatures().toJavaRDD().mapToPair(new PairFunction<Tuple2<Object, double[]>, Integer, double[]>() {
            @Override
            public Tuple2<Integer, double[]> call(Tuple2<Object, double[]> element) throws Exception {
                return new Tuple2<>((Integer) element._1(), element._2());
            }
        });

        // popularity scores
        JavaRDD<ItemScore> itemPopularityScore = data.getBuyEvents().mapToPair(new PairFunction<UserItemEvent, Tuple2<Integer, Integer>, Integer>() {
            @Override
            public Tuple2<Tuple2<Integer, Integer>, Integer> call(UserItemEvent buyEvent) throws Exception {
                Integer userIndex = userIndexMap.get(buyEvent.getUser());
                Integer itemIndex = itemIndexMap.get(buyEvent.getItem());

                return (userIndex == null || itemIndex == null) ? null : new Tuple2<>(new Tuple2<>(userIndex, itemIndex), 1);
            }
        }).filter(new Function<Tuple2<Tuple2<Integer, Integer>, Integer>, Boolean>() {
            @Override
            public Boolean call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
                return (element != null);
            }
        }).mapToPair(new PairFunction<Tuple2<Tuple2<Integer, Integer>, Integer>, Integer, Integer>() {
            @Override
            public Tuple2<Integer, Integer> call(Tuple2<Tuple2<Integer, Integer>, Integer> element) throws Exception {
                return new Tuple2<>(element._1()._2(), element._2());
            }
        }).reduceByKey(new Function2<Integer, Integer, Integer>() {
            @Override
            public Integer call(Integer integer, Integer integer2) throws Exception {
                return integer + integer2;
            }
        }).map(new Function<Tuple2<Integer, Integer>, ItemScore>() {
            @Override
            public ItemScore call(Tuple2<Integer, Integer> element) throws Exception {
                return new ItemScore(indexItemMap.get(element._1()), element._2().doubleValue());
            }
        });

        JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = indexItemRDD.join(productFeaturesRDD);

        return new Model(userFeatures, indexItemFeatures, userIndexRDD, itemIndexRDD, itemPopularityScore, data.getItems().collectAsMap());
    }

    @Override
    public PredictedResult predict(Model model, final Query query) {
        final JavaPairRDD<String, Integer> matchedUser = model.getUserIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() {
            @Override
            public Boolean call(Tuple2<String, Integer> userIndex) throws Exception {
                return userIndex._1().equals(query.getUserEntityId());
            }
        });

        double[] userFeature = null;
        if (!matchedUser.isEmpty()) {
            final Integer matchedUserIndex = matchedUser.first()._2();
            userFeature = model.getUserFeatures().filter(new Function<Tuple2<Integer, double[]>, Boolean>() {
                @Override
                public Boolean call(Tuple2<Integer, double[]> element) throws Exception {
                    return element._1().equals(matchedUserIndex);
                }
            }).first()._2();
        }

        if (userFeature != null) {
            return new PredictedResult(topItemsForUser(userFeature, model, query));
        } else {
            List<double[]> recentProductFeatures = getRecentProductFeatures(query, model);
            if (recentProductFeatures.isEmpty()) {
                return new PredictedResult(mostPopularItems(model, query));
            } else {
                return new PredictedResult(similarItems(recentProductFeatures, model, query));
            }
        }
    }

    @Override
    public RDD<Tuple2<Object, PredictedResult>> batchPredict(Model model, RDD<Tuple2<Object, Query>> qs) {
        List<Tuple2<Object, Query>> indexQueries = qs.toJavaRDD().collect();
        List<Tuple2<Object, PredictedResult>> results = new ArrayList<>();

        for (Tuple2<Object, Query> indexQuery : indexQueries) {
            results.add(new Tuple2<>(indexQuery._1(), predict(model, indexQuery._2())));
        }

        return new JavaSparkContext(qs.sparkContext()).parallelize(results).rdd();
    }

    private List<double[]> getRecentProductFeatures(Query query, Model model) {
        try {
            List<double[]> result = new ArrayList<>();

            List<Event> events = LJavaEventStore.findByEntity(
                    ap.getAppName(),
                    "user",
                    query.getUserEntityId(),
                    OptionHelper.<String>none(),
                    OptionHelper.some(ap.getSimilarItemEvents()),
                    OptionHelper.some(OptionHelper.some("item")),
                    OptionHelper.<Option<String>>none(),
                    OptionHelper.<DateTime>none(),
                    OptionHelper.<DateTime>none(),
                    OptionHelper.some(10),
                    true,
                    Duration.apply(10, TimeUnit.SECONDS));

            for (final Event event : events) {
                if (event.targetEntityId().isDefined()) {
                    JavaPairRDD<String, Integer> filtered = model.getItemIndex().filter(new Function<Tuple2<String, Integer>, Boolean>() {
                        @Override
                        public Boolean call(Tuple2<String, Integer> element) throws Exception {
                            return element._1().equals(event.targetEntityId().get());
                        }
                    });

                    final Integer itemIndex = filtered.first()._2();

                    if (!filtered.isEmpty()) {

                        JavaPairRDD<Integer, Tuple2<String, double[]>> indexItemFeatures = model.getIndexItemFeatures().filter(new Function<Tuple2<Integer, Tuple2<String, double[]>>, Boolean>() {
                            @Override
                            public Boolean call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
                                return itemIndex.equals(element._1());
                            }
                        });

                        List<Tuple2<Integer, Tuple2<String, double[]>>> oneIndexItemFeatures = indexItemFeatures.collect();
                        if (oneIndexItemFeatures.size() > 0) {
                            result.add(oneIndexItemFeatures.get(0)._2()._2());
                        }
                    }
                }
            }

            return result;
        } catch (Exception e) {
            logger.error("Error reading recent events for user " + query.getUserEntityId());
            throw new RuntimeException(e.getMessage(), e);
        }
    }

    private List<ItemScore> topItemsForUser(double[] userFeature, Model model, Query query) {
        final DoubleMatrix userMatrix = new DoubleMatrix(userFeature);

        JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
            @Override
            public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
                return new ItemScore(element._2()._1(), userMatrix.dot(new DoubleMatrix(element._2()._2())));
            }
        });

        itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
        return sortAndTake(itemScores, query.getNumber());
    }

    private List<ItemScore> similarItems(final List<double[]> recentProductFeatures, Model model, Query query) {
        JavaRDD<ItemScore> itemScores = model.getIndexItemFeatures().map(new Function<Tuple2<Integer, Tuple2<String, double[]>>, ItemScore>() {
            @Override
            public ItemScore call(Tuple2<Integer, Tuple2<String, double[]>> element) throws Exception {
                double similarity = 0.0;
                for (double[] recentFeature : recentProductFeatures) {
                    similarity += cosineSimilarity(element._2()._2(), recentFeature);
                }

                return new ItemScore(element._2()._1(), similarity);
            }
        });

        itemScores = validScores(itemScores, query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
        return sortAndTake(itemScores, query.getNumber());
    }

    private List<ItemScore> mostPopularItems(Model model, Query query) {
        JavaRDD<ItemScore> itemScores = validScores(model.getItemPopularityScore(), query.getWhitelist(), query.getBlacklist(), query.getCategories(), model.getItems(), query.getUserEntityId());
        return sortAndTake(itemScores, query.getNumber());
    }

    private double cosineSimilarity(double[] a, double[] b) {
        DoubleMatrix matrixA = new DoubleMatrix(a);
        DoubleMatrix matrixB = new DoubleMatrix(b);

        return matrixA.dot(matrixB) / (matrixA.norm2() * matrixB.norm2());
    }

    private List<ItemScore> sortAndTake(JavaRDD<ItemScore> all, int number) {
        return all.sortBy(new Function<ItemScore, Double>() {
            @Override
            public Double call(ItemScore itemScore) throws Exception {
                return itemScore.getScore();
            }
        }, false, all.partitions().size()).take(number);
    }

    private JavaRDD<ItemScore> validScores(JavaRDD<ItemScore> all, final Set<String> whitelist, final Set<String> blacklist, final Set<String> categories, final Map<String, Item> items, String userEntityId) {
        final Set<String> seenItemEntityIds = seenItemEntityIds(userEntityId);
        final Set<String> unavailableItemEntityIds = unavailableItemEntityIds();

        return all.filter(new Function<ItemScore, Boolean>() {
            @Override
            public Boolean call(ItemScore itemScore) throws Exception {
                Item item = items.get(itemScore.getItemEntityId());

                return (item != null
                        && passWhitelistCriteria(whitelist, item.getEntityId())
                        && passBlacklistCriteria(blacklist, item.getEntityId())
                        && passCategoryCriteria(categories, item)
                        && passUnseenCriteria(seenItemEntityIds, item.getEntityId())
                        && passAvailabilityCriteria(unavailableItemEntityIds, item.getEntityId()));
            }
        });
    }

    private boolean passWhitelistCriteria(Set<String> whitelist, String itemEntityId) {
        return (whitelist.isEmpty() || whitelist.contains(itemEntityId));
    }

    private boolean passBlacklistCriteria(Set<String> blacklist, String itemEntityId) {
        return !blacklist.contains(itemEntityId);
    }

    private boolean passCategoryCriteria(Set<String> categories, Item item) {
        return (categories.isEmpty() || Sets.intersection(categories, item.getCategories()).size() > 0);
    }

    private boolean passUnseenCriteria(Set<String> seen, String itemEntityId) {
        return !seen.contains(itemEntityId);
    }

    private boolean passAvailabilityCriteria(Set<String> unavailableItemEntityIds, String entityId) {
        return !unavailableItemEntityIds.contains(entityId);
    }

    private Set<String> unavailableItemEntityIds() {
        try {
            List<Event> unavailableConstraintEvents = LJavaEventStore.findByEntity(
                    ap.getAppName(),
                    "constraint",
                    "unavailableItems",
                    OptionHelper.<String>none(),
                    OptionHelper.some(Collections.singletonList("$set")),
                    OptionHelper.<Option<String>>none(),
                    OptionHelper.<Option<String>>none(),
                    OptionHelper.<DateTime>none(),
                    OptionHelper.<DateTime>none(),
                    OptionHelper.some(1),
                    true,
                    Duration.apply(10, TimeUnit.SECONDS));

            if (unavailableConstraintEvents.isEmpty()) return Collections.emptySet();

            Event unavailableConstraint = unavailableConstraintEvents.get(0);

            List<String> unavailableItems = unavailableConstraint.properties().getStringList("items");

            return new HashSet<>(unavailableItems);
        } catch (Exception e) {
            logger.error("Error reading constraint events");
            throw new RuntimeException(e.getMessage(), e);
        }
    }

    private Set<String> seenItemEntityIds(String userEntityId) {
        if (!ap.isUnseenOnly()) return Collections.emptySet();

        try {
            Set<String> result = new HashSet<>();
            List<Event> seenEvents = LJavaEventStore.findByEntity(
                    ap.getAppName(),
                    "user",
                    userEntityId,
                    OptionHelper.<String>none(),
                    OptionHelper.some(ap.getSeenItemEvents()),
                    OptionHelper.some(OptionHelper.some("item")),
                    OptionHelper.<Option<String>>none(),
                    OptionHelper.<DateTime>none(),
                    OptionHelper.<DateTime>none(),
                    OptionHelper.<Integer>none(),
                    true,
                    Duration.apply(10, TimeUnit.SECONDS));

            for (Event event : seenEvents) {
                result.add(event.targetEntityId().get());
            }

            return result;
        } catch (Exception e) {
            logger.error("Error reading seen events for user " + userEntityId);
            throw new RuntimeException(e.getMessage(), e);
        }
    }
}