org.apache.avro.SchemaBuilder Scala Examples

The following examples show how to use org.apache.avro.SchemaBuilder. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: AvroParquetSourceTest.scala    From eel-sdk   with Apache License 2.0 6 votes vote down vote up
package io.eels.component.parquet

import java.nio.file.Paths

import io.eels.component.parquet.avro.AvroParquetSource
import io.eels.component.parquet.util.ParquetLogMute
import io.eels.schema._
import org.apache.avro.SchemaBuilder
import org.apache.avro.generic.{GenericData, GenericRecord}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.parquet.avro.AvroParquetWriter
import org.scalatest.{Matchers, WordSpec}

class AvroParquetSourceTest extends WordSpec with Matchers {
  ParquetLogMute()

  private implicit val conf = new Configuration()
  private implicit val fs = FileSystem.get(conf)

  private val personFile = Paths.get(getClass.getResource("/io/eels/component/parquet/person.avro.pq").toURI)
  private val resourcesDir = personFile.getParent

  "AvroParquetSource" should {
    "read schema" in {
      val people = AvroParquetSource(personFile)
      people.schema shouldBe StructType(
        Field("name", StringType, nullable = false),
        Field("job", StringType, nullable = false),
        Field("location", StringType, nullable = false)
      )
    }
    "read parquet files" in {
      val people = AvroParquetSource(personFile.toAbsolutePath()).toDataStream().toSet.map(_.values)
      people shouldBe Set(
        Vector("clint eastwood", "actor", "carmel"),
        Vector("elton john", "musician", "pinner")
      )
    }
    "read multiple parquet files using file expansion" in {
      import io.eels.FilePattern._
      val people = AvroParquetSource(s"${resourcesDir.toUri.toString}/*.pq").toDataStream().toSet.map(_.values)
      people shouldBe Set(
        Vector("clint eastwood", "actor", "carmel"),
        Vector("elton john", "musician", "pinner"),
        Vector("clint eastwood", "actor", "carmel"),
        Vector("elton john", "musician", "pinner")
      )
    }
    // todo add merge to parquet source
    "merge schemas" ignore {

      try {
        fs.delete(new Path("merge1.pq"), false)
      } catch {
        case t: Throwable =>
      }
      try {
        fs.delete(new Path("merge2.pq"), false)
      } catch {
        case t: Throwable =>
      }

      val schema1 = SchemaBuilder.builder().record("schema1").fields().requiredString("a").requiredDouble("b").endRecord()
      val schema2 = SchemaBuilder.builder().record("schema2").fields().requiredInt("a").requiredBoolean("c").endRecord()

      val writer1 = AvroParquetWriter.builder[GenericRecord](new Path("merge1.pq")).withSchema(schema1).build()
      val record1 = new GenericData.Record(schema1)
      record1.put("a", "aaaaa")
      record1.put("b", 124.3)
      writer1.write(record1)
      writer1.close()

      val writer2 = AvroParquetWriter.builder[GenericRecord](new Path("merge2.pq")).withSchema(schema2).build()
      val record2 = new GenericData.Record(schema2)
      record2.put("a", 111)
      record2.put("c", true)
      writer2.write(record2)
      writer2.close()

      ParquetSource(new Path("merge*")).schema shouldBe
        StructType(
          Field("a", StringType, nullable = false),
          Field("b", DoubleType, nullable = false),
          Field("c", BooleanType, nullable = false)
        )

      fs.delete(new Path(".merge1.pq.crc"), false)
      fs.delete(new Path(".merge2.pq.crc"), false)
      fs.delete(new Path("merge1.pq"), false)
      fs.delete(new Path("merge2.pq"), false)
    }
  }
} 
Example 2
Source File: IDF.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.document.{Cell, PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.expression._
import com.ibm.aardpfark.pfa.types.WithSchema
import com.ibm.aardpfark.spark.ml.PFAModel
import com.sksamuel.avro4s.{AvroNamespace, AvroSchema}
import org.apache.avro.{Schema, SchemaBuilder}

import org.apache.spark.ml.feature.IDFModel

@AvroNamespace("com.ibm.aardpfark.exec.spark.ml.feature")
case class IDFData(idf: Seq[Double]) extends WithSchema {
  override def schema: Schema = AvroSchema[this.type]
}




  override def action: PFAExpression = {
    NewRecord(outputSchema, Map(outputCol -> a.zipmap(inputExpr, idfRef, multFn.ref)))
  }

  override def pfa: PFADocument = {
    PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      .withCell(modelCell)
      .withFunction(multFn)
      .withAction(action)
      .pfa
  }

} 
Example 3
Source File: GenericAvroSerializerSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.ByteBuffer

import com.esotericsoftware.kryo.io.{Output, Input}
import org.apache.avro.{SchemaBuilder, Schema}
import org.apache.avro.generic.GenericData.Record

import org.apache.spark.{SparkFunSuite, SharedSparkContext}

class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext {
  conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

  val schema : Schema = SchemaBuilder
    .record("testRecord").fields()
    .requiredString("data")
    .endRecord()
  val record = new Record(schema)
  record.put("data", "test data")

  test("schema compression and decompression") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema))))
  }

  test("record serialization and deserialization") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)

    val outputStream = new ByteArrayOutputStream()
    val output = new Output(outputStream)
    genericSer.serializeDatum(record, output)
    output.flush()
    output.close()

    val input = new Input(new ByteArrayInputStream(outputStream.toByteArray))
    assert(genericSer.deserializeDatum(input) === record)
  }

  test("uses schema fingerprint to decrease message size") {
    val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema)

    val output = new Output(new ByteArrayOutputStream())

    val beginningNormalPosition = output.total()
    genericSerFull.serializeDatum(record, output)
    output.flush()
    val normalLength = output.total - beginningNormalPosition

    conf.registerAvroSchemas(schema)
    val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema)
    val beginningFingerprintPosition = output.total()
    genericSerFinger.serializeDatum(record, output)
    val fingerprintLength = output.total - beginningFingerprintPosition

    assert(fingerprintLength < normalLength)
  }

  test("caches previously seen schemas") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    val compressedSchema = genericSer.compress(schema)
    val decompressedSchema = genericSer.decompress(ByteBuffer.wrap(compressedSchema))

    assert(compressedSchema.eq(genericSer.compress(schema)))
    assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema))))
  }
} 
Example 4
Source File: TestStreamlets.scala    From cloudflow   with Apache License 2.0 5 votes vote down vote up
package cloudflow.streamlets.descriptors

import scala.collection.immutable

import org.apache.avro.SchemaBuilder

import com.typesafe.config.Config

import cloudflow.streamlets._
import cloudflow.streamlets.avro.AvroUtil

case class Coffee(espressos: Int)

object Schemas {
  val coffeeSchema = SchemaBuilder
    .record("Coffee")
    .namespace("cloudflow.sbt")
    .fields()
    .name("expressos")
    .`type`()
    .nullable()
    .intType()
    .noDefault()
    .endRecord()
}

case object TestRuntime extends StreamletRuntime {
  override val name = "test-runtime"
}

trait TestStreamlet extends Streamlet[StreamletContext] {
  override def runtime: StreamletRuntime                                 = TestRuntime
  def logStartRunnerMessage(buildInfo: String): Unit                     = ???
  override protected def createContext(config: Config): StreamletContext = ???
  override def run(context: StreamletContext): StreamletExecution        = ???

}

class CoffeeIngress extends Streamlet[StreamletContext] with TestStreamlet {
  case class TestOutlet(name: String, schemaDefinition: SchemaDefinition) extends Outlet
  override val shape                                = StreamletShape(TestOutlet("out", AvroUtil.createSchemaDefinition(Schemas.coffeeSchema)))
  override val labels: immutable.IndexedSeq[String] = Vector("test", "coffee")
  override val description: String                  = "Coffee Ingress Test"
} 
Example 5
Source File: MaxAbsScaler.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.document.{Cell, PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.expression._
import com.ibm.aardpfark.pfa.types.WithSchema
import com.ibm.aardpfark.spark.ml.PFAModel
import com.sksamuel.avro4s.{AvroNamespace, AvroSchema}
import org.apache.avro.SchemaBuilder

import org.apache.spark.ml.feature.MaxAbsScalerModel

@AvroNamespace("com.ibm.aardpfark.exec.spark.ml.feature")
case class MaxAbsScalerModelData(maxAbs: Array[Double]) extends WithSchema {
  def schema = AvroSchema[this.type]
}

class PFAMaxAbsScalerModel(override val sparkTransformer: MaxAbsScalerModel) extends PFAModel[MaxAbsScalerModelData] {
  import com.ibm.aardpfark.pfa.dsl._

  private val inputCol = sparkTransformer.getInputCol
  private val outputCol = sparkTransformer.getOutputCol
  private val inputExpr = StringExpr(s"input.${inputCol}")

  // cell data
  private val scalerData = MaxAbsScalerModelData(sparkTransformer.maxAbs.toArray)
  override def cell = Cell(scalerData)
  // references to cell variables
  private val maxAbs = modelCell.ref("maxAbs")

  override def inputSchema = {
    SchemaBuilder.record(withUid(inputBaseName)).fields()
      .name(inputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  override def outputSchema = {
    SchemaBuilder.record(withUid(outputBaseName)).fields()
      .name(outputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  val ifFn = If (core.gt(StringExpr("s"), 0.0)) Then core.div("i", "s") Else StringExpr("i")
  val divDoubleFn = NamedFunctionDef(
    "divDouble",
    FunctionDef[Double, Double](Seq("i", "s"), Seq(ifFn))
  )

  override def action: PFAExpression = {
    NewRecord(outputSchema, Map(outputCol -> a.zipmap(inputExpr, maxAbs, divDoubleFn.ref)))
  }

  override def pfa: PFADocument = {
    PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      .withCell(modelCell)
      .withFunction(divDoubleFn)
      .withAction(action)
      .pfa
  }
} 
Example 6
Source File: ElementwiseProduct.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.document.{Cell, PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.expression._
import com.ibm.aardpfark.pfa.types.WithSchema
import com.ibm.aardpfark.spark.ml.PFAModel
import com.sksamuel.avro4s.{AvroNamespace, AvroSchema}
import org.apache.avro.SchemaBuilder

import org.apache.spark.ml.feature.ElementwiseProduct

@AvroNamespace("com.ibm.aardpfark.exec.spark.ml.feature")
case class ElementwiseProductData(scalingVec: Seq[Double]) extends WithSchema {
  def schema = AvroSchema[this.type]
}

class PFAElementwiseProduct(override val sparkTransformer: ElementwiseProduct) extends PFAModel[ElementwiseProductData] {
  import com.ibm.aardpfark.pfa.dsl._

  private val inputCol = sparkTransformer.getInputCol
  private val outputCol = sparkTransformer.getOutputCol
  private val inputExpr = StringExpr(s"input.${inputCol}")

  // cell data
  private val scalingVec = sparkTransformer.getScalingVec.toArray

  override def cell = Cell(ElementwiseProductData(scalingVec))
  // references to cell variables
  private val scalingVecRef = modelCell.ref("scalingVec")

  override def inputSchema = {
    SchemaBuilder.record(withUid(inputBaseName)).fields()
      .name(inputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  override def outputSchema = {
    SchemaBuilder.record(withUid(outputBaseName)).fields()
      .name(outputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  val scaleFn = NamedFunctionDef("doubleMult", FunctionDef[Double, Double](
    Seq("x", "y"),
    Seq(core.mult("x", "y"))
  ))

  override def action: PFAExpression = {
    val scale = a.zipmap(inputExpr, scalingVecRef, scaleFn.ref)
    NewRecord(outputSchema, Map(outputCol -> scale))
  }

  override def pfa: PFADocument = {
    PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      .withCell(modelCell)
      .withFunction(scaleFn)
      .withAction(action)
      .pfa
  }
} 
Example 7
Source File: VectorSelector.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.document.{Cell, PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.expression._
import com.ibm.aardpfark.pfa.types.WithSchema
import com.ibm.aardpfark.spark.ml.PFAModel
import com.sksamuel.avro4s.{AvroNamespace, AvroSchema}
import org.apache.avro.SchemaBuilder

@AvroNamespace("com.ibm.aardpfark.exec.spark.ml.feature")
case class VectorSelectorData(indices: Seq[Int]) extends WithSchema {
  override def schema = AvroSchema[this.type]
}

abstract class PFAVectorSelector extends PFAModel[VectorSelectorData] {
  import com.ibm.aardpfark.pfa.dsl._

  protected val inputCol: String
  protected val outputCol: String
  protected lazy val inputExpr = StringExpr(s"input.${inputCol}")

  protected val indices: Seq[Int]

  override protected def cell = Cell(VectorSelectorData(indices))

  override def inputSchema = {
    SchemaBuilder.record(withUid(inputBaseName)).fields()
      .name(inputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  override def outputSchema = {
    SchemaBuilder.record(withUid(outputBaseName)).fields()
      .name(outputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  private def filterFn = FunctionDef[Boolean](
    Seq(Param[Int]("idx"), Param[Double]("x")),
    Seq(a.contains(modelCell.ref("indices"), "idx"))
  )
  override def action: PFAExpression = {
    NewRecord(outputSchema, Map(outputCol -> a.filterWithIndex(inputExpr, filterFn)))
  }

  override def pfa: PFADocument = {
    PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      .withCell(modelCell)
      .withAction(action)
      .pfa
  }
} 
Example 8
Source File: Normalizer.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.document.{PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.expression._
import com.ibm.aardpfark.spark.ml.PFATransformer
import org.apache.avro.SchemaBuilder

import org.apache.spark.ml.feature.Normalizer


class PFANormalizer(override val sparkTransformer: Normalizer) extends PFATransformer {
  import com.ibm.aardpfark.pfa.dsl._

  private val inputCol = sparkTransformer.getInputCol
  private val outputCol = sparkTransformer.getOutputCol
  private val inputExpr = StringExpr(s"input.${inputCol}")

  private val p = sparkTransformer.getP

  override def inputSchema = {
    SchemaBuilder.record(withUid(inputBaseName)).fields()
      .name(inputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  override def outputSchema = {
    SchemaBuilder.record(withUid(outputBaseName)).fields()
      .name(outputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  private def absPow(p: Double) = FunctionDef[Double, Double](
    Seq("x"),
    Seq(core.pow(m.abs("x"), p))
  )

  private val sq = FunctionDef[Double, Double](
    Seq("x"),
    Seq(core.pow("x", 2.0))
  )

  private val absVal = FunctionDef[Double, Double](
    Seq("x"),
    Seq(m.abs("x"))
  )

  override def action: PFAExpression = {
    val fn = p match {
      case 1.0 =>
        a.sum(a.map(inputExpr, absVal))
      case 2.0 =>
        m.sqrt(a.sum(a.map(inputExpr, sq)))
      case Double.PositiveInfinity =>
        a.max(a.map(inputExpr, absVal))
      case _ =>
        core.pow(a.sum(a.map(inputExpr, absPow(p))), 1.0 / p)

    }
    val norm = Let("norm", fn)
    val invNorm = core.div(1.0, norm.ref)
    val scale = la.scale(inputExpr, invNorm)
    Action(
      norm,
      NewRecord(outputSchema, Map(outputCol -> scale))
    )
  }

  override def pfa: PFADocument = {
    PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      //.withFunction(pow(p))
      .withAction(action)
      .pfa
  }
} 
Example 9
Source File: PCAModel.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.document.{Cell, PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.expression.PFAExpression
import com.ibm.aardpfark.pfa.types.WithSchema
import com.ibm.aardpfark.spark.ml.PFAModel
import com.sksamuel.avro4s.{AvroNamespace, AvroSchema}
import org.apache.avro.{Schema, SchemaBuilder}

import org.apache.spark.ml.feature.PCAModel

@AvroNamespace("com.ibm.aardpfark.exec.spark.ml.feature")
case class PCAData(pc: Array[Array[Double]]) extends WithSchema {
  override def schema: Schema = AvroSchema[this.type]
}

class PFAPCAModel(override val sparkTransformer: PCAModel) extends PFAModel[PCAData] {
  import com.ibm.aardpfark.pfa.dsl._

  private val inputCol = sparkTransformer.getInputCol
  private val outputCol = sparkTransformer.getOutputCol
  private val inputExpr = StringExpr(s"input.${inputCol}")

  override def inputSchema = {
    SchemaBuilder.record(withUid(inputBaseName)).fields()
      .name(inputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  override def outputSchema = {
    SchemaBuilder.record(withUid(outputBaseName)).fields()
      .name(outputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  override protected def cell = {
    val pc = sparkTransformer.pc.transpose.rowIter.map(v => v.toArray).toArray
    Cell(PCAData(pc))
  }

  override def action: PFAExpression = {
    val dot = la.dot(modelCell.ref("pc"), inputExpr)
    NewRecord(outputSchema, Map(outputCol -> dot))
  }

  override def pfa: PFADocument = {
    PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      .withCell(modelCell)
      .withAction(action)
      .pfa
  }

} 
Example 10
Source File: StringIndexerModel.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.document.{Cell, PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.expression._
import com.ibm.aardpfark.pfa.types.WithSchema
import com.ibm.aardpfark.spark.ml.PFAModel
import com.sksamuel.avro4s.{AvroNamespace, AvroSchema}
import org.apache.avro.SchemaBuilder

import org.apache.spark.ml.feature.StringIndexerModel

@AvroNamespace("com.ibm.aardpfark.exec.spark.ml.feature")
case class Vocab(vocab: Map[String, Int]) extends WithSchema {
  def schema = AvroSchema[this.type]
}

class PFAStringIndexerModel(override val sparkTransformer: StringIndexerModel) extends PFAModel[Vocab] {
  import com.ibm.aardpfark.pfa.dsl._

  private val inputCol = sparkTransformer.getInputCol
  private val outputCol = sparkTransformer.getOutputCol
  private val inputExpr = StringExpr(s"input.${inputCol}")

  private val handleInvalid = sparkTransformer.getHandleInvalid
  private val unknownLabel = sparkTransformer.labels.length

  override def inputSchema = {
    SchemaBuilder.record(withUid(inputBaseName)).fields()
      .name(inputCol).`type`().stringType().noDefault()
      .endRecord()
  }

  override def outputSchema = {
    val bldr = SchemaBuilder.record(withUid(outputBaseName)).fields()
      .name(outputCol).`type`()
    if (handleInvalid == "skip") {
      bldr.nullable().doubleType().noDefault().endRecord()
    } else {
      bldr.doubleType().noDefault().endRecord()
    }
  }

  override protected def cell = {
    val vocab = sparkTransformer.labels.zipWithIndex.toMap
    Cell(Vocab(vocab))
  }

  private val vocabRef = modelCell.ref("vocab")

  override def action: PFAExpression = {
    val inputAsStr = s.strip(cast.json(inputExpr), StringLiteral("\""))
    val mapper = If (map.containsKey(vocabRef, inputAsStr)) Then Attr(vocabRef, inputAsStr) Else {
      if (handleInvalid == "error") {
        Error("StringIndexer encountered unseen label")
      } else if (handleInvalid == "keep") {
        IntLiteral(unknownLabel)
      } else {
        NullLiteral
      }
    }
    NewRecord(outputSchema, Map(outputCol -> mapper), true)
  }

  override def pfa: PFADocument = {
    PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      .withCell(modelCell)
      .withAction(action)
      .pfa
  }

} 
Example 11
Source File: MinMaxScaler.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.document.{Cell, PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.expression._
import com.ibm.aardpfark.pfa.types.WithSchema
import com.ibm.aardpfark.spark.ml.PFAModel
import com.sksamuel.avro4s.{AvroNamespace, AvroSchema}
import org.apache.avro.SchemaBuilder

import org.apache.spark.ml.feature.MinMaxScalerModel

@AvroNamespace("com.ibm.aardpfark.exec.spark.ml.feature")
case class MinMaxScalerModelData(
  originalMin: Seq[Double],
  originalRange: Seq[Double]) extends WithSchema {
  def schema = AvroSchema[this.type]
}

class PFAMinMaxScalerModel(override val sparkTransformer: MinMaxScalerModel) extends PFAModel[MinMaxScalerModelData] {
  import com.ibm.aardpfark.pfa.dsl._
  import com.ibm.aardpfark.pfa.dsl.core._

  private val inputCol = sparkTransformer.getInputCol
  private val outputCol = sparkTransformer.getOutputCol
  private val inputExpr = StringExpr(s"input.${inputCol}")

  // cell data
  private val originalMin = sparkTransformer.originalMin.toArray
  private val originalRange = sparkTransformer.originalMax.toArray.zip(originalMin).map { case (max, min) =>
    max - min
  }

  // references to cell variables
  private val originalMinRef = modelCell.ref("originalMin")
  private val originalRangeRef = modelCell.ref("originalRange")

  override def inputSchema = {
    SchemaBuilder.record(withUid(inputBaseName)).fields()
      .name(inputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  override def outputSchema = {
    SchemaBuilder.record(withUid(outputBaseName)).fields()
      .name(outputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  override def cell = {
    val scalerData = MinMaxScalerModelData(originalMin, originalRange)
    Cell(scalerData)
  }

  // local double literals
  private val scale = sparkTransformer.getMax - sparkTransformer.getMin
  private val min = sparkTransformer.getMin
  val cond = If (net(StringExpr("range"), 0.0)) Then div(minus("i", "min"), "range") Else 0.5
  val minMaxScaleFn = NamedFunctionDef("minMaxScale", FunctionDef[Double, Double](
    Seq("i", "min", "range"),
    Seq(plus(mult(cond, scale), min))))

  override def action: PFAExpression = {
    NewRecord(outputSchema,
      Map(outputCol -> a.zipmap(inputExpr, originalMinRef, originalRangeRef, minMaxScaleFn.ref)))
  }

  override def pfa: PFADocument = {
    PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      .withCell(modelCell)
      .withFunction(minMaxScaleFn)
      .withAction(action)
      .pfa
  }
} 
Example 12
Source File: RegexTokenizer.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.document.{PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.expression.PFAExpression
import com.ibm.aardpfark.spark.ml.PFATransformer
import org.apache.avro.SchemaBuilder

import org.apache.spark.ml.feature.RegexTokenizer


// TODO missing token count filter and gaps vs tokens
class PFARegexTokenizer(override val sparkTransformer: RegexTokenizer) extends PFATransformer {
  import com.ibm.aardpfark.pfa.dsl._

  private val inputCol = sparkTransformer.getInputCol
  private val outputCol = sparkTransformer.getOutputCol
  private val inputExpr = StringExpr(s"input.${inputCol}")

  private val pattern = sparkTransformer.getPattern
  private val gaps = sparkTransformer.getGaps
  private val minTokenLength = sparkTransformer.getMinTokenLength
  private val toLowerCase = sparkTransformer.getToLowercase

  override def inputSchema = {
    SchemaBuilder.record(withUid(inputBaseName)).fields()
      .name(inputCol).`type`().stringType().noDefault()
      .endRecord()
  }

  override def outputSchema = {
    SchemaBuilder.record(withUid(outputBaseName)).fields()
      .name(outputCol).`type`().array().items().stringType().noDefault()
      .endRecord()
  }

  override def action: PFAExpression = {
    val a = if (toLowerCase) {
      re.split(s.lower(inputExpr), pattern)
    } else {
      re.split(inputExpr, pattern)
    }
    NewRecord(outputSchema, Map(outputCol -> a))
  }

  override def pfa: PFADocument = {
    PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      .withAction(action)
      .pfa
  }
} 
Example 13
Source File: VectorAssembler.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.document.{PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.expression.PFAExpression
import com.ibm.aardpfark.spark.ml.PFATransformer
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.spark.ml.feature.VectorAssembler
import org.json4s.DefaultFormats

class PFAVectorAssembler(override val sparkTransformer: VectorAssembler) extends PFATransformer {

  import com.ibm.aardpfark.pfa.dsl._
  implicit val formats = DefaultFormats

  private val inputCols = sparkTransformer.getInputCols
  private val outputCol = sparkTransformer.getOutputCol

  type DorSeqD = Either[Double, Seq[Double]]

  override protected def inputSchema: Schema = {
    val builder = SchemaBuilder.record(withUid(inputBaseName)).fields()
    for (inputCol <- inputCols) {
      builder.name(inputCol).`type`()
        .unionOf()
        .doubleType().and()
        .array().items().doubleType()
        .endUnion().noDefault()
    }
    builder.endRecord()
  }

  override protected def outputSchema: Schema = {
    SchemaBuilder.record(withUid(outputBaseName)).fields()
      .name(outputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  private val asDouble = As[Double]("x", x => NewArray[Double](x))
  private val asArray = As[Array[Double]]("x", x => x)

  private val castFn = NamedFunctionDef("castToArray",
    FunctionDef[DorSeqD, Seq[Double]]("x") { x =>
      Cast(x, asDouble, asArray)
    }
  )

  override protected def action: PFAExpression = {
    val cols = Let("cols", NewArray[DorSeqD](inputCols.map(c => StringExpr(s"input.$c"))))
    Action(
      cols,
      NewRecord(outputSchema, Map(outputCol -> a.flatten(a.map(cols.ref, castFn.ref))))
    )
  }

  override def pfa: PFADocument = {
    PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      .withAction(action)
      .withFunction(castFn)
      .pfa
  }
} 
Example 14
Source File: GenericAvroSerializerSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.ByteBuffer

import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.generic.GenericData.Record

import org.apache.spark.{SharedSparkContext, SparkFunSuite}

class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext {
  conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

  val schema : Schema = SchemaBuilder
    .record("testRecord").fields()
    .requiredString("data")
    .endRecord()
  val record = new Record(schema)
  record.put("data", "test data")

  test("schema compression and decompression") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema))))
  }

  test("record serialization and deserialization") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)

    val outputStream = new ByteArrayOutputStream()
    val output = new Output(outputStream)
    genericSer.serializeDatum(record, output)
    output.flush()
    output.close()

    val input = new Input(new ByteArrayInputStream(outputStream.toByteArray))
    assert(genericSer.deserializeDatum(input) === record)
  }

  test("uses schema fingerprint to decrease message size") {
    val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema)

    val output = new Output(new ByteArrayOutputStream())

    val beginningNormalPosition = output.total()
    genericSerFull.serializeDatum(record, output)
    output.flush()
    val normalLength = output.total - beginningNormalPosition

    conf.registerAvroSchemas(schema)
    val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema)
    val beginningFingerprintPosition = output.total()
    genericSerFinger.serializeDatum(record, output)
    val fingerprintLength = output.total - beginningFingerprintPosition

    assert(fingerprintLength < normalLength)
  }

  test("caches previously seen schemas") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    val compressedSchema = genericSer.compress(schema)
    val decompressedSchema = genericSer.decompress(ByteBuffer.wrap(compressedSchema))

    assert(compressedSchema.eq(genericSer.compress(schema)))
    assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema))))
  }
} 
Example 15
Source File: Binarizer.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.dsl.StringExpr
import com.ibm.aardpfark.pfa.document.{PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.dsl._
import com.ibm.aardpfark.pfa.expression.PFAExpression
import com.ibm.aardpfark.spark.ml.PFATransformer
import org.apache.avro.SchemaBuilder

import org.apache.spark.ml.feature.Binarizer

class PFABinarizer(override val sparkTransformer: Binarizer) extends PFATransformer {

  private val inputCol = sparkTransformer.getInputCol
  private val outputCol = sparkTransformer.getOutputCol
  private val inputExpr = StringExpr(s"input.${inputCol}")

  override def inputSchema = {
    SchemaBuilder.record(withUid(inputBaseName)).fields()
      .name(inputCol).`type`().unionOf().array().items().doubleType()
      .and()
      .doubleType().endUnion()
      .noDefault()
      .endRecord()
  }

  override def outputSchema = {
    SchemaBuilder.record(withUid(outputBaseName)).fields()
      .name(outputCol).`type`().unionOf().array().items().doubleType()
      .and()
      .doubleType().endUnion()
      .noDefault()
      .endRecord()
  }

  private val th = sparkTransformer.getThreshold
  private val doubleBin = NamedFunctionDef("doubleBin", FunctionDef[Double, Double]("d",
    If (core.gt(StringExpr("d"), th)) Then 1.0 Else 0.0)
  )

  override def action: PFAExpression = {
    val asDouble = As[Double]("x", x => doubleBin.call(x))
    val asArray = As[Array[Double]]("x", x => a.map(x, doubleBin.ref))
    val cast = Cast(inputExpr,
      Seq(asDouble, asArray))
    NewRecord(outputSchema, Map(outputCol -> cast))
  }

  override def pfa: PFADocument = {
    PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      .withFunction(doubleBin)
      .withAction(action)
      .pfa
  }
} 
Example 16
Source File: StopWordsRemover.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.document.{Cell, PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.expression.PFAExpression
import com.ibm.aardpfark.pfa.types.WithSchema
import com.ibm.aardpfark.spark.ml.PFAModel
import com.sksamuel.avro4s.{AvroNamespace, AvroSchema}
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.spark.ml.feature.StopWordsRemover

@AvroNamespace("com.ibm.aardpfark.exec.spark.spark.ml.feature")
case class StopWords(words: Seq[String]) extends WithSchema {
  def schema = AvroSchema[this.type ]
}

class PFAStopWordsRemover(override val sparkTransformer: StopWordsRemover) extends PFAModel[StopWords] {
  import com.ibm.aardpfark.pfa.dsl._

  private val inputCol = sparkTransformer.getInputCol
  private val outputCol = sparkTransformer.getOutputCol
  private val inputExpr = StringExpr(s"input.${inputCol}")

  private val stopWords = sparkTransformer.getStopWords
  private val caseSensitive = sparkTransformer.getCaseSensitive

  private def filterFn = FunctionDef[String, Boolean]("word") { w =>
    Seq(core.not(a.contains(wordsRef, if (caseSensitive) w else s.lower(w))))
  }

  override def inputSchema: Schema = {
    SchemaBuilder.record(withUid(inputBaseName)).fields()
      .name(inputCol).`type`().array().items().stringType().noDefault()
      .endRecord()
  }

  override def outputSchema: Schema =  {
    SchemaBuilder.record(withUid(outputBaseName)).fields()
      .name(outputCol).`type`().array().items().stringType().noDefault()
      .endRecord()
  }

  override protected def cell = {
    Cell(StopWords(stopWords))
  }

  private val wordsRef = modelCell.ref("words")

  override def action: PFAExpression = {
    NewRecord(outputSchema, Map(outputCol -> a.filter(inputExpr, filterFn)))
  }

  override def pfa: PFADocument =
    PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      .withCell(modelCell)
      .withAction(action)
      .pfa
} 
Example 17
Source File: StandardScaler.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.document.{Cell, PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.expression._
import com.ibm.aardpfark.pfa.types.WithSchema
import com.ibm.aardpfark.spark.ml.PFAModel
import com.sksamuel.avro4s.{AvroNamespace, AvroSchema}
import org.apache.avro.SchemaBuilder

import org.apache.spark.ml.feature.StandardScalerModel

@AvroNamespace("com.ibm.aardpfark.exec.spark.ml.feature")
case class StandardScalerModelData(mean: Seq[Double], std: Seq[Double]) extends WithSchema {
  def schema = AvroSchema[this.type]
}

class PFAStandardScalerModel(override val sparkTransformer: StandardScalerModel) extends PFAModel[StandardScalerModelData] {
  import com.ibm.aardpfark.pfa.dsl._
  import com.ibm.aardpfark.pfa.dsl.core._

  private val inputCol = sparkTransformer.getInputCol
  private val outputCol = sparkTransformer.getOutputCol
  private val inputExpr = StringExpr(s"input.${inputCol}")

  // references to cell variables
  private val meanRef = modelCell.ref("mean")
  private val stdRef = modelCell.ref("std")

  override def inputSchema = {
    SchemaBuilder.record(withUid(inputBaseName)).fields()
      .name(inputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  override def outputSchema = {
    SchemaBuilder.record(withUid(outputBaseName)).fields()
      .name(outputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  override def cell = {
    val scalerData = StandardScalerModelData(sparkTransformer.mean.toArray, sparkTransformer.std.toArray)
    Cell(scalerData)
  }

  def partFn(name: String, p: Seq[String], e: PFAExpression) = {
    NamedFunctionDef(name, FunctionDef[Double, Double](p, Seq(e)))
  }

  // function schema
  val (scaleFnDef, scaleFnRef) = if (sparkTransformer.getWithMean) {
    if (sparkTransformer.getWithStd) {
      val meanStdScale = partFn("meanStdScale", Seq("i", "m", "s"), div(minus("i", "m"), "s"))
      (Some(meanStdScale), a.zipmap(inputExpr, meanRef, stdRef, meanStdScale.ref))
    } else {
      val meanScale = partFn("meanScale", Seq("i", "m"), minus("i", "m"))
      (Some(meanScale), a.zipmap(inputExpr, meanRef, meanScale.ref))
    }
  } else {
    if (sparkTransformer.getWithStd) {
      val stdScale = partFn("stdScale", Seq("i", "s"), div("i", "s"))
      (Some(stdScale), a.zipmap(inputExpr, stdRef, stdScale.ref))
    } else {
      (None, inputExpr)
    }
  }

  override def action: PFAExpression = {
    NewRecord(outputSchema, Map(outputCol -> scaleFnRef))
  }

  override def pfa: PFADocument = {
    val builder = PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      .withCell(modelCell)
      .withAction(action)
    scaleFnDef.foreach(fnDef => builder.withFunction(fnDef))
    builder.pfa
  }
} 
Example 18
Source File: NGram.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.feature

import com.ibm.aardpfark.pfa.document.{PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.expression.{PFAExpression, PartialFunctionRef}
import com.ibm.aardpfark.spark.ml.PFATransformer
import org.apache.avro.SchemaBuilder

import org.apache.spark.ml.feature.NGram


class PFANGram(override val sparkTransformer: NGram) extends PFATransformer {
  import com.ibm.aardpfark.pfa.dsl._

  private val inputCol = sparkTransformer.getInputCol
  private val outputCol = sparkTransformer.getOutputCol
  private val inputExpr = StringExpr(s"input.${inputCol}")

  private val n = sparkTransformer.getN

  override def inputSchema = {
    SchemaBuilder.record(withUid(inputBaseName)).fields()
      .name(inputCol).`type`().array().items().stringType().noDefault()
      .endRecord()
  }

  override def outputSchema = {
    SchemaBuilder.record(withUid(outputBaseName)).fields()
      .name(outputCol).`type`().array().items().stringType().noDefault()
      .endRecord()
  }

  override def action: PFAExpression = {
    // TODO - this partial fn reference is an ugly workaround for now - add support for builtin lib
    val partialFn = new PartialFunctionRef("s.join", Seq(("sep", " ")))
    val mapExpr = a.map(a.slidingWindow(inputExpr, n, 1), partialFn)
    NewRecord(outputSchema, Map(outputCol -> mapExpr))
  }

  override def pfa: PFADocument = {
    PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      .withAction(action)
      .pfa
  }
} 
Example 19
Source File: KMeans.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.clustering

import com.ibm.aardpfark.pfa.dsl.StringExpr
import com.ibm.aardpfark.pfa.document.{Cell, PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.dsl._
import com.ibm.aardpfark.pfa.expression.PFAExpression
import com.ibm.aardpfark.pfa.types.WithSchema
import com.ibm.aardpfark.spark.ml.PFAModel
import com.sksamuel.avro4s.{AvroNamespace, AvroSchema}
import org.apache.avro.{Schema, SchemaBuilder}

import org.apache.spark.ml.clustering.KMeansModel

@AvroNamespace("com.ibm.aardpfark.exec.spark.ml.clustering")
case class Cluster(id: Int, center: Seq[Double])

@AvroNamespace("com.ibm.aardpfark.exec.spark.ml.clustering")
case class KMeansModelData(clusters: Seq[Cluster]) extends WithSchema {
  override def schema: Schema = AvroSchema[this.type]
}

class PFAKMeansModel(override val sparkTransformer: KMeansModel) extends PFAModel[KMeansModelData] {

  private val inputCol = sparkTransformer.getFeaturesCol
  private val outputCol = sparkTransformer.getPredictionCol
  private val inputExpr = StringExpr(s"input.${inputCol}")

  override def inputSchema = {
    SchemaBuilder.record(withUid(inputBaseName)).fields()
      .name(inputCol).`type`().array().items().doubleType().noDefault()
      .endRecord()
  }

  override def outputSchema = SchemaBuilder.record(withUid(outputBaseName)).fields()
    .name(outputCol).`type`().intType().noDefault()
    .endRecord()

  override def cell =  {
    val clusters = sparkTransformer.clusterCenters.zipWithIndex.map { case (v, i) =>
      Cluster(i, v.toArray)
    }
    Cell(KMeansModelData(clusters))
  }

  override def action: PFAExpression = {
    val closest = model.cluster.closest(inputExpr, modelCell.ref("clusters"))
    NewRecord(outputSchema, Map(outputCol -> Attr(closest, "id")))
  }

  override def pfa: PFADocument = {
    PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      .withCell(modelCell)
      .withAction(action)
      .pfa
  }

} 
Example 20
Source File: Merge.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml

import com.ibm.aardpfark.pfa.document.Cell
import com.ibm.aardpfark.pfa.expression._
import org.apache.avro.{Schema, SchemaBuilder}

import org.apache.spark.ml.PipelineModel


    val first = docs.head
    val last = docs.last
    var name = "merged"
    var version = 0L
    val inputSchema = is
    val outputSchema = last.output
    var meta: Map[String, String] = Map()
    var cells: Map[String, Cell[_]] = Map()
    var action: PFAExpression = StringExpr("input")
    var fcns: Map[String, FunctionDef] = Map()
    var currentSchema = inputSchema

    docs.zipWithIndex.foreach { case (doc, idx) =>

      val inputParam = Param("input", currentSchema)

      val inputFields = currentSchema.getFields.toSeq
      val newFields = doc.output.getFields.toSeq
      val outputFields = inputFields ++ newFields

      val bldr = SchemaBuilder.record(s"Stage_${idx + 1}_output_schema").fields()
      outputFields.foreach { field =>
        bldr
          .name(field.name())
          .`type`(field.schema())
          .noDefault()
      }

      currentSchema = bldr.endRecord()

      val let = Let(s"Stage_${idx + 1}_action_output", Do(doc.action))
      val inputExprs = inputFields.map { field =>
        field.name -> StringExpr(s"input.${field.name}")
      }
      val newExprs = newFields.map { field =>
        field.name -> StringExpr(s"${let.x}.${field.name}")

      }
      val exprs = inputExprs ++ newExprs
      val stageOutput = NewRecord(currentSchema, exprs.toMap)

      val le = new LetExpr(Seq((let.x, let.`type`, let.expr)))

      val stageActionFn = NamedFunctionDef(s"Stage_${idx + 1}_action", FunctionDef(
        Seq(inputParam), currentSchema, Seq(le, stageOutput)
      ))

      fcns = fcns ++ doc.fcns + (stageActionFn.name -> stageActionFn.fn)
      cells = cells ++ doc.cells
      meta = meta ++ doc.metadata
      action = stageActionFn.call(action)
    }

    first.copy(
      name = Some(name),
      version = Some(version),
      metadata = meta,
      cells = cells,
      fcns = fcns,
      action = action,
      input = inputSchema,
      output = currentSchema
    )

  }
} 
Example 21
Source File: LogisticRegressionModel.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.classification

import com.ibm.aardpfark.pfa.document.{PFABuilder, PFADocument}
import com.ibm.aardpfark.pfa.dsl._
import com.ibm.aardpfark.spark.ml.PFALinearPredictionModel
import org.apache.avro.SchemaBuilder

import org.apache.spark.ml.classification.LogisticRegressionModel


class PFALogisticRegressionModel(override val sparkTransformer: LogisticRegressionModel)
  extends PFALinearPredictionModel {

  private val rawPredictionCol = sparkTransformer.getRawPredictionCol
  private val probabilityCol = sparkTransformer.getProbabilityCol
  private val isBinary = sparkTransformer.numClasses == 2

  override def outputSchema = SchemaBuilder.record(withUid(outputBaseName)).fields()
    .name(rawPredictionCol).`type`().array().items().doubleType().noDefault()
    .name(predictionCol).`type`.doubleType().noDefault()
    .name(probabilityCol).`type`().array().items().doubleType().noDefault()
    .endRecord()

  private val safeDoubleDiv = NamedFunctionDef("safeDoubleDiv", FunctionDef[Double, Double]("x", "y") {
    case Seq(x, y) =>
      val result = Let("result", core.div(x, y))
      val cond = If (impute.isnan(result.ref)) Then {
        core.addinv(core.pow(10.0, 320.0))
      } Else {
        result.ref
      }
      Seq(
        result,
        cond
      )
  })

  private val rawPredFn = if (isBinary) {
    NewArray[Double](Seq(core.addinv(margin.ref), margin.ref))
  } else {
    margin.ref
  }
  private val probFn = if (isBinary) {
    m.link.logit(rawPredFn)
  } else {
    m.link.softmax(rawPredFn)
  }

  private val rawPred = Let("rawPred", rawPredFn)
  private val prob = Let("prob", probFn)

  private val predFn = if (isBinary) {
    val threshold = sparkTransformer.getThreshold
    val probAttr = Attr(prob.ref, 1)
    If (core.lte(probAttr, threshold)) Then 0.0 Else 1.0
  } else {
    val scaled = if (sparkTransformer.isDefined(sparkTransformer.thresholds)) {
      val thresholds = NewArray[Double](sparkTransformer.getThresholds.map(DoubleLiteral))
      a.zipmap(prob.ref, thresholds, safeDoubleDiv.ref)
    } else {
      prob.ref
    }
    a.argmax(scaled)
  }

  private val pred = Let("pred", predFn)

  override def action = {
    Action(
      margin,
      rawPred,
      prob,
      pred,
      NewRecord(outputSchema, Map(
        probabilityCol -> prob.ref,
        rawPredictionCol -> rawPred.ref,
        predictionCol -> pred.ref)
      )
    )
  }

  override def pfa: PFADocument = {
    val bldr = PFABuilder()
      .withName(sparkTransformer.uid)
      .withMetadata(getMetadata)
      .withInput(inputSchema)
      .withOutput(outputSchema)
      .withCell(modelCell)
      .withAction(action)
    if (!isBinary) {
      bldr.withFunction(safeDoubleDiv)
    }
    bldr.pfa
  }

} 
Example 22
Source File: LinearSVCModel.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.classification

import com.ibm.aardpfark.pfa.dsl._
import com.ibm.aardpfark.spark.ml.PFALinearPredictionModel
import org.apache.avro.SchemaBuilder

import org.apache.spark.ml.classification.LinearSVCModel

class PFALinearSVCModel(override val sparkTransformer: LinearSVCModel)
  extends PFALinearPredictionModel {

  private val rawPredictionCol = sparkTransformer.getRawPredictionCol
  private val threshold = sparkTransformer.getThreshold

  override def outputSchema = SchemaBuilder.record(withUid(outputBaseName)).fields()
    .name(rawPredictionCol).`type`().array().items().doubleType().noDefault()
    .name(predictionCol).`type`.doubleType().noDefault()
    .endRecord()

  private val rawSchema = outputSchema.getField(rawPredictionCol).schema()
  private val rawPred = NewArray(rawSchema, Seq(core.addinv(margin.ref), margin.ref))
  private val pred = If (core.lte(margin.ref, threshold)) Then 0.0 Else 1.0
  override def action = {
    Action(
      margin,
      NewRecord(outputSchema, Map(
        predictionCol -> pred,
        rawPredictionCol -> rawPred)
      )
    )
  }
} 
Example 23
Source File: SparkSupport.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml

import com.ibm.aardpfark.avro.SchemaConverters
import com.ibm.aardpfark.pfa.document.{PFADocument, ToPFA}
import org.apache.avro.SchemaBuilder

import org.apache.spark.ml.{PipelineModel, Transformer}
import org.apache.spark.sql.types.StructType


object SparkSupport {


  def toPFA(t: Transformer, pretty: Boolean): String = {
    toPFATransformer(t).pfa.toJSON(pretty)
  }

  def toPFA(p: PipelineModel, s: StructType, pretty: Boolean): String = {
    val inputFields = s.map { f => f.copy(nullable = false) }
    val inputSchema = StructType(inputFields)
    val pipelineInput = SchemaBuilder.record(s"Input_${p.uid}")
    val inputAvroSchema = SchemaConverters.convertStructToAvro(inputSchema, pipelineInput, "")
    Merge.mergePipeline(p, inputAvroSchema).toJSON(pretty)
  }

  // testing implicit conversions for Spark ML PipelineModel and Transformer to PFA / JSON

  implicit private[aardpfark] def toPFATransformer(transformer: org.apache.spark.ml.Transformer): ToPFA = {

    val pkg = transformer.getClass.getPackage.getName
    val name = transformer.getClass.getSimpleName
    val pfaPkg = pkg.replace("org.apache", "com.ibm.aardpfark")
    val pfaClass = Class.forName(s"$pfaPkg.PFA$name")

    val ctor = pfaClass.getConstructors()(0)
    ctor.newInstance(transformer).asInstanceOf[ToPFA]
  }
} 
Example 24
Source File: CastsSuite.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.pfa.expression

import com.ibm.aardpfark.pfa.DSLSuiteBase
import com.ibm.aardpfark.pfa.dsl._
import com.ibm.aardpfark.pfa.document.PFABuilder
import org.apache.avro.SchemaBuilder

class CastsSuite extends DSLSuiteBase {

  test("DSL: Casts") {

//    val fromNull = As[Null]("input", _ => "Null")
    val fromDouble = As[Double]("input", _ => "Double")
    val fromInt = As[Int]("input", _ => "Int")
    val cast = Cast(inputExpr, Seq(fromDouble, fromInt))
    val action = Action(cast)

    val schema = SchemaBuilder.unionOf().doubleType().and().intType().endUnion()

    val pfaDoc = new PFABuilder()
      .withInput(schema)
      .withOutput[String]
      .withAction(action)
      .pfa

    val engine = getPFAEngine(pfaDoc.toJSON())

    val doubleResult = engine.action(engine.jsonInput("""{"double": 1.0}"""))
    assert(doubleResult == "Double")
    val intResult = engine.action(engine.jsonInput("""{"int":1}"""))
    assert(intResult == "Int")
  }

} 
Example 25
Source File: IndexWithCompleteDocument.scala    From CM-Well   with Apache License 2.0 5 votes vote down vote up
package cmwell.analytics.data

import com.fasterxml.jackson.databind.JsonNode
import com.typesafe.config.ConfigFactory
import org.apache.avro.generic.GenericRecord
import org.apache.avro.{Schema, SchemaBuilder}

case class IndexWithCompleteDocument(uuid: String, document: String) extends GenericRecord with CsvGenerator {

  override def put(key: String, v: scala.Any): Unit = ???

  override def get(key: String): AnyRef = key match {
    case "uuid" => uuid
    case "document" => document
    case _ => throw new IllegalArgumentException
  }

  override def put(i: Int, v: scala.Any): Unit = ???

  override def get(i: Int): AnyRef = i match {
    case 0 => uuid
    case 1 => document
    case _ => throw new IllegalArgumentException
  }

  override def getSchema: Schema = IndexWithCompleteDocument.schema

  // Specifically don't implement CsvGenerator.csv since it is guaranteed to be invalid CSV - force use of Parquet.
}

object IndexWithCompleteDocument extends ObjectExtractor[IndexWithCompleteDocument] {

  val schema: Schema = SchemaBuilder
    .record("IndexWithCompleteDocument").namespace("cmwell.analytics")
    .fields
    .name("uuid").`type`.unionOf.stringType.and.nullType.endUnion.noDefault
    .name("document").`type`.unionOf.stringType.and.nullType.endUnion.noDefault
    .endRecord

  private val config = ConfigFactory.load
  val infotonSize: Int = config.getInt("extract-index-from-es.fetch-size-index-with-complete-document")

  def includeFields: String = s""""_source": "*""""

  def extractFromJson(hit: JsonNode): IndexWithCompleteDocument =
    IndexWithCompleteDocument(
      uuid = hit.findValue("_id").asText,
      document = hit.findValue("_source").toString)
} 
Example 26
Source File: ConverterTest.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.avro

import org.apache.avro.SchemaBuilder
import org.scalatest.{Matchers, WordSpec}

class ConverterTest extends WordSpec with Matchers {

  "Converter" should {
    "convert to long" in {
      AvroSerializer(SchemaBuilder.builder().longType()).serialize("123") shouldBe 123l
      AvroSerializer(SchemaBuilder.builder().longType()).serialize(14555) shouldBe 14555l
    }
    "convert to String" in {
      AvroSerializer(SchemaBuilder.builder().stringType()).serialize(123l) shouldBe "123"
      AvroSerializer(SchemaBuilder.builder().stringType).serialize(124) shouldBe "124"
      AvroSerializer(SchemaBuilder.builder().stringType).serialize("Qweqwe") shouldBe "Qweqwe"
    }
    "convert to boolean" in {
      AvroSerializer(SchemaBuilder.builder().booleanType).serialize(true) shouldBe true
      AvroSerializer(SchemaBuilder.builder().booleanType).serialize(false) shouldBe false
      AvroSerializer(SchemaBuilder.builder().booleanType).serialize("true") shouldBe true
      AvroSerializer(SchemaBuilder.builder().booleanType()).serialize("false") shouldBe false
    }
    "convert to Double" in {
      AvroSerializer(SchemaBuilder.builder().doubleType).serialize("213.4") shouldBe 213.4d
      AvroSerializer(SchemaBuilder.builder().doubleType).serialize("345.11") shouldBe 345.11d
      AvroSerializer(SchemaBuilder.builder().doubleType()).serialize(345) shouldBe 345.0
    }
  }
} 
Example 27
Source File: AvroSchemaFnsTest.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.avro

import java.util

import io.eels.schema._
import org.apache.avro.SchemaBuilder
import org.codehaus.jackson.node.NullNode
import org.scalatest.{Matchers, WordSpec}

import scala.collection.JavaConverters._

class AvroSchemaFnsTest extends WordSpec with Matchers {

  "toAvro" should {
    "use a union of [null, type] for a nullable column" in {
      val schema = StructType(Field("a", StringType, true))
      val fields = AvroSchemaFns.toAvroSchema(schema).getFields.asScala
      fields.head.schema().getType shouldBe org.apache.avro.Schema.Type.UNION
      fields.head.schema().getTypes.get(0).getType shouldBe org.apache.avro.Schema.Type.NULL
      fields.head.schema().getTypes.get(1).getType shouldBe org.apache.avro.Schema.Type.STRING
    }
    "set default type of NullNode for a nullable column" in {
      val schema = StructType(Field("a", StringType, true))
      val fields = AvroSchemaFns.toAvroSchema(schema).getFields
      fields.get(0).defaultValue() shouldBe NullNode.getInstance()
    }
    "not set a default value for a non null column" in {
      val schema = StructType(Field("a", IntType(true), false))
      val fields = AvroSchemaFns.toAvroSchema(schema).getFields
      (fields.get(0).defaultVal() == null) shouldBe true
      fields.get(0).schema().getType shouldBe org.apache.avro.Schema.Type.INT
    }
  }

  "fromAvroSchema" should {
    "convert avro unions [null, string] to nullable columns" in {
      val avro = SchemaBuilder.record("dummy").fields().optionalString("str").endRecord()
      AvroSchemaFns.fromAvroSchema(avro) shouldBe StructType(Field("str", StringType, true))
    }
    "convert avro unions [null, double] to nullable double columns" in {
      val union = org.apache.avro.Schema.createUnion(util.Arrays.asList(SchemaBuilder.builder().doubleType(), SchemaBuilder.builder().nullType()))
      val avro = SchemaBuilder.record("dummy").fields().name("u").`type`(union).noDefault().endRecord()
      AvroSchemaFns.fromAvroSchema(avro) shouldBe StructType(Field("u", DoubleType, true))
    }
  }
} 
Example 28
Source File: AvroSchemaMergeTest.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.avro

import org.apache.avro.SchemaBuilder
import org.scalatest.{Matchers, WordSpec}

class AvroSchemaMergeTest extends WordSpec with Matchers {
  "AvroSchemaMerge" should {
    "merge all fields" in {
      val schema1 = SchemaBuilder.record("record1").fields().nullableString("str1", "moo").requiredFloat("f").endRecord()
      val schema2 = SchemaBuilder.record("record2").fields().nullableString("str2", "foo").requiredFloat("g").endRecord()
      AvroSchemaMerge("finalname", "finalnamespace", List(schema1, schema2)) shouldBe
          SchemaBuilder.record("finalname").namespace("finalnamespace")
              .fields()
              .nullableString("str1", "moo")
              .requiredFloat("f")
              .nullableString("str2", "foo")
              .requiredFloat("g")
              .endRecord()
    }

    "drop duplicates" in {
      val schema1 = SchemaBuilder.record("record1").fields().nullableString("str1", "moo").requiredFloat("f").endRecord()
      val schema2 = SchemaBuilder.record("record2").fields().nullableString("str2", "foo").requiredFloat("f").endRecord()
      AvroSchemaMerge("finalname", "finalnamespace", List(schema1, schema2)) shouldBe
          SchemaBuilder.record("finalname").namespace("finalnamespace")
              .fields()
              .nullableString("str1", "moo")
              .requiredFloat("f")
              .nullableString("str2", "foo")
              .endRecord()
    }
  }
} 
Example 29
Source File: AvroSerializerTest.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.avro

import io.eels.schema.{ArrayType, Field, IntType, StructType}
import io.eels.Row
import org.apache.avro.SchemaBuilder
import org.scalatest.{Matchers, WordSpec}
import scala.collection.JavaConverters._

class AvroSerializerTest extends WordSpec with Matchers {

  private val avroSchema = SchemaBuilder.record("row").fields().requiredString("s").requiredLong("l").requiredBoolean("b").endRecord()
  private val serializer = new RowSerializer(avroSchema)

  "AvroRecordMarshaller" should {
    "createReader field from values in row" in {
      val eelSchema = StructType(Field("s"), Field("l"), Field("b"))
      val record = serializer.serialize(Row(eelSchema, "a", 1L, false))
      record.get("s") shouldBe "a"
      record.get("l") shouldBe 1L
      record.get("b") shouldBe false
    }
    "only accept rows with same number of values as schema fields" in {
      intercept[IllegalArgumentException] {
        val eelSchema = StructType(Field("a"), Field("b"))
        serializer.serialize(Row(eelSchema, "a", 1L))
      }
      intercept[IllegalArgumentException] {
        val eelSchema = StructType(Field("a"), Field("b"), Field("c"), Field("d"))
        serializer.serialize(Row(eelSchema, "1", "2", "3", "4"))
      }
    }
    "support rows with a different ordering to the write schema" in {
      val eelSchema = StructType(Field("l"), Field("b"), Field("s"))
      val record = serializer.serialize(Row(eelSchema, 1L, false, "a"))
      record.get("s") shouldBe "a"
      record.get("l") shouldBe 1L
      record.get("b") shouldBe false
    }
    "convert strings to longs" in {
      val record = serializer.serialize(Row(AvroSchemaFns.fromAvroSchema(avroSchema), "1", "2", "true"))
      record.get("l") shouldBe 2L
    }
    "convert strings to booleans" in {
      val record = serializer.serialize(Row(AvroSchemaFns.fromAvroSchema(avroSchema), "1", "2", "true"))
      record.get("b") shouldBe true
    }
    "convert longs to strings" in {
      val record = serializer.serialize(Row(AvroSchemaFns.fromAvroSchema(avroSchema), 1L, "2", "true"))
      record.get("s") shouldBe "1"
    }
    "convert booleans to strings" in {
      val record = serializer.serialize(Row(AvroSchemaFns.fromAvroSchema(avroSchema), true, "2", "true"))
      record.get("s") shouldBe "true"
    }
    "support arrays" in {
      val schema = StructType(Field("a", ArrayType(IntType.Signed)))
      val serializer = new RowSerializer(AvroSchemaFns.toAvroSchema(schema))
      val record = serializer.serialize(Row(schema, Array(1, 2)))
      record.get("a").asInstanceOf[java.util.List[_]].asScala.toList shouldBe List(1, 2)
    }
    "support lists" in {
      val schema = StructType(Field("a", ArrayType(IntType.Signed)))
      val serializer = new RowSerializer(AvroSchemaFns.toAvroSchema(schema))
      val record = serializer.serialize(Row(schema, Array(1, 2)))
      record.get("a").asInstanceOf[java.util.List[_]].asScala.toList shouldBe List(1, 2)
    }
    "support sets" in {
      val schema = StructType(Field("a", ArrayType(IntType(true))))
      val serializer = new RowSerializer(AvroSchemaFns.toAvroSchema(schema))
      val record = serializer.serialize(Row(schema, Set(1, 2)))
      record.get("a").asInstanceOf[java.util.List[_]].asScala.toList shouldBe List(1, 2)
    }
    "support iterables" in {
      val schema = StructType(Field("a", ArrayType(IntType(true))))
      val serializer = new RowSerializer(AvroSchemaFns.toAvroSchema(schema))
      val record = serializer.serialize(Row(schema, Iterable(1, 2)))
      record.get("a").asInstanceOf[java.util.List[_]].asScala.toList shouldBe List(1, 2)
    }
  }
} 
Example 30
Source File: AvroParquetReaderFnTest.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.parquet

import java.util.UUID

import io.eels.component.avro.AvroSchemaFns
import io.eels.component.parquet.avro.AvroParquetReaderFn
import io.eels.schema.{DoubleType, Field, LongType, StructType}
import org.apache.avro.SchemaBuilder
import org.apache.avro.generic.{GenericData, GenericRecord}
import org.apache.avro.util.Utf8
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.parquet.avro.AvroParquetWriter
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpec}

class AvroParquetReaderFnTest extends WordSpec with Matchers with BeforeAndAfterAll {

  private implicit val conf = new Configuration()
  private implicit val fs = FileSystem.get(new Configuration())

  private val path = new Path(UUID.randomUUID().toString())

  override def afterAll(): Unit = {
    val fs = FileSystem.get(new Configuration())
    fs.delete(path, false)
  }

  private val avroSchema = SchemaBuilder.record("com.chuckle").fields()
    .requiredString("str").requiredLong("looong").requiredDouble("dooble").endRecord()

  private val writer = AvroParquetWriter.builder[GenericRecord](path)
    .withSchema(avroSchema)
    .build()

  private val record = new GenericData.Record(avroSchema)
  record.put("str", "wibble")
  record.put("looong", 999L)
  record.put("dooble", 12.34)
  writer.write(record)
  writer.close()

  val schema = StructType(Field("str"), Field("looong", LongType(true), true), Field("dooble", DoubleType, true))

  "AvroParquetReaderFn" should {
    "support projections on doubles" in {

      val reader = AvroParquetReaderFn(path, None, Option(AvroSchemaFns.toAvroSchema(schema.removeField("looong"))))
      val record = reader.read()
      reader.close()

      record.get("str").asInstanceOf[Utf8].toString shouldBe "wibble"
      record.get("dooble") shouldBe 12.34
    }
    "support projections on longs" in {

      val reader = AvroParquetReaderFn(path, None, Option(AvroSchemaFns.toAvroSchema(schema.removeField("str"))))
      val record = reader.read()
      reader.close()

      record.get("looong") shouldBe 999L
    }
    "support full projections" in {

      val reader = AvroParquetReaderFn(path, None, Option(AvroSchemaFns.toAvroSchema(schema)))
      val record = reader.read()
      reader.close()

      record.get("str").asInstanceOf[Utf8].toString shouldBe "wibble"
      record.get("looong") shouldBe 999L
      record.get("dooble") shouldBe 12.34

    }
    "support non projections" in {

      val reader = AvroParquetReaderFn(path, None, None)
      val group = reader.read()
      reader.close()

      group.get("str").asInstanceOf[Utf8].toString shouldBe "wibble"
      group.get("looong") shouldBe 999L
      group.get("dooble") shouldBe 12.34

    }
  }
} 
Example 31
Source File: IngestionFlowSpec.scala    From hydra   with Apache License 2.0 5 votes vote down vote up
package hydra.ingest.services

import cats.effect.{Concurrent, ContextShift, IO}
import hydra.avro.registry.SchemaRegistry
import hydra.core.ingest.HydraRequest
import hydra.core.ingest.RequestParams.{HYDRA_KAFKA_TOPIC_PARAM,HYDRA_RECORD_KEY_PARAM}
import hydra.ingest.services.IngestionFlow.MissingTopicNameException
import hydra.kafka.algebras.KafkaClientAlgebra
import org.apache.avro.{Schema, SchemaBuilder}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import scala.concurrent.ExecutionContext

class IngestionFlowSpec extends AnyFlatSpec with Matchers {

  private implicit val contextShift: ContextShift[IO] = IO.contextShift(ExecutionContext.global)
  private implicit val concurrentEffect: Concurrent[IO] = IO.ioConcurrentEffect
  private implicit val mode: scalacache.Mode[IO] = scalacache.CatsEffect.modes.async

  private val testSubject: String = "test_subject"

  private val testSubjectNoKey: String = "test_subject_no_key"

  private val testKey: String = "test"

  private val testPayload: String =
    s"""{"id": "$testKey", "testField": true}"""

  private val testSchema: Schema = SchemaBuilder.record("TestRecord")
    .prop("hydra.key", "id")
    .fields().requiredString("id").requiredBoolean("testField").endRecord()

  private val testSchemaNoKey: Schema = SchemaBuilder.record("TestRecordNoKey")
    .fields().requiredString("id").requiredBoolean("testField").endRecord()

  private def ingest(request: HydraRequest): IO[KafkaClientAlgebra[IO]] = for {
    schemaRegistry <- SchemaRegistry.test[IO]
    _ <- schemaRegistry.registerSchema(testSubject + "-value", testSchema)
    _ <- schemaRegistry.registerSchema(testSubjectNoKey + "-value", testSchemaNoKey)
    kafkaClient <- KafkaClientAlgebra.test[IO]
    ingestFlow <- IO(new IngestionFlow[IO](schemaRegistry, kafkaClient, "https://schemaRegistry.notreal"))
    _ <- ingestFlow.ingest(request)
  } yield kafkaClient

  it should "ingest a message" in {
    val testRequest = HydraRequest("correlationId", testPayload, metadata = Map(HYDRA_KAFKA_TOPIC_PARAM -> testSubject))
    ingest(testRequest).flatMap { kafkaClient =>
      kafkaClient.consumeStringKeyMessages(testSubject, "test-consumer").take(1).compile.toList.map { publishedMessages =>
        val firstMessage = publishedMessages.head
        (firstMessage._1, firstMessage._2.get.toString) shouldBe (Some(testKey), testPayload)
      }
    }.unsafeRunSync()
  }

  it should "ingest a message with a null key" in {
    val testRequest = HydraRequest("correlationId", testPayload, metadata = Map(HYDRA_KAFKA_TOPIC_PARAM -> testSubjectNoKey))
    ingest(testRequest).flatMap { kafkaClient =>
      kafkaClient.consumeStringKeyMessages(testSubjectNoKey, "test-consumer").take(1).compile.toList.map { publishedMessages =>
        val firstMessage = publishedMessages.head
        (firstMessage._1, firstMessage._2.get.toString) shouldBe (None, testPayload)
      }
    }.unsafeRunSync()
  }

  it should "return an error when no topic name is provided" in {
    val testRequest = HydraRequest("correlationId", testPayload)
    ingest(testRequest).attempt.unsafeRunSync() shouldBe Left(MissingTopicNameException(testRequest))
  }

  it should "take the key from the header if present" in {
    val headerKey = "someDifferentKey"
    val testRequest = HydraRequest("correlationId", testPayload, metadata = Map(HYDRA_RECORD_KEY_PARAM -> headerKey, HYDRA_KAFKA_TOPIC_PARAM -> testSubject))
    ingest(testRequest).flatMap { kafkaClient =>
      kafkaClient.consumeStringKeyMessages(testSubject, "test-consumer").take(1).compile.toList.map { publishedMessages =>
        val firstMessage = publishedMessages.head
        (firstMessage._1, firstMessage._2.get.toString) shouldBe (Some(headerKey), testPayload)
      }
    }.unsafeRunSync()

  }

} 
Example 32
Source File: AvroKeyRecordSpec.scala    From hydra   with Apache License 2.0 5 votes vote down vote up
package hydra.kafka.producer

import hydra.core.transport.AckStrategy
import org.apache.avro.SchemaBuilder
import org.apache.avro.generic.GenericRecordBuilder
import org.scalatest.matchers.should.Matchers
import org.scalatest.flatspec.AnyFlatSpecLike

class AvroKeyRecordSpec extends AnyFlatSpecLike with Matchers {

  it must "construct an AvroKeyRecord" in {
    def schema(name: String) =
      SchemaBuilder
        .record(name)
        .fields()
        .name(name)
        .`type`
        .stringType()
        .noDefault()
        .endRecord()
    def json(name: String) =
      s"""
         |{
         |  "$name":"test"
         |}
         |""".stripMargin
    val avroKeyRecord = AvroKeyRecord.apply(
      "dest",
      schema("key"),
      schema("value"),
      json("key"),
      json("value"),
      AckStrategy.Replicated
    )
    def genRecord(name: String) =
      new GenericRecordBuilder(schema(name)).set(name, "test").build()

    avroKeyRecord shouldBe AvroKeyRecord(
      "dest",
      schema("key"),
      schema("value"),
      genRecord("key"),
      genRecord("value"),
      AckStrategy.Replicated
    )
  }

} 
Example 33
Source File: Main.scala    From sbt-avrohugger   with Apache License 2.0 5 votes vote down vote up
package example

import org.apache.avro.SchemaBuilder

object Main extends App {
  override def main(args: Array[String]): Unit = {
    println(SchemaBuilder
     .record("HandshakeRequest").namespace("org.apache.avro.ipc")
     .fields()
       .name("clientHash").`type`().fixed("MD5").size(16).noDefault()
       .name("clientProtocol").`type`().nullable().stringType().noDefault()
     .endRecord().toString)
  }
} 
Example 34
Source File: GenericAvroSerializerSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.ByteBuffer

import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.generic.GenericData.Record

import org.apache.spark.{SharedSparkContext, SparkFunSuite}

class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext {
  conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

  val schema : Schema = SchemaBuilder
    .record("testRecord").fields()
    .requiredString("data")
    .endRecord()
  val record = new Record(schema)
  record.put("data", "test data")

  test("schema compression and decompression") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema))))
  }

  test("record serialization and deserialization") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)

    val outputStream = new ByteArrayOutputStream()
    val output = new Output(outputStream)
    genericSer.serializeDatum(record, output)
    output.flush()
    output.close()

    val input = new Input(new ByteArrayInputStream(outputStream.toByteArray))
    assert(genericSer.deserializeDatum(input) === record)
  }

  test("uses schema fingerprint to decrease message size") {
    val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema)

    val output = new Output(new ByteArrayOutputStream())

    val beginningNormalPosition = output.total()
    genericSerFull.serializeDatum(record, output)
    output.flush()
    val normalLength = output.total - beginningNormalPosition

    conf.registerAvroSchemas(schema)
    val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema)
    val beginningFingerprintPosition = output.total()
    genericSerFinger.serializeDatum(record, output)
    val fingerprintLength = output.total - beginningFingerprintPosition

    assert(fingerprintLength < normalLength)
  }

  test("caches previously seen schemas") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    val compressedSchema = genericSer.compress(schema)
    val decompressedSchema = genericSer.decompress(ByteBuffer.wrap(compressedSchema))

    assert(compressedSchema.eq(genericSer.compress(schema)))
    assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema))))
  }
} 
Example 35
Source File: IndexWithKeyFields.scala    From CM-Well   with Apache License 2.0 5 votes vote down vote up
package cmwell.analytics.data

import com.fasterxml.jackson.databind.JsonNode
import com.typesafe.config.ConfigFactory
import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
import org.apache.avro.generic.GenericRecord
import org.apache.log4j.LogManager
import org.joda.time.format.ISODateTimeFormat

import scala.util.control.NonFatal


case class IndexWithKeyFields(uuid: String,
                              lastModified: java.sql.Timestamp,
                              path: String) extends GenericRecord with CsvGenerator {

  override def put(key: String, v: scala.Any): Unit = ???

  override def get(key: String): AnyRef = key match {
    case "uuid" => uuid
    case "lastModified" => java.lang.Long.valueOf(lastModified.getTime)
    case "path" => path
  }

  override def put(i: Int, v: scala.Any): Unit = ???

  override def get(i: Int): AnyRef = i match {
    case 0 => uuid
    case 1 => java.lang.Long.valueOf(lastModified.getTime)
    case 2 => path
    case _ => throw new IllegalArgumentException
  }

  override def getSchema: Schema = IndexWithSystemFields.schema

  override def csv: String =
    (if (uuid == null) "" else uuid) + "," +
      (if (lastModified == null) "" else ISODateTimeFormat.dateTime.print(lastModified.getTime)) + "," +
      (if (path == null) "" else path)
}

object IndexWithKeyFields extends ObjectExtractor[IndexWithKeyFields] {

  private val logger = LogManager.getLogger(IndexWithSystemFields.getClass)

  // AVRO-2065 - doesn't allow union over logical type, so we can't make timestamp column nullable.
  val timestampMilliType: Schema = LogicalTypes.timestampMillis.addToSchema(Schema.create(Schema.Type.LONG))

  val schema: Schema = SchemaBuilder
    .record("IndexWithSystemFields").namespace("cmwell.analytics")
    .fields
    .name("uuid").`type`.unionOf.stringType.and.nullType.endUnion.noDefault
    .name("lastModified").`type`(timestampMilliType).noDefault
    .name("path").`type`.unionOf.stringType.and.nullType.endUnion.noDefault
    .endRecord

  private val config = ConfigFactory.load
  val infotonSize: Int = config.getInt("extract-index-from-es.fetch-size-index-with-uuid-lastModified-path")

  def includeFields: String = {
    // Note that 'quad' is not included in this list
    val fields = "uuid,lastModified,path"
      .split(",")
      .map(name => s""""system.$name"""")
      .mkString(",")

    s""""_source": [$fields]"""
  }

  def extractFromJson(hit: JsonNode): IndexWithKeyFields = {

    val system = hit.findValue("_source").findValue("system")

    def extractString(name: String): String = system.findValue(name) match {
      case x: JsonNode => x.asText
      case _ => null
    }

    // Extracting date values as Long - as a java.sql.Date might be better
    def extractDate(name: String): java.sql.Timestamp = system.findValue(name) match {
      case x: JsonNode =>
        try {
          new java.sql.Timestamp(ISODateTimeFormat.dateTime.parseDateTime(x.asText).getMillis)
        }
        catch {
          case NonFatal(ex) =>
            logger.warn(s"Failed conversion of date value: $x", ex)
            throw ex
        }
      case _ => null
    }

    IndexWithKeyFields(
      uuid = extractString("uuid"),
      lastModified = extractDate("lastModified"),
      path = extractString("path"))
  }
} 
Example 36
Source File: GenericAvroSerializerSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.ByteBuffer

import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.generic.GenericData.Record

import org.apache.spark.{SharedSparkContext, SparkFunSuite}

class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext {
  conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

  val schema : Schema = SchemaBuilder
    .record("testRecord").fields()
    .requiredString("data")
    .endRecord()
  val record = new Record(schema)
  record.put("data", "test data")

  test("schema compression and decompression") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema))))
  }

  test("record serialization and deserialization") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)

    val outputStream = new ByteArrayOutputStream()
    val output = new Output(outputStream)
    genericSer.serializeDatum(record, output)
    output.flush()
    output.close()

    val input = new Input(new ByteArrayInputStream(outputStream.toByteArray))
    assert(genericSer.deserializeDatum(input) === record)
  }

  test("uses schema fingerprint to decrease message size") {
    val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema)

    val output = new Output(new ByteArrayOutputStream())

    val beginningNormalPosition = output.total()
    genericSerFull.serializeDatum(record, output)
    output.flush()
    val normalLength = output.total - beginningNormalPosition

    conf.registerAvroSchemas(schema)
    val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema)
    val beginningFingerprintPosition = output.total()
    genericSerFinger.serializeDatum(record, output)
    val fingerprintLength = output.total - beginningFingerprintPosition

    assert(fingerprintLength < normalLength)
  }

  test("caches previously seen schemas") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    val compressedSchema = genericSer.compress(schema)
    val decompressedSchema = genericSer.decompress(ByteBuffer.wrap(compressedSchema))

    assert(compressedSchema.eq(genericSer.compress(schema)))
    assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema))))
  }
} 
Example 37
Source File: GenericAvroSerializerSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.ByteBuffer

import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.generic.GenericData.Record

import org.apache.spark.{SharedSparkContext, SparkFunSuite}

class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext {
  conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

  val schema : Schema = SchemaBuilder
    .record("testRecord").fields()
    .requiredString("data")
    .endRecord()
  val record = new Record(schema)
  record.put("data", "test data")

  test("schema compression and decompression") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema))))
  }

  test("record serialization and deserialization") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)

    val outputStream = new ByteArrayOutputStream()
    val output = new Output(outputStream)
    genericSer.serializeDatum(record, output)
    output.flush()
    output.close()

    val input = new Input(new ByteArrayInputStream(outputStream.toByteArray))
    assert(genericSer.deserializeDatum(input) === record)
  }

  test("uses schema fingerprint to decrease message size") {
    val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema)

    val output = new Output(new ByteArrayOutputStream())

    val beginningNormalPosition = output.total()
    genericSerFull.serializeDatum(record, output)
    output.flush()
    val normalLength = output.total - beginningNormalPosition

    conf.registerAvroSchemas(schema)
    val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema)
    val beginningFingerprintPosition = output.total()
    genericSerFinger.serializeDatum(record, output)
    val fingerprintLength = output.total - beginningFingerprintPosition

    assert(fingerprintLength < normalLength)
  }

  test("caches previously seen schemas") {
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    val compressedSchema = genericSer.compress(schema)
    val decompressedSchema = genericSer.decompress(ByteBuffer.wrap(compressedSchema))

    assert(compressedSchema.eq(genericSer.compress(schema)))
    assert(decompressedSchema.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema))))
  }
} 
Example 38
Source File: GenericAvroSerializerSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.nio.ByteBuffer

import com.esotericsoftware.kryo.io.{Output, Input}
import org.apache.avro.{SchemaBuilder, Schema}
import org.apache.avro.generic.GenericData.Record

import org.apache.spark.{SparkFunSuite, SharedSparkContext}

class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext {
  conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

  val schema : Schema = SchemaBuilder
    .record("testRecord").fields()
    .requiredString("data")
    .endRecord()
  val record = new Record(schema)
  record.put("data", "test data")

  test("schema compression and decompression") {//模式压缩与解压缩
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    assert(schema === genericSer.decompress(ByteBuffer.wrap(genericSer.compress(schema))))
  }

  test("record serialization and deserialization") {//记录序列化和反序列化
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)

    val outputStream = new ByteArrayOutputStream()
    val output = new Output(outputStream)
    genericSer.serializeDatum(record, output)
    output.flush()
    output.close()

    val input = new Input(new ByteArrayInputStream(outputStream.toByteArray))
    assert(genericSer.deserializeDatum(input) === record)
  }
  //使用模式指纹以减少信息大小
  test("uses schema fingerprint to decrease message size") {
    val genericSerFull = new GenericAvroSerializer(conf.getAvroSchema)

    val output = new Output(new ByteArrayOutputStream())

    val beginningNormalPosition = output.total()
    genericSerFull.serializeDatum(record, output)
    output.flush()
    val normalLength = output.total - beginningNormalPosition

    conf.registerAvroSchemas(schema)
    val genericSerFinger = new GenericAvroSerializer(conf.getAvroSchema)
    val beginningFingerprintPosition = output.total()
    genericSerFinger.serializeDatum(record, output)
    val fingerprintLength = output.total - beginningFingerprintPosition

    assert(fingerprintLength < normalLength)
  }

  test("caches previously seen schemas") {//缓存之前模式
    val genericSer = new GenericAvroSerializer(conf.getAvroSchema)
    val compressedSchema = genericSer.compress(schema)
    val decompressedScheam = genericSer.decompress(ByteBuffer.wrap(compressedSchema))

    assert(compressedSchema.eq(genericSer.compress(schema)))
    assert(decompressedScheam.eq(genericSer.decompress(ByteBuffer.wrap(compressedSchema))))
  }
} 
Example 39
Source File: SparkSchemas.scala    From avro4s   with Apache License 2.0 5 votes vote down vote up
package com.sksamuel.avro4s

import org.apache.avro.{LogicalTypes, SchemaBuilder}

import scala.language.implicitConversions

object SparkSchemas {

  // see https://github.com/sksamuel/avro4s/issues/271
  implicit def BigDecimalSchemaFor(sp: ScalePrecision) =
    
    SchemaFor[BigDecimal](
      if (0 <= sp.precision && sp.precision <= 9) {
        LogicalTypes.decimal(sp.precision, sp.scale).addToSchema(SchemaBuilder.builder.intType)
      } else if (10 <= sp.precision && sp.precision <= 18) {
        LogicalTypes.decimal(sp.precision, sp.scale).addToSchema(SchemaBuilder.builder.longType)
      } else {
        LogicalTypes.decimal(sp.precision, sp.scale).addToSchema(SchemaBuilder.builder.bytesType)
      }
    )
} 
Example 40
Source File: ByteArrayEncoderTest.scala    From avro4s   with Apache License 2.0 5 votes vote down vote up
package com.sksamuel.avro4s.record.encoder

import java.nio.ByteBuffer

import com.sksamuel.avro4s.{AvroSchema, Encoder, SchemaFor}
import org.apache.avro.SchemaBuilder
import org.apache.avro.generic.{GenericFixed, GenericRecord}
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers

class ByteArrayEncoderTest extends AnyFunSuite with Matchers {

  test("encode byte arrays as BYTES type") {
    case class Test(z: Array[Byte])
    val schema = AvroSchema[Test]
    Encoder[Test].encode(Test(Array[Byte](1, 4, 9)))
      .asInstanceOf[GenericRecord]
      .get("z")
      .asInstanceOf[ByteBuffer]
      .array().toList shouldBe List[Byte](1, 4, 9)
  }

  test("encode byte vectors as BYTES type") {
    case class Test(z: Vector[Byte])
    val schema = AvroSchema[Test]
    Encoder[Test].encode(Test(Vector[Byte](1, 4, 9)))
      .asInstanceOf[GenericRecord]
      .get("z")
      .asInstanceOf[ByteBuffer]
      .array().toList shouldBe List[Byte](1, 4, 9)
  }

  test("encode byte seq as BYTES type") {
    case class Test(z: Seq[Byte])
    val schema = AvroSchema[Test]
    Encoder[Test].encode(Test(Seq[Byte](1, 4, 9)))
      .asInstanceOf[GenericRecord]
      .get("z")
      .asInstanceOf[ByteBuffer]
      .array().toList shouldBe List[Byte](1, 4, 9)
  }

  test("encode byte list as BYTES type") {
    case class Test(z: List[Byte])
    val schema = AvroSchema[Test]
    Encoder[Test].encode(Test(List[Byte](1, 4, 9)))
      .asInstanceOf[GenericRecord]
      .get("z")
      .asInstanceOf[ByteBuffer]
      .array().toList shouldBe List[Byte](1, 4, 9)
  }

  test("encode top level byte arrays") {
    val schema = AvroSchema[Array[Byte]]
    Encoder[Array[Byte]].encode(Array[Byte](1, 4, 9))
      .asInstanceOf[ByteBuffer]
      .array().toList shouldBe List[Byte](1, 4, 9)
  }

  test("encode ByteBuffers as BYTES type") {
    case class Test(z: ByteBuffer)
    val schema = AvroSchema[Test]
    Encoder[Test].encode(Test(ByteBuffer.wrap(Array[Byte](1, 4, 9))))
      .asInstanceOf[GenericRecord]
      .get("z")
      .asInstanceOf[ByteBuffer]
      .array().toList shouldBe List[Byte](1, 4, 9)
  }

  test("encode top level ByteBuffers") {
    val schema = AvroSchema[ByteBuffer]
    Encoder[ByteBuffer].encode(ByteBuffer.wrap(Array[Byte](1, 4, 9)))
      .asInstanceOf[ByteBuffer]
      .array().toList shouldBe List[Byte](1, 4, 9)
  }

  test("support FIXED") {
    val schema = SchemaBuilder.fixed("foo").size(7)
    val fixed = Encoder.ByteArrayEncoder.withSchema(SchemaFor(schema)).encode("hello".getBytes).asInstanceOf[GenericFixed]
    fixed.bytes().toList shouldBe Seq(104, 101, 108, 108, 111, 0, 0)
    fixed.bytes().length shouldBe 7
  }
} 
Example 41
Source File: EitherDecoderTest.scala    From avro4s   with Apache License 2.0 5 votes vote down vote up
package com.sksamuel.avro4s.record.decoder

import com.sksamuel.avro4s._
import org.apache.avro.SchemaBuilder
import org.apache.avro.util.Utf8
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers

case class Test(either: Either[String, Double])
case class Goo(s: String)
case class Foo(b: Boolean)
case class Test2(either: Either[Goo, Foo])

class EitherDecoderTest extends AnyFunSuite with Matchers {

  case class Voo(s: String)
  case class Woo(b: Boolean)
  case class Test3(either: Either[Voo, Woo])

  @AvroName("w")
  case class Wobble(s: String)

  @AvroName("t")
  case class Topple(b: Boolean)

  case class Test4(either: Either[Wobble, Topple])

  @AvroNamespace("market")
  case class Apple(s: String)

  @AvroNamespace("market")
  case class Orange(b: Boolean)

  case class Test5(either: Either[Apple, Orange])

  test("decode union:T,U for Either[T,U] of primitives") {
    val schema = AvroSchema[Test]
    Decoder[Test].decode(ImmutableRecord(schema, Vector(new Utf8("foo")))) shouldBe Test(Left("foo"))
    Decoder[Test].decode(ImmutableRecord(schema, Vector(java.lang.Double.valueOf(234.4D)))) shouldBe Test(Right(234.4D))
  }

  test("decode union:T,U for Either[T,U] of top level classes") {
    val schema = AvroSchema[Test2]
    Decoder[Test2].decode(ImmutableRecord(schema, Vector(ImmutableRecord(AvroSchema[Goo], Vector(new Utf8("zzz")))))) shouldBe Test2(Left(Goo("zzz")))
    Decoder[Test2].decode(ImmutableRecord(schema, Vector(ImmutableRecord(AvroSchema[Foo], Vector(java.lang.Boolean.valueOf(true)))))) shouldBe Test2(Right(Foo(true)))
  }

  test("decode union:T,U for Either[T,U] of nested classes") {
    val schema = AvroSchema[Test3]
    Decoder[Test3].decode(ImmutableRecord(schema, Vector(ImmutableRecord(AvroSchema[Voo], Vector(new Utf8("zzz")))))) shouldBe Test3(Left(Voo("zzz")))
    Decoder[Test3].decode(ImmutableRecord(schema, Vector(ImmutableRecord(AvroSchema[Woo], Vector(java.lang.Boolean.valueOf(true)))))) shouldBe Test3(Right(Woo(true)))
  }

  test("use @AvroName defined on a class when choosing which Either to decode") {

    val wschema = SchemaBuilder.record("w").namespace("com.sksamuel.avro4s.record.decoder.EitherDecoderTest").fields().requiredBoolean("s").endRecord()
    val tschema = SchemaBuilder.record("t").namespace("com.sksamuel.avro4s.record.decoder.EitherDecoderTest").fields().requiredString("b").endRecord()
    val union = SchemaBuilder.unionOf().`type`(wschema).and().`type`(tschema).endUnion()
    val schema = SchemaBuilder.record("Test4").fields().name("either").`type`(union).noDefault().endRecord()

    Decoder[Test4].decode(ImmutableRecord(schema, Vector(ImmutableRecord(tschema, Vector(java.lang.Boolean.valueOf(true)))))) shouldBe Test4(Right(Topple(true)))
    Decoder[Test4].decode(ImmutableRecord(schema, Vector(ImmutableRecord(wschema, Vector(new Utf8("zzz")))))) shouldBe Test4(Left(Wobble("zzz")))
  }

  test("use @AvroNamespace when choosing which Either to decode") {

    val appleschema = SchemaBuilder.record("Apple").namespace("market").fields().requiredBoolean("s").endRecord()
    val orangeschema = SchemaBuilder.record("Orange").namespace("market").fields().requiredString("b").endRecord()
    val union = SchemaBuilder.unionOf().`type`(appleschema).and().`type`(orangeschema).endUnion()
    val schema = SchemaBuilder.record("Test5").fields().name("either").`type`(union).noDefault().endRecord()

    Decoder[Test5].decode(ImmutableRecord(schema, Vector(ImmutableRecord(orangeschema, Vector(java.lang.Boolean.valueOf(true)))))) shouldBe Test5(Right(Orange(true)))
    Decoder[Test5].decode(ImmutableRecord(schema, Vector(ImmutableRecord(appleschema, Vector(new Utf8("zzz")))))) shouldBe Test5(Left(Apple("zzz")))
  }
} 
Example 42
Source File: BigDecimalDecoderTest.scala    From avro4s   with Apache License 2.0 5 votes vote down vote up
package com.sksamuel.avro4s.record.decoder

import com.sksamuel.avro4s._
import org.apache.avro.generic.GenericData
import org.apache.avro.{Conversions, LogicalTypes, SchemaBuilder}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

case class WithBigDecimal(decimal: BigDecimal)
case class OptionalBigDecimal(big: Option[BigDecimal])

class BigDecimalDecoderTest extends AnyFlatSpec with Matchers {

  "Decoder" should "convert byte array to decimal" in {
    val schema = AvroSchema[WithBigDecimal]
    val record = new GenericData.Record(schema)
    val bytes =
      new Conversions.DecimalConversion().toBytes(BigDecimal(123.45).bigDecimal, null, LogicalTypes.decimal(8, 2))
    record.put("decimal", bytes)
    Decoder[WithBigDecimal].decode(record) shouldBe WithBigDecimal(BigDecimal(123.45))
  }

  it should "support optional big decimals" in {
    val schema = AvroSchema[OptionalBigDecimal]
    val bytes =
      new Conversions.DecimalConversion().toBytes(BigDecimal(123.45).bigDecimal, null, LogicalTypes.decimal(8, 2))
    val record = new GenericData.Record(schema)
    record.put("big", bytes)
    Decoder[OptionalBigDecimal].decode(record) shouldBe OptionalBigDecimal(Option(BigDecimal(123.45)))

    val emptyRecord = new GenericData.Record(schema)
    emptyRecord.put("big", null)
    Decoder[OptionalBigDecimal].decode(emptyRecord) shouldBe OptionalBigDecimal(None)
  }

  it should "be able to decode strings as bigdecimals" in {
    val schemaFor = BigDecimals.AsString
    Decoder[BigDecimal].withSchema(schemaFor).decode("123.45") shouldBe BigDecimal(123.45)
  }

  it should "be able to decode generic fixed as bigdecimals" in {
    val schemaFor = SchemaFor[BigDecimal](
      LogicalTypes.decimal(10, 8).addToSchema(SchemaBuilder.fixed("BigDecimal").size(8))
    )

    val fixed =
      GenericData.get().createFixed(null, Array[Byte](0, 4, 98, -43, 55, 43, -114, 0), schemaFor.schema)
    Decoder[BigDecimal].withSchema(schemaFor).decode(fixed) shouldBe BigDecimal(12345678)
  }

//  it should "be able to decode longs as bigdecimals" in {
//    val schema = LogicalTypes.decimal(5, 2).addToSchema(SchemaBuilder.builder().longType())
//    BigDecimalDecoder.decode(12345, schema) shouldBe ""
//    BigDecimalDecoder.decode(9999, schema) shouldBe ""
//    BigDecimalDecoder.decode(java.lang.Long.valueOf(99887766), schema) shouldBe ""
//    BigDecimalDecoder.decode(java.lang.Integer.valueOf(654), schema) shouldBe ""
//  }
} 
Example 43
Source File: SchemaEvolutionTest.scala    From avro4s   with Apache License 2.0 5 votes vote down vote up
package com.sksamuel.avro4s.record.decoder

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import com.sksamuel.avro4s._
import org.apache.avro.SchemaBuilder
import org.apache.avro.generic.GenericData
import org.apache.avro.util.Utf8
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers

class SchemaEvolutionTest extends AnyFunSuite with Matchers {

  case class Version1(original: String)
  case class Version2(@AvroAlias("original") renamed: String)

  case class P1(name: String, age: Int = 18)
  case class P2(name: String)

  case class OptionalStringTest(a: String, b: Option[String])
  case class DefaultStringTest(a: String, b: String = "foo")

  ignore("@AvroAlias should be used when a reader schema has a field missing from the write schema") {

    val v1schema = AvroSchema[Version1]
    val v1 = Version1("hello")
    val baos = new ByteArrayOutputStream()
    val output = AvroOutputStream.data[Version1].to(baos).build()
    output.write(v1)
    output.close()

    // we load using a v2 schema
    val is = new AvroDataInputStream[Version2](new ByteArrayInputStream(baos.toByteArray), Some(v1schema))
    val v2 = is.iterator.toList.head

    v2.renamed shouldBe v1.original
  }

  test("when decoding, if the record and schema are missing a field and the target has a scala default, use that") {

    val f1 = RecordFormat[P1]
    val f2 = RecordFormat[P2]

    f1.from(f2.to(P2("foo"))) shouldBe P1("foo")
  }

  test("when decoding, if the record is missing a field that is present in the schema with a default, use the default from the schema") {
    val schema = SchemaBuilder.record("foo").fields().requiredString("a").endRecord()
    val record = new GenericData.Record(schema)
    record.put("a", new Utf8("hello"))
    Decoder[DefaultStringTest].decode(record) shouldBe DefaultStringTest("hello")
  }

  test("when decoding, if the record is missing a field that is present in the schema and the type is option, then set to None") {
    val schema1 = SchemaBuilder.record("foo").fields().requiredString("a").endRecord()
    val schema2 = SchemaBuilder.record("foo").fields().requiredString("a").optionalString("b").endRecord()
    val record = new GenericData.Record(schema1)
    record.put("a", new Utf8("hello"))
    Decoder[OptionalStringTest].decode(record) shouldBe OptionalStringTest("hello", None)
  }
} 
Example 44
Source File: DateDecoderTest.scala    From avro4s   with Apache License 2.0 5 votes vote down vote up
package com.sksamuel.avro4s.record.decoder

import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate, LocalDateTime, LocalTime}

import com.sksamuel.avro4s.SchemaFor.TimestampNanosLogicalType
import com.sksamuel.avro4s.{AvroSchema, Decoder, SchemaFor}
import org.apache.avro.generic.GenericData
import org.apache.avro.{LogicalTypes, SchemaBuilder}
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers

//noinspection ScalaDeprecation
class DateDecoderTest extends AnyFunSuite with Matchers {

  case class WithLocalTime(z: LocalTime)
  case class WithLocalDate(z: LocalDate)
  case class WithDate(z: Date)
  case class WithLocalDateTime(z: LocalDateTime)
  case class WithTimestamp(z: Timestamp)
  case class WithInstant(z: Instant)

  test("decode int to LocalTime") {
    val schema = AvroSchema[WithLocalTime]
    val record = new GenericData.Record(schema)
    record.put("z", 46245000000L)
    Decoder[WithLocalTime].decode(record) shouldBe WithLocalTime(LocalTime.of(12, 50, 45))
  }

  test("decode int to LocalDate") {
    val schema = AvroSchema[WithLocalDate]
    val record = new GenericData.Record(schema)
    record.put("z", 17784)
    Decoder[WithLocalDate].decode(record) shouldBe WithLocalDate(LocalDate.of(2018, 9, 10))
  }

  test("decode int to java.sql.Date") {
    val schema = AvroSchema[WithDate]
    val record = new GenericData.Record(schema)
    record.put("z", 17784)
    Decoder[WithDate].decode(record) shouldBe WithDate(Date.valueOf(LocalDate.of(2018, 9, 10)))
  }

  test("decode timestamp-millis to LocalDateTime") {
    val dateSchema = LogicalTypes.timestampMillis().addToSchema(SchemaBuilder.builder.longType)
    val schema = SchemaBuilder.record("foo").fields().name("z").`type`(dateSchema).noDefault().endRecord()
    val record = new GenericData.Record(schema)
    record.put("z", 1572707106376L)
    Decoder[WithLocalDateTime].withSchema(SchemaFor(schema)).decode(record) shouldBe WithLocalDateTime(
      LocalDateTime.of(2019, 11, 2, 15, 5, 6, 376000000))
  }

  test("decode timestamp-micros to LocalDateTime") {
    val dateSchema = LogicalTypes.timestampMicros().addToSchema(SchemaBuilder.builder.longType)
    val schema = SchemaBuilder.record("foo").fields().name("z").`type`(dateSchema).noDefault().endRecord()
    val record = new GenericData.Record(schema)
    record.put("z", 1572707106376001L)
    Decoder[WithLocalDateTime].withSchema(SchemaFor(schema)).decode(record) shouldBe WithLocalDateTime(
      LocalDateTime.of(2019, 11, 2, 15, 5, 6, 376001000))
  }

  test("decode timestamp-nanos to LocalDateTime") {
    val dateSchema = TimestampNanosLogicalType.addToSchema(SchemaBuilder.builder.longType)
    val schema = SchemaBuilder.record("foo").fields().name("z").`type`(dateSchema).noDefault().endRecord()
    val record = new GenericData.Record(schema)
    record.put("z", 1572707106376000002L)
    Decoder[WithLocalDateTime].decode(record) shouldBe WithLocalDateTime(
      LocalDateTime.of(2019, 11, 2, 15, 5, 6, 376000002))
  }

  test("decode long to Timestamp") {
    val schema = AvroSchema[WithTimestamp]
    val record = new GenericData.Record(schema)
    record.put("z", 1538312231000L)
    Decoder[WithTimestamp].decode(record) shouldBe WithTimestamp(new Timestamp(1538312231000L))
  }

  test("decode long to Instant") {
    val schema = AvroSchema[WithInstant]
    val record = new GenericData.Record(schema)
    record.put("z", 1538312231000L)
    Decoder[WithInstant].decode(record) shouldBe WithInstant(Instant.ofEpochMilli(1538312231000L))
  }
} 
Example 45
Source File: AvroSchemaMergeTest.scala    From avro4s   with Apache License 2.0 5 votes vote down vote up
package com.sksamuel.avro4s.schema

import com.sksamuel.avro4s.AvroSchemaMerge
import org.apache.avro.SchemaBuilder
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

class AvroSchemaMergeTest extends AnyWordSpec with Matchers {
  "AvroSchemaMerge" should {
    "merge schemas with union type" in {
      val schemaOne = SchemaBuilder
        .builder("test")
        .record("s1")
        .fields()
        .requiredString("f1")
        .nullableLong("f2", 0)
        .endRecord()

      val schemaTwo = SchemaBuilder
        .builder("test")
        .record("s2")
        .fields()
        .optionalString("f1")
        .requiredLong("f2")
        .endRecord()

      val expected = SchemaBuilder
        .builder("test")
        .record("s3")
        .fields()
        .optionalString("f1")
        .nullableLong("f2", 0)
        .endRecord()

      AvroSchemaMerge.apply("s3", "test", List(schemaOne, schemaTwo)).toString shouldBe expected.toString
    }
  }
} 
Example 46
Source File: SchemaForTypeclassOverrideTest.scala    From avro4s   with Apache License 2.0 5 votes vote down vote up
package com.sksamuel.avro4s.schema

import com.sksamuel.avro4s.{AvroSchema, SchemaFor}
import org.apache.avro.SchemaBuilder
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers

class SchemaForTypeclassOverrideTest extends AnyFunSuite with Matchers {

  test("allow overriding built in SchemaFor implicit for a basic type") {

    implicit val StringSchemaFor = SchemaFor[String] {
      val schema = SchemaBuilder.builder().bytesType()
      schema.addProp("foo", "bar": AnyRef)
      schema
    }

    case class OverrideTest(s: String, i: Int)

    val expected =
      new org.apache.avro.Schema.Parser().parse(getClass.getResourceAsStream("/schema_override_basic.json"))
    val schema = AvroSchema[OverrideTest]
    schema.toString(true) shouldBe expected.toString(true)
  }

  test("allow overriding built in SchemaFor implicit for a complex type") {

    implicit val FooSchemaFor = SchemaFor[Foo] {
      val schema = SchemaBuilder.builder().doubleType()
      schema.addProp("foo", "bar": AnyRef)
      schema
    }

    case class Foo(s: String, b: Boolean)
    case class OverrideTest(s: String, f: Foo)

    val expected =
      new org.apache.avro.Schema.Parser().parse(getClass.getResourceAsStream("/schema_override_complex.json"))
    val schema = AvroSchema[OverrideTest]
    schema.toString(true) shouldBe expected.toString(true)
  }

  test("allow overriding built in SchemaFor implicit for a value type") {

    implicit val FooValueTypeSchemaFor = SchemaFor[FooValueType] {
      val schema = SchemaBuilder.builder().intType()
      schema.addProp("foo", "bar": AnyRef)
      schema
    }

    case class OverrideTest(s: String, foo: FooValueType)

    val expected =
      new org.apache.avro.Schema.Parser().parse(getClass.getResourceAsStream("/schema_override_value_type.json"))
    val schema = AvroSchema[FooValueType]
    schema.toString(true) shouldBe expected.toString(true)
  }

  test("allow overriding built in SchemaFor implicit for a top level value type") {

    implicit val FooValueTypeSchemaFor = SchemaFor[FooValueType] {
      val schema = SchemaBuilder.builder().intType()
      schema.addProp("foo", "bar": AnyRef)
      schema
    }

    val expected = new org.apache.avro.Schema.Parser()
      .parse(getClass.getResourceAsStream("/schema_override_top_level_value_type.json"))
    val schema = AvroSchema[FooValueType]
    schema.toString(true) shouldBe expected.toString(true)
  }
}

case class FooValueType(s: String) extends AnyVal