/*
 * Copyright Myrrix Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.myrrix.online.som;

import java.util.Collections;
import java.util.Comparator;

import com.google.common.base.Preconditions;
import org.apache.commons.math3.distribution.IntegerDistribution;
import org.apache.commons.math3.distribution.PascalDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;
import org.apache.commons.math3.util.Pair;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import net.myrrix.common.LangUtils;
import net.myrrix.common.collection.FastByIDMap;
import net.myrrix.common.math.SimpleVectorMath;
import net.myrrix.common.random.RandomManager;
import net.myrrix.common.random.RandomUtils;

/**
 * <p>This class implements a basic version of
 * <a href="http://en.wikipedia.org/wiki/Self-organizing_map">self-organizing maps</a>, or
 * <a href="http://www.scholarpedia.org/article/Kohonen_network">Kohonen network</a>. Self-organizing maps
 * bear some similarity to clustering techniques like
 * <a href="http://en.wikipedia.org/wiki/K-means_clustering">k-means</a>, in that they both try to discover
 * the centers of relatively close or similar groups of points in the input.</p>
 *
 * <p>K-means and other pure clustering algorithms try to find the centers which best reflect the input's structure.
 * Self-organizing maps have a different priority; the centers it is fitting are connected together as part of
 * a two-dimensional grid, and influence each other as they move. The result is like fitting an elastic 2D grid
 * of points to the input. This constraint results in less faithful clustering -- it is not even primarily a
 * clustering. But it does result in a project of points onto a 2D surface that keeps similar things near
 * to each other -- a sort of randomized ad-hoc 2D map of the space.</p>
 *
 * @author Sean Owen
 * @since 1.0
 */
public final class SelfOrganizingMaps {

  private static final Logger log = LoggerFactory.getLogger(SelfOrganizingMaps.class);

  public static final double DEFAULT_MIN_DECAY = 0.00001;
  public static final double DEFAULT_INIT_LEARNING_RATE = 0.5;

  private final double minDecay;
  private final double initLearningRate;

  public SelfOrganizingMaps() {
    this(DEFAULT_MIN_DECAY, DEFAULT_INIT_LEARNING_RATE);
  }

  /**
   * @param minDecay learning rate decays over iterations; when the decay factor drops below this, stop iteration
   *   as further updates will do little.
   * @param initLearningRate initial learning rate, decaying over time, which controls how much a newly assigned
   *   vector will move vector centers.
   */
  public SelfOrganizingMaps(double minDecay, double initLearningRate) {
    Preconditions.checkArgument(minDecay > 0.0, "Min decay must be positive: {}", minDecay);
    Preconditions.checkArgument(initLearningRate > 0.0 && initLearningRate <= 1.0,
                                "Learning rate should be in (0,1]: {}", initLearningRate);
    this.minDecay = minDecay;
    this.initLearningRate = initLearningRate;
  }

  public Node[][] buildSelfOrganizedMap(FastByIDMap<float[]> vectors, int maxMapSize) {
    return buildSelfOrganizedMap(vectors, maxMapSize, Double.NaN);
  }

  /**
   * @param vectors user-feature or item-feature matrix from current computation generation
   * @param maxMapSize maximum desired dimension of the (square) 2D map
   * @param samplingRate fraction of input to consider when creating the map
   *   size overall, nodes will be pruned to remove least-matching assignments, and not all vectors in the
   *   input will be assigned.
   * @return a square, 2D array of {@link Node} representing the map, with dimension {@code mapSize}
   */
  public Node[][] buildSelfOrganizedMap(FastByIDMap<float[]> vectors, int maxMapSize, double samplingRate) {

    Preconditions.checkNotNull(vectors);
    Preconditions.checkArgument(!vectors.isEmpty());
    Preconditions.checkArgument(maxMapSize > 0);
    Preconditions.checkArgument(Double.isNaN(samplingRate) || (samplingRate > 0.0 && samplingRate <= 1.0));

    if (Double.isNaN(samplingRate)) {
      // Compute a sampling rate that shoots for 1 assignment per node on average
      double expectedNodeSize = (double) vectors.size() / (maxMapSize * maxMapSize);
      samplingRate = expectedNodeSize > 1.0 ? 1.0 / expectedNodeSize : 1.0;
    }
    log.debug("Sampling rate: {}", samplingRate);

    int mapSize = FastMath.min(maxMapSize, (int) FastMath.sqrt(vectors.size() * samplingRate));
    Node[][] map = buildInitialMap(vectors, mapSize);

    sketchMapParallel(vectors, samplingRate, map);

    for (Node[] mapRow : map) {
      for (Node node : mapRow) {
        node.clearAssignedIDs();
      }
    }

    assignVectorsParallel(vectors, samplingRate, map);
    sortMembers(map);

    int numFeatures = vectors.entrySet().iterator().next().getValue().length;
    buildProjections(numFeatures, map);

    return map;
  }

  private void sketchMapParallel(FastByIDMap<float[]> vectors, double samplingRate, Node[][] map) {
    int mapSize = map.length;
    double sigma = (vectors.size() * samplingRate) / Math.log(mapSize);
    int t = 0;
    for (FastByIDMap.MapEntry<float[]> entry : vectors.entrySet()) {
      float[] V = entry.getValue();
      double decayFactor = FastMath.exp(-t / sigma);
      t++;
      if (decayFactor < minDecay) {
        break;
      }
      int[] bmuCoordinates = findBestMatchingUnit(V, map);
      if (bmuCoordinates != null) {      
        updateNeighborhood(map, V, bmuCoordinates[0], bmuCoordinates[1], decayFactor);
      }
    }
  }

  private static void assignVectorsParallel(FastByIDMap<float[]> vectors, double samplingRate, Node[][] map) {
    boolean doSample = samplingRate < 1.0;
    RandomGenerator random = RandomManager.getRandom();
    for (FastByIDMap.MapEntry<float[]> entry : vectors.entrySet()) {
      if (doSample && random.nextDouble() > samplingRate) {
        continue;
      }
      float[] V = entry.getValue();
      int[] bmuCoordinates = findBestMatchingUnit(V, map);
      if (bmuCoordinates != null) {
        Node node = map[bmuCoordinates[0]][bmuCoordinates[1]];
        float[] center = node.getCenter();
        double currentScore =
            SimpleVectorMath.dot(V, center) / (SimpleVectorMath.norm(center) * SimpleVectorMath.norm(V));
        Pair<Double,Long> newAssignedID = new Pair<Double,Long>(currentScore, entry.getKey());
        node.addAssignedID(newAssignedID);
      }
    }
  }

  /**
   * @return map of initialized {@link Node}s, where each node is empty and initialized to a randomly chosen
   *  input vector normalized to unit length
   */
  private static Node[][] buildInitialMap(FastByIDMap<float[]> vectors, int mapSize) {

    double p = ((double) mapSize * mapSize) / vectors.size(); // Choose mapSize^2 out of # vectors
    IntegerDistribution pascalDistribution;
    if (p >= 1.0) {
      // No sampling at all, we can't fill the map with one pass even
      pascalDistribution = null;
    } else {
      // Number of un-selected elements to skip between selections is geometrically distributed with
      // parameter p; this is the same as a negative binomial / Pascal distribution with r=1:
      pascalDistribution = new PascalDistribution(RandomManager.getRandom(), 1, p);
    }

    LongPrimitiveIterator keyIterator = vectors.keySetIterator();
    Node[][] map = new Node[mapSize][mapSize];
    for (Node[] mapRow : map) {
      for (int j = 0; j < mapSize; j++) {
        if (pascalDistribution != null) {
          keyIterator.skip(pascalDistribution.sample());
        }
        while (!keyIterator.hasNext()) {
          keyIterator = vectors.keySetIterator(); // Start over, a little imprecise but affects it not much
          Preconditions.checkState(keyIterator.hasNext());
          if (pascalDistribution != null) {
            keyIterator.skip(pascalDistribution.sample());
          }
        }
        float[] sampledVector = vectors.get(keyIterator.nextLong());
        mapRow[j] = new Node(sampledVector);
      }
    }
    return map;
  }

  /**
   * @return coordinates of {@link Node} in map whose center is "closest" to the given vector. Here closeness
   *  is defined as smallest angle between the vectors
   */
  private static int[] findBestMatchingUnit(float[] vector, Node[][] map) {
    int mapSize = map.length;
    double vectorNorm = SimpleVectorMath.norm(vector);
    double bestScore = Double.NEGATIVE_INFINITY;
    int bestI = -1;
    int bestJ = -1;
    for (int i = 0; i < mapSize; i++) {
      Node[] mapRow = map[i];
      for (int j = 0; j < mapSize; j++) {
        float[] center = mapRow[j].getCenter();
        double currentScore = SimpleVectorMath.dot(vector, center) / (SimpleVectorMath.norm(center) * vectorNorm);
        if (LangUtils.isFinite(currentScore) && currentScore > bestScore) {
          bestScore = currentScore;
          bestI = i;
          bestJ = j;
        }
      }
    }
    return bestI == -1 || bestJ == -1 ? null : new int[] {bestI, bestJ};
  }

  /**
   * Completes the update step after assigning an input vector tentatively to a {@link Node}. The assignment
   * causes nearby nodes (including the assigned one) to move their centers towards the vector.
   */
  private void updateNeighborhood(Node[][] map, float[] V, int bmuI, int bmuJ, double decayFactor) {
    int mapSize = map.length;
    double neighborhoodRadius = mapSize * decayFactor;

    int minI = FastMath.max(0, (int) FastMath.floor(bmuI - neighborhoodRadius));
    int maxI = FastMath.min(mapSize, (int) FastMath.ceil(bmuI + neighborhoodRadius));
    int minJ = FastMath.max(0, (int) FastMath.floor(bmuJ - neighborhoodRadius));
    int maxJ = FastMath.min(mapSize, (int) FastMath.ceil(bmuJ + neighborhoodRadius));

    for (int i = minI; i < maxI; i++) {
      Node[] mapRow = map[i];
      for (int j = minJ; j < maxJ; j++) {
        double learningRate = initLearningRate * decayFactor;
        double currentDistance = distance(i, j, bmuI, bmuJ);
        double theta = FastMath.exp(-(currentDistance * currentDistance) /
                                     (2.0 * neighborhoodRadius * neighborhoodRadius));
        double learningTheta = learningRate * theta;
        float[] center = mapRow[j].getCenter();
        int length = center.length;
        // Don't synchronize, for performance. Colliding updates once in a while does little.
        for (int k = 0; k < length; k++) {
          center[k] += (float) (learningTheta * (V[k] - center[k]));
        }
      }
    }
  }

  private static void sortMembers(Node[][] map) {
    for (Node[] mapRow : map) {
      for (Node node : mapRow) {
        Collections.sort(node.getAssignedIDs(), new Comparator<Pair<Double,Long>>() {
          @Override
          public int compare(Pair<Double,Long> a, Pair<Double,Long> b) {
            if (a.getFirst() > b.getFirst()) {
              return -1;
            }
            if (a.getFirst() < b.getFirst()) {
              return 1;
            }
            return 0;
          }
        });
      }
    }
  }

  private static void buildProjections(int numFeatures, Node[][] map) {
    int mapSize = map.length;
    float[] mean = new float[numFeatures];
    for (Node[] mapRow : map) {
      for (int j = 0; j < mapSize; j++) {
        add(mapRow[j].getCenter(), mean);
      }
    }
    divide(mean, mapSize * mapSize);

    RandomGenerator random = RandomManager.getRandom();
    float[] rBasis = RandomUtils.randomUnitVector(numFeatures, random);
    float[] gBasis = RandomUtils.randomUnitVector(numFeatures, random);
    float[] bBasis = RandomUtils.randomUnitVector(numFeatures, random);

    for (Node[] mapRow : map) {
      for (int j = 0; j < mapSize; j++) {
        float[] W = mapRow[j].getCenter().clone();
        subtract(mean, W);
        double norm = SimpleVectorMath.norm(W);
        float[] projection3D = mapRow[j].getProjection3D();
        projection3D[0] = (float) ((1.0 + SimpleVectorMath.dot(W, rBasis) / norm) / 2.0);
        projection3D[1] = (float) ((1.0 + SimpleVectorMath.dot(W, gBasis) / norm) / 2.0);
        projection3D[2] = (float) ((1.0 + SimpleVectorMath.dot(W, bBasis) / norm) / 2.0);
      }
    }
  }

  private static void add(float[] from, float[] to) {
    int length = from.length;
    for (int i = 0; i < length; i++) {
      to[i] += from[i];
    }
  }

  private static void subtract(float[] toSubtract, float[] from) {
    int length = toSubtract.length;
    for (int i = 0; i < length; i++) {
      from[i] -= toSubtract[i];
    }
  }

  private static void divide(float[] x, float by) {
    int length = x.length;
    for (int i = 0; i < length; i++) {
      x[i] /= by;
    }
  }

  private static double distance(int i1, int j1, int i2, int j2) {
    int diff1 = i1 - i2;
    int diff2 = j1 - j2;
    return FastMath.sqrt(diff1 * diff1 + diff2 * diff2);
  }

}