/* * Copyright Myrrix Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package net.myrrix.online.generation; import java.io.File; import java.io.IOException; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.RealMatrix; import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; import net.myrrix.common.collection.FastByIDMap; import net.myrrix.common.collection.FastIDSet; import net.myrrix.common.math.MatrixUtils; /** * <p>Merges two model files into one model file. The models have to be "compatible" in order to make any sense, * in the sense that model 1 must map As to Bs and model 2, Bs to Cs, to make a model from As to Cs.</p> * * <p>The resulting model file can be plugged directly into another instance's working directory.</p> * * <p>Usage: MergeModels [model.bin.gz file 1] [model.bin.gz file 2] [merged model.bin.gz]</p> * * <p>This is a simple utility class and an experiment which may be removed.</p> * * @author Sean Owen * @since 1.0 */ public final class MergeModels { private MergeModels() { } public static void main(String[] args) throws Exception { File model1File = new File(args[0]); File model2File = new File(args[1]); File mergedModelFile = new File(args[2]); merge(model1File, model2File, mergedModelFile); } public static void merge(File model1File, File model2File, File mergedModelFile) throws IOException { Generation model1 = GenerationSerializer.readGeneration(model1File); Generation model2 = GenerationSerializer.readGeneration(model2File); FastByIDMap<float[]> x1 = model1.getX(); FastByIDMap<float[]> y1 = model1.getY(); FastByIDMap<float[]> x2 = model2.getX(); FastByIDMap<float[]> y2 = model2.getY(); RealMatrix translation = multiply(y1, x2); FastByIDMap<float[]> xMerged = MatrixUtils.multiply(translation.transpose(), x1); FastIDSet emptySet = new FastIDSet(); FastByIDMap<FastIDSet> knownItems = new FastByIDMap<FastIDSet>(); LongPrimitiveIterator it = xMerged.keySetIterator(); while (it.hasNext()) { knownItems.put(it.nextLong(), emptySet); } FastIDSet x1ItemTagIDs = model1.getItemTagIDs(); FastIDSet y2UserTagIDs = model2.getUserTagIDs(); Generation merged = new Generation(knownItems, xMerged, y2, x1ItemTagIDs, y2UserTagIDs); GenerationSerializer.writeGeneration(merged, mergedModelFile); } private static RealMatrix multiply(FastByIDMap<float[]> left, FastByIDMap<float[]> right) { int numRows = left.entrySet().iterator().next().getValue().length; int numCols = right.entrySet().iterator().next().getValue().length; double[][] translationData = new double[numRows][numCols]; for (FastByIDMap.MapEntry<float[]> entry1 : left.entrySet()) { float[] leftCol = entry1.getValue(); float[] rightRow = right.get(entry1.getKey()); if (rightRow != null) { for (int row = 0; row < numRows; row++) { float leftColAtRow = leftCol[row]; double[] translationDataAtRow = translationData[row]; for (int col = 0; col < numCols; col++) { translationDataAtRow[col] += leftColAtRow * rightRow[col]; } } } } return new Array2DRowRealMatrix(translationData); } }