/*
 * Copyright Myrrix Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.myrrix.online.generation;

import java.io.File;
import java.io.IOException;

import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;

import net.myrrix.common.collection.FastByIDMap;
import net.myrrix.common.collection.FastIDSet;
import net.myrrix.common.math.MatrixUtils;

/**
 * <p>Merges two model files into one model file. The models have to be "compatible" in order to make any sense,
 * in the sense that model 1 must map As to Bs and model 2, Bs to Cs, to make a model from As to Cs.</p>
 *
 * <p>The resulting model file can be plugged directly into another instance's working directory.</p>
 *
 * <p>Usage: MergeModels [model.bin.gz file 1] [model.bin.gz file 2] [merged model.bin.gz]</p>
 *
 * <p>This is a simple utility class and an experiment which may be removed.</p>
 *
 * @author Sean Owen
 * @since 1.0
 */
public final class MergeModels {

  private MergeModels() {
  }

  public static void main(String[] args) throws Exception {
    File model1File = new File(args[0]);
    File model2File = new File(args[1]);
    File mergedModelFile = new File(args[2]);
    merge(model1File, model2File, mergedModelFile);
  }

  public static void merge(File model1File, File model2File, File mergedModelFile) throws IOException {

    Generation model1 = GenerationSerializer.readGeneration(model1File);
    Generation model2 = GenerationSerializer.readGeneration(model2File);

    FastByIDMap<float[]> x1 = model1.getX();
    FastByIDMap<float[]> y1 = model1.getY();
    FastByIDMap<float[]> x2 = model2.getX();
    FastByIDMap<float[]> y2 = model2.getY();

    RealMatrix translation = multiply(y1, x2);

    FastByIDMap<float[]> xMerged = MatrixUtils.multiply(translation.transpose(), x1);

    FastIDSet emptySet = new FastIDSet();
    FastByIDMap<FastIDSet> knownItems = new FastByIDMap<FastIDSet>();
    LongPrimitiveIterator it = xMerged.keySetIterator();
    while (it.hasNext()) {
      knownItems.put(it.nextLong(), emptySet);
    }

    FastIDSet x1ItemTagIDs = model1.getItemTagIDs();
    FastIDSet y2UserTagIDs = model2.getUserTagIDs();

    Generation merged = new Generation(knownItems, xMerged, y2, x1ItemTagIDs, y2UserTagIDs);
    GenerationSerializer.writeGeneration(merged, mergedModelFile);
  }

  private static RealMatrix multiply(FastByIDMap<float[]> left, FastByIDMap<float[]> right) {
    int numRows = left.entrySet().iterator().next().getValue().length;
    int numCols = right.entrySet().iterator().next().getValue().length;
    double[][] translationData = new double[numRows][numCols];
    for (FastByIDMap.MapEntry<float[]> entry1 : left.entrySet()) {
      float[] leftCol = entry1.getValue();
      float[] rightRow = right.get(entry1.getKey());
      if (rightRow != null) {
        for (int row = 0; row < numRows; row++) {
          float leftColAtRow = leftCol[row];
          double[] translationDataAtRow = translationData[row];
          for (int col = 0; col < numCols; col++) {
            translationDataAtRow[col] += leftColAtRow * rightRow[col];
          }
        }
      }
    }
    return new Array2DRowRealMatrix(translationData);
  }

}