/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.spark.sql.matfast.util import com.esotericsoftware.kryo.{Kryo, Serializer} import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.sql.matfast.matrix._ class KryoMLMatrixSerializer extends Serializer[MLMatrix]{ private def getTypeInt(m: MLMatrix): Short = m match { case _: SparseMatrix => 0 case _: DenseMatrix => 1 case _ => -1 } override def write(kryo: Kryo, output: Output, matrix: MLMatrix) { output.writeShort(getTypeInt(matrix)) matrix match { case dense: DenseMatrix => output.writeInt(dense.numRows, true) output.writeInt(dense.numCols, true) output.writeInt(dense.values.length, true) dense.values.foreach(output.writeDouble) output.writeBoolean(dense.isTransposed) case sp: SparseMatrix => output.writeInt(sp.numRows, true) output.writeInt(sp.numCols, true) output.writeInt(sp.colPtrs.length, true) sp.colPtrs.foreach(x => output.writeInt(x, true)) output.writeInt(sp.rowIndices.length, true) sp.rowIndices.foreach(x => output.writeInt(x, true)) output.writeInt(sp.values.length, true) sp.values.foreach(output.writeDouble) output.writeBoolean(sp.isTransposed) } } override def read(kryo: Kryo, input: Input, typ: Class[MLMatrix]): MLMatrix = { val typInt = input.readShort() if (typInt == 1) { // DenseMatrix val numRows = input.readInt(true) val numCols = input.readInt(true) val dim = input.readInt(true) val values = Array.ofDim[Double](dim) for (i <- 0 until dim) values(i) = input.readDouble() val isTransposed = input.readBoolean() new DenseMatrix(numRows, numCols, values, isTransposed) } else if (typInt == 0) { // SparseMatrix val numRows = input.readInt(true) val numCols = input.readInt(true) val colPtrsDim = input.readInt(true) val colPtrs = Array.ofDim[Int](colPtrsDim) for (i <- 0 until colPtrsDim) colPtrs(i) = input.readInt(true) val rowIndicesDim = input.readInt(true) val rowIndices = Array.ofDim[Int](rowIndicesDim) for (i <- 0 until rowIndicesDim) rowIndices(i) = input.readInt(true) val valueDim = input.readInt(true) val values = Array.ofDim[Double](valueDim) for (i <- 0 until valueDim) values(i) = input.readDouble() val isTransposed = input.readBoolean() new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) } else null } }