package edu.neu.ccs.pyramid.regression.linear_regression; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorView; import java.io.*; /** * Created by chengli on 2/18/15. */ public class Weights implements Serializable{ private static final long serialVersionUID = 1L; private int numFeatures; /** * vector is not serializable */ private transient Vector weightVector; /** * serialize this array instead */ private double[] serializableWeights; public Weights(int numFeatures) { this.numFeatures = numFeatures; this.weightVector = new DenseVector((numFeatures + 1)); this.serializableWeights = new double[(numFeatures + 1)]; } public Weights(int numFeatures, Vector weightVector) { this.numFeatures = numFeatures; if (weightVector.size()!=(numFeatures + 1)){ throw new IllegalArgumentException("weightVector.size()!=(numFeatures + 1)"); } this.weightVector = weightVector; this.serializableWeights = new double[(numFeatures + 1)]; } /** * * @return weights including bias at the beginning */ public Vector getWeights(){ return this.weightVector; } public void setWeightVector(Vector weightVector) { this.weightVector = weightVector; } /** * * @return weights , no bias */ public Vector getWeightsWithoutBias(){ int length = this.numFeatures; return new VectorView(this.weightVector,1,length); } /** * * @return bias */ public double getBias(){ return this.weightVector.get(0); } public void setBias(double bias){ this.weightVector.set(0,bias); } public void setWeight(int featureIndex, double weight){ this.weightVector.set(featureIndex+1,weight); } private void writeObject(java.io.ObjectOutputStream out) throws IOException { for (int i=0;i<serializableWeights.length;i++){ serializableWeights[i] = weightVector.get(i); } out.writeInt(numFeatures); out.writeObject(serializableWeights); } private void readObject(java.io.ObjectInputStream in) throws IOException, ClassNotFoundException{ numFeatures = in.readInt(); serializableWeights = (double[])in.readObject(); weightVector = new DenseVector((numFeatures + 1)); for (int i=0;i<serializableWeights.length;i++){ weightVector.set(i,serializableWeights[i]); } } void serialize(File file) throws Exception{ File parent = file.getParentFile(); if (!parent.exists()){ parent.mkdirs(); } try ( FileOutputStream fileOutputStream = new FileOutputStream(file); BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(fileOutputStream); ObjectOutputStream objectOutputStream = new ObjectOutputStream(bufferedOutputStream); ){ objectOutputStream.writeObject(this); } } public static Weights deserialize(File file) throws Exception{ try( FileInputStream fileInputStream = new FileInputStream(file); BufferedInputStream bufferedInputStream = new BufferedInputStream(fileInputStream); ObjectInputStream objectInputStream = new ObjectInputStream(bufferedInputStream); ){ return (Weights)objectInputStream.readObject(); } } @Override public String toString() { return "Weights{" + ", numFeatures=" + numFeatures + ", weightVector=" + weightVector + '}'; } }