com.google.protobuf.ByteString Scala Examples

The following examples show how to use com.google.protobuf.ByteString. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: Queries.scala    From daml   with Apache License 2.0 7 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.on.sql.queries

import java.io.InputStream
import java.sql.{Blob, Connection, PreparedStatement}

import anorm.{
  BatchSql,
  Column,
  MetaDataItem,
  NamedParameter,
  RowParser,
  SqlMappingError,
  SqlParser,
  SqlRequestError,
  ToStatement
}
import com.google.protobuf.ByteString

trait Queries extends ReadQueries with WriteQueries

object Queries {
  val TablePrefix = "ledger"
  val LogTable = s"${TablePrefix}_log"
  val MetaTable = s"${TablePrefix}_meta"
  val StateTable = s"${TablePrefix}_state"

  // By explicitly writing a value to a "table_key" column, we ensure we only ever have one row in
  // the meta table. An attempt to write a second row will result in a key conflict.
  private[queries] val MetaTableKey = 0

  def executeBatchSql(
      query: String,
      params: Iterable[Seq[NamedParameter]],
  )(implicit connection: Connection): Unit = {
    if (params.nonEmpty)
      BatchSql(query, params.head, params.drop(1).toArray: _*).execute()
    ()
  }

  implicit def byteStringToStatement: ToStatement[ByteString] = new ToStatement[ByteString] {
    override def set(s: PreparedStatement, index: Int, v: ByteString): Unit =
      s.setBinaryStream(index, v.newInput(), v.size())
  }

  implicit def columnToByteString: Column[ByteString] =
    Column.nonNull { (value: Any, meta: MetaDataItem) =>
      value match {
        case blob: Blob => Right(ByteString.readFrom(blob.getBinaryStream))
        case byteArray: Array[Byte] => Right(ByteString.copyFrom(byteArray))
        case inputStream: InputStream => Right(ByteString.readFrom(inputStream))
        case _ =>
          Left[SqlRequestError, ByteString](
            SqlMappingError(s"Cannot convert value of column ${meta.column} to ByteString"))
      }
    }

  def getBytes(columnName: String): RowParser[ByteString] =
    SqlParser.get(columnName)(columnToByteString)

} 
Example 2
Source File: InMemoryState.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.on.memory

import java.util.concurrent.locks.StampedLock

import com.daml.ledger.on.memory.InMemoryState._
import com.daml.ledger.participant.state.kvutils.Bytes
import com.daml.ledger.participant.state.kvutils.api.LedgerRecord
import com.daml.ledger.participant.state.v1.Offset
import com.google.protobuf.ByteString

import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future, blocking}

private[memory] class InMemoryState private (log: MutableLog, state: MutableState) {
  private val lockCurrentState = new StampedLock()
  @volatile private var lastLogEntryIndex = 0

  def readLog[A](action: ImmutableLog => A): A =
    action(log) // `log` is mutable, but the interface is immutable

  def newHeadSinceLastWrite(): Int = lastLogEntryIndex

  def write[A](action: (MutableLog, MutableState) => Future[A])(
      implicit executionContext: ExecutionContext
  ): Future[A] =
    for {
      stamp <- Future {
        blocking {
          lockCurrentState.writeLock()
        }
      }
      result <- action(log, state)
        .andThen {
          case _ =>
            lastLogEntryIndex = log.size - 1
            lockCurrentState.unlock(stamp)
        }
    } yield result
}

object InMemoryState {
  type ImmutableLog = IndexedSeq[LedgerRecord]
  type ImmutableState = collection.Map[StateKey, StateValue]

  type MutableLog = mutable.Buffer[LedgerRecord] with ImmutableLog
  type MutableState = mutable.Map[StateKey, StateValue] with ImmutableState

  type StateKey = Bytes
  type StateValue = Bytes

  // The first element will never be read because begin offsets are exclusive.
  private val Beginning = LedgerRecord(Offset.beforeBegin, ByteString.EMPTY, ByteString.EMPTY)

  def empty =
    new InMemoryState(
      log = mutable.ArrayBuffer(Beginning),
      state = mutable.Map.empty,
    )
} 
Example 3
Source File: InMemoryLedgerReaderWriterSpec.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.on.memory

import com.codahale.metrics.MetricRegistry
import com.daml.ledger.api.testing.utils.AkkaBeforeAndAfterAll
import com.daml.ledger.participant.state.kvutils.api.CommitMetadata
import com.daml.ledger.participant.state.v1.{ParticipantId, SubmissionResult}
import com.daml.ledger.validator.{BatchedValidatingCommitter, LedgerStateOperations}
import com.daml.lf.data.Ref
import com.daml.metrics.Metrics
import com.daml.platform.akkastreams.dispatcher.Dispatcher
import com.google.protobuf.ByteString
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito.{times, verify, when}
import org.scalatest.mockito.MockitoSugar
import org.scalatest.{AsyncWordSpec, Matchers}

import scala.concurrent.{ExecutionContext, Future}

class InMemoryLedgerReaderWriterSpec
    extends AsyncWordSpec
    with AkkaBeforeAndAfterAll
    with Matchers
    with MockitoSugar {
  "commit" should {
    "not signal new head in case of failure" in {
      val mockDispatcher = mock[Dispatcher[Index]]
      val mockCommitter = mock[BatchedValidatingCommitter[Index]]
      when(
        mockCommitter.commit(
          anyString(),
          any[ByteString](),
          any[ParticipantId](),
          any[LedgerStateOperations[Index]])(any[ExecutionContext]()))
        .thenReturn(
          Future.successful(SubmissionResult.InternalError("Validation failed with an exception")))
      val instance = new InMemoryLedgerReaderWriter(
        Ref.ParticipantId.assertFromString("participant ID"),
        "ledger ID",
        mockDispatcher,
        InMemoryState.empty,
        mockCommitter,
        new Metrics(new MetricRegistry)
      )

      instance
        .commit("correlation ID", ByteString.copyFromUtf8("some bytes"), CommitMetadata.Empty)
        .map { actual =>
          verify(mockDispatcher, times(0)).signalNewHead(anyInt())
          actual should be(a[SubmissionResult.InternalError])
        }
    }
  }
} 
Example 4
Source File: PackageManagementClient.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.client.services.admin

import com.daml.ledger.api.v1.admin.package_management_service.PackageManagementServiceGrpc.PackageManagementServiceStub
import com.daml.ledger.api.v1.admin.package_management_service.{
  ListKnownPackagesRequest,
  PackageDetails,
  UploadDarFileRequest
}
import com.daml.ledger.client.LedgerClient
import com.google.protobuf.ByteString

import scala.concurrent.{ExecutionContext, Future}

object PackageManagementClient {

  private val listKnownPackagesRequest = ListKnownPackagesRequest()

}

final class PackageManagementClient(service: PackageManagementServiceStub)(
    implicit ec: ExecutionContext) {

  def listKnownPackages(token: Option[String] = None): Future[Seq[PackageDetails]] =
    LedgerClient
      .stub(service, token)
      .listKnownPackages(PackageManagementClient.listKnownPackagesRequest)
      .map(_.packageDetails)

  def uploadDarFile(darFile: ByteString, token: Option[String] = None): Future[Unit] =
    LedgerClient
      .stub(service, token)
      .uploadDarFile(UploadDarFileRequest(darFile))
      .map(_ => ())

} 
Example 5
Source File: PackageManagementServiceIT.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.api.testtool.tests

import com.daml.ledger.api.testtool.infrastructure.Allocation._
import com.daml.ledger.api.testtool.infrastructure.Assertions._
import com.daml.ledger.api.testtool.infrastructure.LedgerTestSuite
import com.daml.ledger.packagemanagementtest.PackageManagementTest.PackageManagementTestTemplate
import com.daml.ledger.packagemanagementtest.PackageManagementTest.PackageManagementTestTemplate._
import com.google.protobuf.ByteString
import io.grpc.Status

import scala.concurrent.{ExecutionContext, Future}

final class PackageManagementServiceIT extends LedgerTestSuite {
  private[this] val testPackageResourcePath =
    "/ledger/ledger-api-test-tool/PackageManagementTest.dar"

  private def loadTestPackage()(implicit ec: ExecutionContext): Future[ByteString] = {
    val testPackage = Future {
      val in = getClass.getResourceAsStream(testPackageResourcePath)
      assert(in != null, s"Unable to load test package resource at '$testPackageResourcePath'")
      in
    }
    val bytes = testPackage.map(ByteString.readFrom)
    bytes.onComplete(_ => testPackage.map(_.close()))
    bytes
  }

  test(
    "PackageManagementEmptyUpload",
    "An attempt at uploading an empty payload should fail",
    allocate(NoParties),
  )(implicit ec => {
    case Participants(Participant(ledger)) =>
      for {
        failure <- ledger.uploadDarFile(ByteString.EMPTY).failed
      } yield {
        assertGrpcError(
          failure,
          Status.Code.INVALID_ARGUMENT,
          "Invalid argument: Invalid DAR: package-upload",
        )
      }
  })

  test(
    "PackageManagementLoad",
    "Concurrent uploads of the same package should be idempotent and result in the package being available for use",
    allocate(SingleParty),
  )(implicit ec => {
    case Participants(Participant(ledger, party)) =>
      for {
        testPackage <- loadTestPackage()
        _ <- Future.sequence(Vector.fill(8)(ledger.uploadDarFile(testPackage)))
        knownPackages <- ledger.listKnownPackages()
        contract <- ledger.create(party, new PackageManagementTestTemplate(party))
        acsBefore <- ledger.activeContracts(party)
        _ <- ledger.exercise(party, contract.exerciseTestChoice)
        acsAfter <- ledger.activeContracts(party)
      } yield {
        val duplicatePackageIds =
          knownPackages.groupBy(_.packageId).mapValues(_.size).filter(_._2 > 1)
        assert(
          duplicatePackageIds.isEmpty,
          s"There are duplicate package identifiers: ${duplicatePackageIds map {
            case (name, count) => s"$name ($count)"
          } mkString (", ")}",
        )
        assert(
          acsBefore.size == 1,
          "After the contract has been created there should be one active contract but there's none",
        )
        assert(
          acsAfter.isEmpty,
          s"There should be no active package after the contract has been consumed: ${acsAfter.map(_.contractId).mkString(", ")}",
        )
      }
  })
} 
Example 6
Source File: SequentialLogEntryId.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.participant.state.kvutils

import java.util.concurrent.atomic.AtomicLong

import com.daml.ledger.participant.state.kvutils.DamlKvutils.DamlLogEntryId
import com.google.protobuf.ByteString

class SequentialLogEntryId(prefix: String) {
  private val currentEntryId = new AtomicLong()
  private val prefixBytes = ByteString.copyFromUtf8(prefix)

  def next(): DamlLogEntryId = {
    val entryId = currentEntryId.getAndIncrement().toHexString
    DamlLogEntryId.newBuilder
      .setEntryId(prefixBytes.concat(ByteString.copyFromUtf8(entryId)))
      .build
  }
} 
Example 7
Source File: Serialization.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.participant.state.kvutils.export

import java.io.{DataInputStream, DataOutputStream}
import java.time.Instant

import com.daml.ledger.participant.state
import com.daml.ledger.participant.state.kvutils.export.FileBasedLedgerDataExporter.{
  SubmissionInfo,
  WriteSet
}
import com.google.protobuf.ByteString

object Serialization {
  def serializeEntry(
      submissionInfo: SubmissionInfo,
      writeSet: WriteSet,
      out: DataOutputStream): Unit = {
    serializeSubmissionInfo(submissionInfo, out)
    serializeWriteSet(writeSet, out)
  }

  def readEntry(input: DataInputStream): (SubmissionInfo, WriteSet) = {
    val submissionInfo = readSubmissionInfo(input)
    val writeSet = readWriteSet(input)
    (submissionInfo, writeSet)
  }

  private def serializeSubmissionInfo(
      submissionInfo: SubmissionInfo,
      out: DataOutputStream): Unit = {
    out.writeUTF(submissionInfo.correlationId)
    out.writeInt(submissionInfo.submissionEnvelope.size())
    submissionInfo.submissionEnvelope.writeTo(out)
    out.writeLong(submissionInfo.recordTimeInstant.toEpochMilli)
    out.writeUTF(submissionInfo.participantId)
  }

  private def readSubmissionInfo(input: DataInputStream): SubmissionInfo = {
    val correlationId = input.readUTF()
    val submissionEnvelopeSize = input.readInt()
    val submissionEnvelope = new Array[Byte](submissionEnvelopeSize)
    input.readFully(submissionEnvelope)
    val recordTimeEpochMillis = input.readLong()
    val participantId = input.readUTF()
    SubmissionInfo(
      ByteString.copyFrom(submissionEnvelope),
      correlationId,
      Instant.ofEpochMilli(recordTimeEpochMillis),
      state.v1.ParticipantId.assertFromString(participantId)
    )
  }

  private def serializeWriteSet(writeSet: WriteSet, out: DataOutputStream): Unit = {
    out.writeInt(writeSet.size)
    for ((key, value) <- writeSet.sortBy(_._1.asReadOnlyByteBuffer())) {
      out.writeInt(key.size())
      key.writeTo(out)
      out.writeInt(value.size())
      value.writeTo(out)
    }
  }

  private def readWriteSet(input: DataInputStream): WriteSet = {
    val numKeyValuePairs = input.readInt()
    (1 to numKeyValuePairs).map { _ =>
      val keySize = input.readInt()
      val keyBytes = new Array[Byte](keySize)
      input.readFully(keyBytes)
      val valueSize = input.readInt()
      val valueBytes = new Array[Byte](valueSize)
      input.readFully(valueBytes)
      (ByteString.copyFrom(keyBytes), ByteString.copyFrom(valueBytes))
    }
  }
} 
Example 8
Source File: FileBasedLedgerDataExporter.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.participant.state.kvutils.export

import java.io.DataOutputStream
import java.time.Instant
import java.util.concurrent.locks.StampedLock

import com.daml.ledger.participant.state.v1.ParticipantId
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}
import com.google.protobuf.ByteString

import scala.collection.mutable
import scala.collection.mutable.ListBuffer


class FileBasedLedgerDataExporter(output: DataOutputStream) extends LedgerDataExporter {
  import FileBasedLedgerDataExporter._

  private val outputLock = new StampedLock

  private[export] val correlationIdMapping = mutable.Map.empty[String, String]
  private[export] val inProgressSubmissions = mutable.Map.empty[String, SubmissionInfo]
  private[export] val bufferedKeyValueDataPerCorrelationId =
    mutable.Map.empty[String, mutable.ListBuffer[(Key, Value)]]

  def addSubmission(
      submissionEnvelope: ByteString,
      correlationId: String,
      recordTimeInstant: Instant,
      participantId: ParticipantId): Unit =
    this.synchronized {
      inProgressSubmissions.put(
        correlationId,
        SubmissionInfo(submissionEnvelope, correlationId, recordTimeInstant, participantId))
      ()
    }

  def addParentChild(parentCorrelationId: String, childCorrelationId: String): Unit =
    this.synchronized {
      correlationIdMapping.put(childCorrelationId, parentCorrelationId)
      ()
    }

  def addToWriteSet(correlationId: String, data: Iterable[(Key, Value)]): Unit =
    this.synchronized {
      correlationIdMapping
        .get(correlationId)
        .foreach { parentCorrelationId =>
          val keyValuePairs = bufferedKeyValueDataPerCorrelationId
            .getOrElseUpdate(parentCorrelationId, ListBuffer.empty)
          keyValuePairs.appendAll(data)
          bufferedKeyValueDataPerCorrelationId.put(parentCorrelationId, keyValuePairs)
        }
    }

  def finishedProcessing(correlationId: String): Unit = {
    val (submissionInfo, bufferedData) = this.synchronized {
      (
        inProgressSubmissions.get(correlationId),
        bufferedKeyValueDataPerCorrelationId.get(correlationId))
    }
    submissionInfo.foreach { submission =>
      bufferedData.foreach(writeSubmissionData(submission, _))
      this.synchronized {
        inProgressSubmissions.remove(correlationId)
        bufferedKeyValueDataPerCorrelationId.remove(correlationId)
        correlationIdMapping
          .collect {
            case (key, value) if value == correlationId => key
          }
          .foreach(correlationIdMapping.remove)
      }
    }
  }

  private def writeSubmissionData(
      submissionInfo: SubmissionInfo,
      writeSet: ListBuffer[(Key, Value)]): Unit = {
    val stamp = outputLock.writeLock()
    try {
      Serialization.serializeEntry(submissionInfo, writeSet, output)
      output.flush()
    } finally {
      outputLock.unlock(stamp)
    }
  }
}

object FileBasedLedgerDataExporter {
  case class SubmissionInfo(
      submissionEnvelope: ByteString,
      correlationId: String,
      recordTimeInstant: Instant,
      participantId: ParticipantId)

  type WriteSet = Seq[(Key, Value)]
} 
Example 9
Source File: LedgerDataExporter.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.participant.state.kvutils.export

import java.io.{DataOutputStream, FileOutputStream}
import java.time.Instant

import com.daml.ledger.participant.state.v1.ParticipantId
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}
import com.google.protobuf.ByteString
import org.slf4j.LoggerFactory

trait LedgerDataExporter {

  
  def finishedProcessing(correlationId: String): Unit
}

object LedgerDataExporter {
  val EnvironmentVariableName = "KVUTILS_LEDGER_EXPORT"

  private val logger = LoggerFactory.getLogger(this.getClass)

  private lazy val outputStreamMaybe: Option[DataOutputStream] = {
    Option(System.getenv(EnvironmentVariableName))
      .map { filename =>
        logger.info(s"Enabled writing ledger entries to $filename")
        new DataOutputStream(new FileOutputStream(filename))
      }
  }

  private lazy val instance = outputStreamMaybe
    .map(new FileBasedLedgerDataExporter(_))
    .getOrElse(NoopLedgerDataExporter)

  def apply(): LedgerDataExporter = instance
} 
Example 10
Source File: TestHelper.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.validator

import com.daml.ledger.participant.state.kvutils.DamlKvutils.{
  DamlCommandDedupKey,
  DamlContractKey,
  DamlLogEntry,
  DamlLogEntryId,
  DamlPartyAllocationEntry,
  DamlStateKey,
  DamlSubmission,
  DamlSubmissionDedupKey
}
import com.daml.ledger.participant.state.v1.ParticipantId
import com.daml.lf.value.ValueOuterClass
import com.google.protobuf.ByteString

private[validator] object TestHelper {

  lazy val aParticipantId: ParticipantId = ParticipantId.assertFromString("aParticipantId")

  lazy val aLogEntry: DamlLogEntry =
    DamlLogEntry
      .newBuilder()
      .setPartyAllocationEntry(
        DamlPartyAllocationEntry.newBuilder().setParty("aParty").setParticipantId(aParticipantId))
      .build()

  lazy val allDamlStateKeyTypes: Seq[DamlStateKey] = Seq(
    DamlStateKey.newBuilder
      .setPackageId("a package ID"),
    DamlStateKey.newBuilder
      .setContractId("a contract ID"),
    DamlStateKey.newBuilder
      .setCommandDedup(DamlCommandDedupKey.newBuilder.setCommandId("an ID")),
    DamlStateKey.newBuilder
      .setParty("a party"),
    DamlStateKey.newBuilder
      .setContractKey(
        DamlContractKey.newBuilder.setTemplateId(
          ValueOuterClass.Identifier.newBuilder.addName("a name"))),
    DamlStateKey.newBuilder.setConfiguration(com.google.protobuf.Empty.getDefaultInstance),
    DamlStateKey.newBuilder.setSubmissionDedup(
      DamlSubmissionDedupKey.newBuilder.setSubmissionId("a submission ID"))
  ).map(_.build)

  lazy val anInvalidEnvelope: ByteString = ByteString.copyFromUtf8("invalid data")

  def makePartySubmission(party: String): DamlSubmission = {
    val builder = DamlSubmission.newBuilder
    builder.setSubmissionSeed(ByteString.EMPTY)
    builder.addInputDamlStateBuilder().setParty(party)
    val submissionId = s"$party-submission"
    builder
      .addInputDamlStateBuilder()
      .getSubmissionDedupBuilder
      .setParticipantId(aParticipantId)
      .setSubmissionId(submissionId)
    builder.getPartyAllocationEntryBuilder
      .setSubmissionId(submissionId)
      .setParticipantId(aParticipantId)
      .setDisplayName(party)
      .setParty(party)
    builder.build
  }

  def aLogEntryId(): DamlLogEntryId = SubmissionValidator.allocateRandomLogEntryId()
} 
Example 11
Source File: BatchedValidatingCommitterSpec.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.validator

import java.time.Instant

import akka.stream.Materializer
import com.daml.ledger.api.testing.utils.AkkaBeforeAndAfterAll
import com.daml.ledger.participant.state.v1.{ParticipantId, SubmissionResult}
import com.daml.ledger.validator.TestHelper.aParticipantId
import com.daml.ledger.validator.batch.BatchedSubmissionValidator
import com.google.protobuf.ByteString
import org.mockito.ArgumentMatchers.{any, anyString}
import org.mockito.Mockito.when
import org.scalatest.mockito.MockitoSugar
import org.scalatest.{AsyncWordSpec, Matchers}

import scala.concurrent.{ExecutionContext, Future}

class BatchedValidatingCommitterSpec
    extends AsyncWordSpec
    with AkkaBeforeAndAfterAll
    with Matchers
    with MockitoSugar {
  "commit" should {
    "return Acknowledged in case of success" in {
      val mockValidator = mock[BatchedSubmissionValidator[Unit]]
      when(
        mockValidator.validateAndCommit(
          any[ByteString](),
          anyString(),
          any[Instant](),
          any[ParticipantId](),
          any[DamlLedgerStateReader](),
          any[CommitStrategy[Unit]]())(any[Materializer](), any[ExecutionContext]()))
        .thenReturn(Future.unit)
      val instance =
        BatchedValidatingCommitter[Unit](() => Instant.now(), mockValidator)

      instance
        .commit("", ByteString.EMPTY, aParticipantId, mock[LedgerStateOperations[Unit]])
        .map { actual =>
          actual shouldBe SubmissionResult.Acknowledged
        }
    }

    "return InternalError in case of an exception" in {
      val mockValidator = mock[BatchedSubmissionValidator[Unit]]
      when(
        mockValidator.validateAndCommit(
          any[ByteString](),
          anyString(),
          any[Instant](),
          any[ParticipantId](),
          any[DamlLedgerStateReader](),
          any[CommitStrategy[Unit]]())(any[Materializer](), any[ExecutionContext]()))
        .thenReturn(Future.failed(new IllegalArgumentException("Validation failure")))
      val instance = BatchedValidatingCommitter[Unit](() => Instant.now(), mockValidator)

      instance
        .commit("", ByteString.EMPTY, aParticipantId, mock[LedgerStateOperations[Unit]])
        .map { actual =>
          actual shouldBe SubmissionResult.InternalError("Validation failure")
        }
    }
  }
} 
Example 12
Source File: LogAppendingCommitStrategySpec.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.validator

import org.scalatest.mockito.MockitoSugar
import org.scalatest.{AsyncWordSpec, Matchers}
import TestHelper._
import com.daml.ledger.participant.state.kvutils.DamlKvutils.{DamlStateKey, DamlStateValue}
import com.daml.ledger.participant.state.kvutils.Envelope
import com.daml.ledger.validator.LedgerStateOperations.{Key, Value}
import com.google.protobuf.ByteString
import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito.{times, verify, when}

import scala.concurrent.Future

class LogAppendingCommitStrategySpec extends AsyncWordSpec with Matchers with MockitoSugar {
  "commit" should {
    "return index from appendToLog" in {
      val mockLedgerStateOperations = mock[LedgerStateOperations[Long]]
      val expectedIndex = 1234L
      when(mockLedgerStateOperations.appendToLog(any[Key](), any[Value]()))
        .thenReturn(Future.successful(expectedIndex))
      val instance =
        new LogAppendingCommitStrategy[Long](
          mockLedgerStateOperations,
          DefaultStateKeySerializationStrategy)

      instance
        .commit(aParticipantId, "a correlation ID", aLogEntryId(), aLogEntry, Map.empty, Map.empty)
        .map { actualIndex =>
          verify(mockLedgerStateOperations, times(1)).appendToLog(any[Key](), any[Value]())
          verify(mockLedgerStateOperations, times(0)).writeState(any[Seq[(Key, Value)]]())
          actualIndex should be(expectedIndex)
        }
    }

    "write keys serialized according to strategy" in {
      val mockLedgerStateOperations = mock[LedgerStateOperations[Long]]
      val actualOutputStateBytesCaptor = ArgumentCaptor
        .forClass(classOf[Seq[(Key, Value)]])
        .asInstanceOf[ArgumentCaptor[Seq[(Key, Value)]]]
      when(mockLedgerStateOperations.writeState(actualOutputStateBytesCaptor.capture()))
        .thenReturn(Future.unit)
      when(mockLedgerStateOperations.appendToLog(any[Key](), any[Value]()))
        .thenReturn(Future.successful(0L))
      val mockStateKeySerializationStrategy = mock[StateKeySerializationStrategy]
      val expectedStateKey = ByteString.copyFromUtf8("some key")
      when(mockStateKeySerializationStrategy.serializeStateKey(any[DamlStateKey]()))
        .thenReturn(expectedStateKey)
      val expectedOutputStateBytes = Seq((expectedStateKey, Envelope.enclose(aStateValue)))
      val instance =
        new LogAppendingCommitStrategy[Long](
          mockLedgerStateOperations,
          mockStateKeySerializationStrategy)

      instance
        .commit(
          aParticipantId,
          "a correlation ID",
          aLogEntryId(),
          aLogEntry,
          Map.empty,
          Map(aStateKey -> aStateValue))
        .map { _: Long =>
          verify(mockStateKeySerializationStrategy, times(1)).serializeStateKey(aStateKey)
          verify(mockLedgerStateOperations, times(1)).writeState(any[Seq[(Key, Value)]]())
          actualOutputStateBytesCaptor.getValue should be(expectedOutputStateBytes)
        }
    }
  }

  private val aStateKey: DamlStateKey = DamlStateKey
    .newBuilder()
    .setContractId(1.toString)
    .build

  private val aStateValue: DamlStateValue = DamlStateValue.getDefaultInstance
} 
Example 13
Source File: EnvelopeSpec.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.participant.state.kvutils

import com.daml.ledger.participant.state.kvutils.{DamlKvutils => Proto}
import com.google.protobuf.ByteString
import org.scalatest.{Matchers, WordSpec}

class EnvelopeSpec extends WordSpec with Matchers {
  "envelope" should {

    "be able to enclose and open" in {
      val submission = Proto.DamlSubmission.getDefaultInstance

      Envelope.open(Envelope.enclose(submission)) shouldEqual
        Right(Envelope.SubmissionMessage(submission))

      val logEntry = Proto.DamlLogEntry.getDefaultInstance
      Envelope.open(Envelope.enclose(logEntry)) shouldEqual
        Right(Envelope.LogEntryMessage(logEntry))

      val stateValue = Proto.DamlStateValue.getDefaultInstance
      Envelope.open(Envelope.enclose(stateValue)) shouldEqual
        Right(Envelope.StateValueMessage(stateValue))
    }

    "be able to enclose and open batch submission batch message" in {
      val submissionBatch = Proto.DamlSubmissionBatch.newBuilder
        .addSubmissions(
          Proto.DamlSubmissionBatch.CorrelatedSubmission.newBuilder
            .setCorrelationId("anId")
            .setSubmission(ByteString.copyFromUtf8("a submission")))
        .build
      Envelope.open(Envelope.enclose(submissionBatch)) shouldEqual
        Right(Envelope.SubmissionBatchMessage(submissionBatch))
    }
  }
} 
Example 14
Source File: FileBasedLedgerDataExportSpec.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.ledger.participant.state.kvutils.export

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.time.Instant

import com.daml.ledger.participant.state.v1
import com.google.protobuf.ByteString
import org.scalatest.mockito.MockitoSugar
import org.scalatest.{Matchers, WordSpec}

class FileBasedLedgerDataExportSpec extends WordSpec with Matchers with MockitoSugar {
  // XXX SC remove in Scala 2.13; see notes in ConfSpec
  import scala.collection.GenTraversable, org.scalatest.enablers.Containing
  private[this] implicit def `fixed sig containingNatureOfGenTraversable`[
      E: org.scalactic.Equality,
      TRAV]: Containing[TRAV with GenTraversable[E]] =
    Containing.containingNatureOfGenTraversable[E, GenTraversable]

  "addParentChild" should {
    "add entry to correlation ID mapping" in {
      val instance = new FileBasedLedgerDataExporter(mock[DataOutputStream])
      instance.addParentChild("parent", "child")

      instance.correlationIdMapping should contain("child" -> "parent")
    }
  }

  "addToWriteSet" should {
    "append to existing data" in {
      val instance = new FileBasedLedgerDataExporter(mock[DataOutputStream])
      instance.addParentChild("parent", "child")
      instance.addToWriteSet("child", Seq(keyValuePairOf("a", "b")))
      instance.addToWriteSet("child", Seq(keyValuePairOf("c", "d")))

      instance.bufferedKeyValueDataPerCorrelationId should contain(
        "parent" ->
          Seq(keyValuePairOf("a", "b"), keyValuePairOf("c", "d")))
    }
  }

  "finishedProcessing" should {
    "remove all data such as submission info, write-set and child correlation IDs" in {
      val dataOutputStream = new DataOutputStream(new ByteArrayOutputStream())
      val instance = new FileBasedLedgerDataExporter(dataOutputStream)
      instance.addSubmission(
        ByteString.copyFromUtf8("an envelope"),
        "parent",
        Instant.now(),
        v1.ParticipantId.assertFromString("id"))
      instance.addParentChild("parent", "parent")
      instance.addToWriteSet("parent", Seq(keyValuePairOf("a", "b")))

      instance.finishedProcessing("parent")

      instance.inProgressSubmissions shouldBe empty
      instance.bufferedKeyValueDataPerCorrelationId shouldBe empty
      instance.correlationIdMapping shouldBe empty
    }
  }

  "serialized submission" should {
    "be readable back" in {
      val baos = new ByteArrayOutputStream()
      val dataOutputStream = new DataOutputStream(baos)
      val instance = new FileBasedLedgerDataExporter(dataOutputStream)
      val expectedRecordTimeInstant = Instant.now()
      val expectedParticipantId = v1.ParticipantId.assertFromString("id")
      instance.addSubmission(
        ByteString.copyFromUtf8("an envelope"),
        "parent",
        expectedRecordTimeInstant,
        v1.ParticipantId.assertFromString("id"))
      instance.addParentChild("parent", "parent")
      instance.addToWriteSet("parent", Seq(keyValuePairOf("a", "b")))

      instance.finishedProcessing("parent")

      val dataInputStream = new DataInputStream(new ByteArrayInputStream(baos.toByteArray))
      val (actualSubmissionInfo, actualWriteSet) = Serialization.readEntry(dataInputStream)
      actualSubmissionInfo.submissionEnvelope should be(ByteString.copyFromUtf8("an envelope"))
      actualSubmissionInfo.correlationId should be("parent")
      actualSubmissionInfo.recordTimeInstant should be(expectedRecordTimeInstant)
      actualSubmissionInfo.participantId should be(expectedParticipantId)
      actualWriteSet should be(Seq(keyValuePairOf("a", "b")))
    }
  }

  private def keyValuePairOf(key: String, value: String): (ByteString, ByteString) =
    ByteString.copyFromUtf8(key) -> ByteString.copyFromUtf8(value)
} 
Example 15
Source File: ResultAssertions.scala    From daml   with Apache License 2.0 5 votes vote down vote up
// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

package com.daml.grpc.adapter.client

import com.daml.platform.hello.HelloResponse
import com.google.protobuf.ByteString
import io.grpc.{Status, StatusRuntimeException}
import org.scalatest.{Assertion, Matchers}

import scala.util.Random

trait ResultAssertions { self: Matchers =>

  protected def elemCount: Int = 1024
  protected lazy val elemRange: Range = 1.to(elemCount)
  protected lazy val halfCount: Int = elemCount / 2
  protected lazy val halfRange: Range = elemRange.take(halfCount)

  protected def isCancelledException(err: Throwable): Assertion = {
    err shouldBe a[StatusRuntimeException]
    err.asInstanceOf[StatusRuntimeException].getStatus.getCode shouldEqual Status.CANCELLED.getCode
  }

  protected def assertElementsAreInOrder(expectedCount: Long)(
      results: Seq[HelloResponse]
  ): Assertion = {
    results should have length expectedCount
    results.map(_.respInt) shouldEqual (1 to expectedCount.toInt)
  }

  protected def elementsAreSummed(results: Seq[HelloResponse]): Assertion = {
    results should have length 1
    results.foldLeft(0)(_ + _.respInt) shouldEqual elemRange.sum
  }

  protected def everyElementIsDoubled(results: Seq[HelloResponse]): Assertion = {
    results should have length elemCount.toLong
    //the order does matter
    results.map(_.respInt) shouldEqual elemRange.map(_ * 2)
  }

  protected def genPayload(): ByteString = {
    val bytes = new Array[Byte](1024)
    Random.nextBytes(bytes)
    ByteString.copyFrom(bytes)
  }
} 
Example 16
Source File: TFTensorNumeric.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils.tf

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.{ConvertableFrom, StringType, TensorDataType}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.UndefinedTensorNumeric

import scala.language.implicitConversions

object TFTensorNumeric {

  implicit object NumericByteString extends UndefinedTensorNumeric[ByteString]("ByteString") {

    override def getType(): TensorDataType = StringType
    override def plus(x: ByteString, y: ByteString): ByteString = x.concat(y)


    override def fromType[K](k: K)(implicit c: ConvertableFrom[K]): ByteString = {
      ByteString.copyFromUtf8(k.toString)
    }

    override def axpy(n: Int, da: ByteString, dx: Array[ByteString],
                      _dx_offset: Int, incx: Int, dy: Array[ByteString],
                      _dy_offset: Int, incy: Int): Unit = {
      var i = 0
      while (i < n) {
        dy(i + _dy_offset) = dx(_dx_offset + i).concat(dy(_dy_offset + i))
        i += 1
      }
    }

    override def nearlyEqual(a: ByteString, b: ByteString, epsilon: Double): Boolean = {
      a == b
    }

  }
} 
Example 17
Source File: ParseSingleExample.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils.tf.loaders

import java.nio.ByteOrder

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.tf.{ParseSingleExample => ParseSingleExampleOperation}
import com.intel.analytics.bigdl.tensor._
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.tf.Context
import org.tensorflow.framework.{DataType, NodeDef}

import collection.JavaConverters._
import scala.reflect.ClassTag

class ParseSingleExample extends TensorflowOpsLoader {

  import Utils._

  override def build[T: ClassTag](nodeDef: NodeDef, byteOrder: ByteOrder,
    context: Context[T])(implicit ev: TensorNumeric[T]): Module[T] = {
    val Tdense = nodeDef.getAttrMap.get("Tdense")
      .getList.getTypeList.asScala
      .map {
        case DataType.DT_INT64 => LongType
        case DataType.DT_INT32 => IntType
        case DataType.DT_FLOAT => FloatType
        case DataType.DT_DOUBLE => DoubleType
        case DataType.DT_STRING => StringType
        case _ => throw new IllegalArgumentException()
      }
    val denseKeysByteArray = nodeDef.getAttrMap.get("dense_keys").getList.
      getSList.asScala.map(_.toByteArray)
    val denseKeys = denseKeysByteArray.map(ByteString.copyFrom(_))
    val denseShapes = nodeDef.getAttrMap.get("dense_shapes")
      .getList.getShapeList.asScala
      .map { shapeProto =>
        shapeProto.getDimList.asScala.map(_.getSize.toInt).toArray
      }

    new ParseSingleExampleOperation[T](Tdense, denseKeys, denseShapes)
  }
} 
Example 18
Source File: Const.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils.tf.loaders

import java.nio.ByteOrder

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.tf.{Const => ConstOps}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric.{NumericBoolean, NumericChar, NumericDouble, NumericFloat, NumericInt, NumericLong, NumericShort, NumericString}
import com.intel.analytics.bigdl.utils.tf.TFTensorNumeric.NumericByteString
import com.intel.analytics.bigdl.utils.tf.{Context, TFUtils}
import org.tensorflow.framework.NodeDef

import scala.reflect.ClassTag

class Const extends TensorflowOpsLoader {
  override def build[T: ClassTag](nodeDef: NodeDef, byteOrder: ByteOrder,
    context: Context[T])(implicit ev: TensorNumeric[T]): Module[T] = {
    val value = TFUtils.parseTensor(nodeDef.getAttrMap.get("value").getTensor, byteOrder)
    val const = value.getTensorNumeric() match {
      case NumericFloat => ConstOps[T, Float](value.asInstanceOf[Tensor[Float]])
      case NumericDouble => ConstOps[T, Double](value.asInstanceOf[Tensor[Double]])
      case NumericInt => ConstOps[T, Int](value.asInstanceOf[Tensor[Int]])
      case NumericLong => ConstOps[T, Long](value.asInstanceOf[Tensor[Long]])
      case NumericChar => ConstOps[T, Char](value.asInstanceOf[Tensor[Char]])
      case NumericBoolean => ConstOps[T, Boolean](value.asInstanceOf[Tensor[Boolean]])
      case NumericShort => ConstOps[T, Short](value.asInstanceOf[Tensor[Short]])
      case NumericString => ConstOps[T, String](value.asInstanceOf[Tensor[String]])
      case NumericByteString => ConstOps[T, ByteString](value.asInstanceOf[Tensor[ByteString]])
    }
    const.asInstanceOf[Module[T]]
  }
} 
Example 19
Source File: Types.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils.serializer

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric.{NumericBoolean, NumericChar, NumericDouble, NumericFloat, NumericInt, NumericLong, NumericString}
import com.intel.analytics.bigdl.utils.tf.TFTensorNumeric.NumericByteString
import com.intel.analytics.bigdl.serialization.Bigdl.BigDLModule

import scala.collection.mutable
import scala.reflect.ClassTag


trait StorageType
object ProtoStorageType extends StorageType
object BigDLStorage extends StorageType

case class SerializeContext[T: ClassTag](moduleData: ModuleData[T],
                                         storages: mutable.HashMap[Int, Any],
                                         storageType: StorageType,
                                         copyWeightAndBias : Boolean = true,
                                         groupType : String = null)
case class DeserializeContext(bigdlModule : BigDLModule,
                              storages: mutable.HashMap[Int, Any],
                              storageType: StorageType,
                              copyWeightAndBias : Boolean = true)

case class SerializeResult(bigDLModule: BigDLModule.Builder, storages: mutable.HashMap[Int, Any])

case class ModuleData[T: ClassTag](module : AbstractModule[Activity, Activity, T],
                                   pre : Seq[String], next : Seq[String])

object BigDLDataType extends Enumeration{
  type BigDLDataType = Value
  val FLOAT, DOUBLE, CHAR, BOOL, STRING, INT, SHORT, LONG, BYTESTRING, BYTE = Value
}

object SerConst {
  val MAGIC_NO = 3721
  val DIGEST_TYPE = "MD5"
  val GLOBAL_STORAGE = "global_storage"
  val MODULE_TAGES = "module_tags"
  val MODULE_NUMERICS = "module_numerics"
  val GROUP_TYPE = "group_type"
}

object ClassTagMapper {
  def apply(tpe : String): ClassTag[_] = {
    tpe match {
      case "Float" => scala.reflect.classTag[Float]
      case "Double" => scala.reflect.classTag[Double]
      case "Char" => scala.reflect.classTag[Char]
      case "Boolean" => scala.reflect.classTag[Boolean]
      case "String" => scala.reflect.classTag[String]
      case "Int" => scala.reflect.classTag[Int]
      case "Long" => scala.reflect.classTag[Long]
      case "com.google.protobuf.ByteString" => scala.reflect.classTag[ByteString]
    }
  }

  def apply(classTag: ClassTag[_]): String = classTag.toString
}
object TensorNumericMapper {
  def apply(tpe : String): TensorNumeric[_] = {
    tpe match {
      case "Float" => NumericFloat
      case "Double" => NumericDouble
      case "Char" => NumericChar
      case "Boolean" => NumericBoolean
      case "String" => NumericString
      case "Int" => NumericInt
      case "Long" => NumericLong
      case "ByteString" => NumericByteString
    }
  }

  def apply(tensorNumeric: TensorNumeric[_]): String = {
    tensorNumeric match {
      case NumericFloat => "Float"
      case NumericDouble => "Double"
      case NumericChar => "Char"
      case NumericBoolean => "Boolean"
      case NumericString => "String"
      case NumericInt => "Int"
      case NumericLong => "Long"
      case NumericByteString => "ByteString"
    }
  }
} 
Example 20
Source File: Assert.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.tf

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.nn.ops.Operation
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.Table

import scala.reflect.ClassTag


private[bigdl] class Assert[T: ClassTag]()
  (implicit ev: TensorNumeric[T]) extends Operation[Table, Activity, T] {
  override def updateOutput(input: Table): Tensor[T] = {
    val predicateTensor = input(1).asInstanceOf[Tensor[Boolean]]
    val messageTensor = input(2).asInstanceOf[Tensor[ByteString]]

    val predicate = predicateTensor.value()
    val message = messageTensor.value()

    assert(predicate, message.toStringUtf8)
    null
  }
} 
Example 21
Source File: ApproximateEqual.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.ops

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.tensor._
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.Table

import scala.reflect.ClassTag

class ApproximateEqual[T: ClassTag](tolerance: Float)
                        (implicit ev: TensorNumeric[T]) extends Compare[T] {
  override def compareFloat(a: Float, b: Float): Boolean = math.abs(a - b) < tolerance

  override def compareDouble(a: Double, b: Double): Boolean = math.abs(a - b) < tolerance

  override def compareChar(a: Char, b: Char): Boolean = math.abs(a - b) < tolerance

  override def compareLong(a: Long, b: Long): Boolean = math.abs(a - b) < tolerance

  override def compareShort(a: Short, b: Short): Boolean = math.abs(a - b) < tolerance

  override def compareInt(a: Int, b: Int): Boolean = math.abs(a - b) < tolerance

  override def compareBoolean(a: Boolean, b: Boolean): Boolean = {
    throw new UnsupportedOperationException("Does not support ApproximateEqual on Boolean")
  }

  override def compareByteString(a: ByteString, b: ByteString): Boolean = {
    throw new UnsupportedOperationException("Does not support ApproximateEqual on ByteString")
  }
}

object ApproximateEqual {
  def apply[T: ClassTag](tolerance: Float)
     (implicit ev: TensorNumeric[T]): Operation[Activity, Activity, T]
  = ModuleToOperation[T](new ApproximateEqual(tolerance))
} 
Example 22
Source File: NotEqual.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.ops

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.tensor._
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.Table

import scala.reflect.ClassTag

class NotEqual[T: ClassTag]()
  (implicit ev: TensorNumeric[T]) extends Operation[Table, Tensor[Boolean], T] {

  output = Activity.allocate[Tensor[Boolean], Boolean]()

  override def updateOutput(input: Table): Tensor[Boolean] = {
    output.resizeAs(input(1))
    input[Tensor[_]](1).getType() match {
      case FloatType =>
        output.zipWith[Float, Float](
          input[Tensor[Float]](1),
          input[Tensor[Float]](2),
          (a, b) => a != b)
      case BooleanType =>
        output.zipWith[Boolean, Boolean](
          input[Tensor[Boolean]](1),
          input[Tensor[Boolean]](2),
          (a, b) => a != b)
      case DoubleType =>
        output.zipWith[Double, Double](
          input[Tensor[Double]](1),
          input[Tensor[Double]](2),
          (a, b) => a != b)
      case CharType =>
        output.zipWith[Char, Char](
          input[Tensor[Char]](1),
          input[Tensor[Char]](2),
          (a, b) => a != b)
      case StringType =>
        output.zipWith[ByteString, ByteString](
          input[Tensor[ByteString]](1),
          input[Tensor[ByteString]](2),
          (a, b) => a != b)
      case LongType =>
        output.zipWith[Long, Long](
          input[Tensor[Long]](1),
          input[Tensor[Long]](2),
          (a, b) => a != b)
      case ShortType =>
        output.zipWith[Short, Short](
          input[Tensor[Short]](1),
          input[Tensor[Short]](2),
          (a, b) => a != b)
      case IntType =>
        output.zipWith[Int, Int](
          input[Tensor[Int]](1),
          input[Tensor[Int]](2),
          (a, b) => a != b)
      case _ => throw new RuntimeException("Unsupported tensor type")
    }

    output
  }
}

object NotEqual {
  def apply[T: ClassTag]()(implicit ev: TensorNumeric[T]): Operation[Activity, Activity, T]
  = ModuleToOperation[T](new NotEqual())
} 
Example 23
Source File: Compare.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.ops

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.tensor._
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.Table

import scala.reflect.ClassTag

abstract class Compare[T: ClassTag]()
(implicit ev: TensorNumeric[T]) extends Operation[Table, Tensor[Boolean], T] {

  def compareFloat(a: Float, b: Float): Boolean

  def compareDouble(a: Double, b: Double): Boolean

  def compareChar(a: Char, b: Char): Boolean

  def compareLong(a: Long, b: Long): Boolean

  def compareShort(a: Short, b: Short): Boolean

  def compareInt(a: Int, b: Int): Boolean

  def compareBoolean(a: Boolean, b: Boolean): Boolean

  def compareByteString(a: ByteString, b: ByteString): Boolean

  output = Activity.allocate[Tensor[Boolean], Boolean]()

  override def updateOutput(input: Table): Tensor[Boolean] = {
    output.resizeAs(input(1))
    input[Tensor[_]](1).getType() match {
      case FloatType =>
        output.zipWith[Float, Float](
          input[Tensor[Float]](1),
          input[Tensor[Float]](2),
          (a, b) => compareFloat(a, b))
      case DoubleType =>
        output.zipWith[Double, Double](
          input[Tensor[Double]](1),
          input[Tensor[Double]](2),
          (a, b) => compareDouble(a, b))
      case CharType =>
        output.zipWith[Char, Char](
          input[Tensor[Char]](1),
          input[Tensor[Char]](2),
          (a, b) => compareChar(a, b))
      case LongType =>
        output.zipWith[Long, Long](
          input[Tensor[Long]](1),
          input[Tensor[Long]](2),
          (a, b) => compareLong(a, b))
      case ShortType =>
        output.zipWith[Short, Short](
          input[Tensor[Short]](1),
          input[Tensor[Short]](2),
          (a, b) => compareShort(a, b))
      case IntType =>
        output.zipWith[Int, Int](
          input[Tensor[Int]](1),
          input[Tensor[Int]](2),
          (a, b) => compareInt(a, b))
      case BooleanType =>
        output.zipWith[Boolean, Boolean](
          input[Tensor[Boolean]](1),
          input[Tensor[Boolean]](2),
          (a, b) => compareBoolean(a, b))
      case StringType =>
        output.zipWith[ByteString, ByteString](
          input[Tensor[ByteString]](1),
          input[Tensor[ByteString]](2),
          (a, b) => compareByteString(a, b))
      case _ => throw new RuntimeException("Unsupported tensor type")
    }

    output
  }
} 
Example 24
Source File: Substr.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.ops

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.tensor._
import com.intel.analytics.bigdl.utils.Table

import scala.reflect.ClassTag

class Substr[T: ClassTag]()
   (implicit ev: TensorNumeric[T]) extends Operation[Table, Tensor[ByteString], T] {

  override def updateOutput(input: Table): Tensor[ByteString] = {
    val data = input[Tensor[ByteString]](1).value()
    val pos = input[Tensor[Int]](2).value()
    val len = input[Tensor[Int]](3).value()
    import com.intel.analytics.bigdl.utils.tf.TFTensorNumeric.NumericByteString

    output = Tensor.scalar(data.substring(pos, pos + len))
    output
  }
}

object Substr {
  def apply[T: ClassTag]()
                        (implicit ev: TensorNumeric[T]):
  Operation[Activity, Activity, T]
  = new Substr[T]().asInstanceOf[Operation[Activity, Activity, T]]
} 
Example 25
Source File: Activity.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.abstractnn

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.{T, Table}

import scala.reflect._


  def allocate[D <: Activity: ClassTag, T : ClassTag](): D = {
    val buffer = if (classTag[D] == classTag[Table]) {
      T()
    } else if (classTag[D] == classTag[Tensor[_]]) {
      if (classTag[Boolean] == classTag[T]) {
        import com.intel.analytics.bigdl.numeric.NumericBoolean
        Tensor[Boolean]()
      } else if (classTag[Char] == classTag[T]) {
        import com.intel.analytics.bigdl.numeric.NumericChar
        Tensor[Char]()
      } else if (classTag[Short] == classTag[T]) {
        import com.intel.analytics.bigdl.numeric.NumericShort
        Tensor[Short]()
      } else if (classTag[Int] == classTag[T]) {
        import com.intel.analytics.bigdl.numeric.NumericInt
        Tensor[Int]()
      } else if (classTag[Long] == classTag[T]) {
        import com.intel.analytics.bigdl.numeric.NumericLong
        Tensor[Long]()
      } else if (classTag[Float] == classTag[T]) {
        import com.intel.analytics.bigdl.numeric.NumericFloat
        Tensor[Float]()
      } else if (classTag[Double] == classTag[T]) {
        import com.intel.analytics.bigdl.numeric.NumericDouble
        Tensor[Double]()
      } else if (classTag[String] == classTag[T]) {
        import com.intel.analytics.bigdl.numeric.NumericString
        Tensor[String]()
      } else if (classTag[ByteString] == classTag[T]) {
        import com.intel.analytics.bigdl.utils.tf.TFTensorNumeric.NumericByteString
        Tensor[ByteString]()
      } else {
        throw new IllegalArgumentException("Type T activity is not supported")
      }
    } else {
      null
    }
    buffer.asInstanceOf[D]
  }

  def emptyGradInput(name: String): EmptyGradInput = new EmptyGradInput(name)
} 
Example 26
Source File: TFTensorNumericSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils.tf

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.Tensor
import org.scalatest.{FlatSpec, Matchers}

class TFTensorNumericSpec extends FlatSpec with Matchers {

  import TFTensorNumeric.NumericByteString

  "String Tensor" should "works correctly" in {
    val a = Tensor[ByteString](Array(ByteString.copyFromUtf8("a"),
      ByteString.copyFromUtf8("b")), Array(2))
    val b = Tensor[ByteString](Array(ByteString.copyFromUtf8("a"),
      ByteString.copyFromUtf8("b")), Array(2))
    val sum = Tensor[ByteString](Array(ByteString.copyFromUtf8("aa"),
      ByteString.copyFromUtf8("bb")), Array(2))

    a + b should be (sum)
  }

} 
Example 27
Source File: TFUtilsSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils.tf

import java.io.File
import java.nio.ByteOrder

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.T
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}
import org.tensorflow.framework.TensorProto

import scala.collection.JavaConverters._

class TFUtilsSpec extends FlatSpec with Matchers with BeforeAndAfter {

  private var constTensors: Map[String, TensorProto] = null
  before {
    constTensors = getConstTensorProto()
  }

  private def getConstTensorProto(): Map[String, TensorProto] = {
    val resource = getClass.getClassLoader.getResource("tf")
    val path = resource.getPath + File.separator + "consts.pbtxt"
    val nodes = TensorflowLoader.parseTxt(path)
    nodes.asScala.map(node => node.getName -> node.getAttrMap.get("value").getTensor).toMap
  }

  "parseTensor " should "work with bool TensorProto" in {
    val tensorProto = constTensors("bool_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Boolean](T(true, false, true, false)))
  }

  "parseTensor " should "work with float TensorProto" in {
    val tensorProto = constTensors("float_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Float](T(1.0f, 2.0f, 3.0f, 4.0f)))
  }

  "parseTensor " should "work with double TensorProto" in {
    val tensorProto = constTensors("double_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Double](T(1.0, 2.0, 3.0, 4.0)))
  }

  "parseTensor " should "work with int TensorProto" in {
    val tensorProto = constTensors("int_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Int](T(1, 2, 3, 4)))
  }

  "parseTensor " should "work with long TensorProto" in {
    val tensorProto = constTensors("long_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Long](T(1, 2, 3, 4)))
  }

  "parseTensor " should "work with int8 TensorProto" in {
    val tensorProto = constTensors("int8_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Int](T(1, 2, 3, 4)))
  }

  "parseTensor " should "work with uint8 TensorProto" in {
    val tensorProto = constTensors("uint8_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Int](T(1, 2, 3, 4)))
  }

  "parseTensor " should "work with int16 TensorProto" in {
    val tensorProto = constTensors("int16_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Int](T(1, 2, 3, 4)))
  }

  "parseTensor " should "work with uint16 TensorProto" in {
    val tensorProto = constTensors("uint16_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Int](T(1, 2, 3, 4)))
  }

  "parseTensor " should "work with string TensorProto" in {
    import TFTensorNumeric.NumericByteString
    val tensorProto = constTensors("string_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    val data = Array(
      ByteString.copyFromUtf8("a"),
      ByteString.copyFromUtf8("b"),
      ByteString.copyFromUtf8("c"),
      ByteString.copyFromUtf8("d")
    )
    bigdlTensor should be (Tensor[ByteString](data, Array[Int](4)))
  }
} 
Example 28
Source File: Conv3DBackpropInputV2Spec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils.tf.loaders

import java.nio.charset.Charset

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.tf.{PaddingType, TensorflowSpecHelper}
import org.tensorflow.framework.{AttrValue, DataType, NodeDef}
import com.intel.analytics.bigdl.utils.tf.Tensorflow._

class Conv3DBackpropInputV2Spec extends TensorflowSpecHelper {

  "Conv3DBackpropInputV2 forward with VALID padding" should "be correct" in {

    val dataFormat = AttrValue.newBuilder().setS(ByteString
      .copyFrom("NDHWC", Charset.defaultCharset())).build()

    val builder = NodeDef.newBuilder()
      .setName(s"Conv3DBackpropInputV2Test")
      .setOp("Conv3DBackpropInputV2")
      .putAttr("T", typeAttr(DataType.DT_FLOAT))
      .putAttr("strides", listIntAttr(Seq(1, 1, 2, 3, 1)))
      .putAttr("padding", PaddingType.PADDING_VALID.value)
      .putAttr("data_format", dataFormat)

    val inputSize = Tensor[Int](Array(4, 20, 30, 40, 3), Array(5))
    val filter = Tensor[Float](2, 3, 4, 3, 4).rand()
    val outputBackprop = Tensor[Float](4, 19, 14, 13, 4).rand()

    compare[Float](
      builder,
      Seq(inputSize, filter, outputBackprop),
      0,
      1e-4
    )
  }

  "Conv3DBackpropInputV2 forward with SAME padding" should "be correct" in {

    val dataFormat = AttrValue.newBuilder().setS(ByteString
      .copyFrom("NDHWC", Charset.defaultCharset())).build()

    val builder = NodeDef.newBuilder()
      .setName(s"Conv3DBackpropInputV2Test")
      .setOp("Conv3DBackpropInputV2")
      .putAttr("T", typeAttr(DataType.DT_FLOAT))
      .putAttr("strides", listIntAttr(Seq(1, 1, 2, 3, 1)))
      .putAttr("padding", PaddingType.PADDING_SAME.value)
      .putAttr("data_format", dataFormat)

    val inputSize = Tensor[Int](Array(4, 20, 30, 40, 3), Array(5))
    val filter = Tensor[Float](2, 3, 4, 3, 4).rand()
    val outputBackprop = Tensor[Float](4, 20, 15, 14, 4).rand()

    compare[Float](
      builder,
      Seq(inputSize, filter, outputBackprop),
      0,
      1e-4
    )
  }


} 
Example 29
Source File: Conv3DSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils.tf.loaders

import java.nio.charset.Charset

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.tf.{PaddingType, TensorflowSpecHelper}
import org.tensorflow.framework.{AttrValue, DataType, NodeDef}
import com.intel.analytics.bigdl.utils.tf.Tensorflow._

class Conv3DSpec extends TensorflowSpecHelper {

  "Conv3D forward with VALID padding" should "be correct" in {

    val dataFormat = AttrValue.newBuilder().setS(ByteString
      .copyFrom("NDHWC", Charset.defaultCharset())).build()

    val builder = NodeDef.newBuilder()
      .setName(s"Conv3DTest")
      .setOp("Conv3D")
      .putAttr("T", typeAttr(DataType.DT_FLOAT))
      .putAttr("strides", listIntAttr(Seq(1, 1, 2, 3, 1)))
      .putAttr("padding", PaddingType.PADDING_VALID.value)
      .putAttr("data_format", dataFormat)

    val input = Tensor[Float](4, 20, 30, 40, 3).rand()
    val filter = Tensor[Float](2, 3, 4, 3, 4).rand()

    compare[Float](
      builder,
      Seq(input, filter),
      0,
      1e-4
    )
  }

  "Conv3D forward with SAME padding" should "be correct" in {

    val dataFormat = AttrValue.newBuilder().setS(ByteString
      .copyFrom("NDHWC", Charset.defaultCharset())).build()

    val builder = NodeDef.newBuilder()
      .setName(s"Conv3DTest")
      .setOp("Conv3D")
      .putAttr("T", typeAttr(DataType.DT_FLOAT))
      .putAttr("strides", listIntAttr(Seq(1, 1, 2, 3, 1)))
      .putAttr("padding", PaddingType.PADDING_SAME.value)
      .putAttr("data_format", dataFormat)

    val input = Tensor[Float](4, 20, 30, 40, 3).rand()
    val filter = Tensor[Float](2, 3, 4, 3, 4).rand()

    compare[Float](
      builder,
      Seq(input, filter),
      0,
      1e-4
    )
  }
} 
Example 30
Source File: Conv3DBackpropFilterV2Spec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils.tf.loaders

import java.nio.charset.Charset

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.tf.{PaddingType, TensorflowSpecHelper}
import org.tensorflow.framework.{AttrValue, DataType, NodeDef}
import com.intel.analytics.bigdl.utils.tf.Tensorflow._

class Conv3DBackpropFilterV2Spec extends TensorflowSpecHelper {

  "Conv3DBackpropFilter forward with VALID padding" should "be correct" in {

    val dataFormat = AttrValue.newBuilder().setS(ByteString
      .copyFrom("NDHWC", Charset.defaultCharset())).build()

    val builder = NodeDef.newBuilder()
      .setName(s"Conv3DBackpropFilterV2Test")
      .setOp("Conv3DBackpropFilterV2")
      .putAttr("T", typeAttr(DataType.DT_FLOAT))
      .putAttr("strides", listIntAttr(Seq(1, 1, 2, 3, 1)))
      .putAttr("padding", PaddingType.PADDING_VALID.value)
      .putAttr("data_format", dataFormat)

    val input = Tensor[Float](4, 20, 30, 40, 3).rand()
    val filter = Tensor[Int](Array(2, 3, 4, 3, 4), Array(5))
    val outputBackprop = Tensor[Float](4, 19, 14, 13, 4).rand()

    // the output in this case is typical the scale of thousands,
    // so it is ok to have 1e-2 absolute error tolerance
    compare[Float](
      builder,
      Seq(input, filter, outputBackprop),
      0,
      1e-2
    )
  }

  "Conv3DBackpropFilter forward with SAME padding" should "be correct" in {

    val dataFormat = AttrValue.newBuilder().setS(ByteString
      .copyFrom("NDHWC", Charset.defaultCharset())).build()

    val builder = NodeDef.newBuilder()
      .setName(s"Conv3DBackpropFilterV2Test")
      .setOp("Conv3DBackpropFilterV2")
      .putAttr("T", typeAttr(DataType.DT_FLOAT))
      .putAttr("strides", listIntAttr(Seq(1, 1, 2, 3, 1)))
      .putAttr("padding", PaddingType.PADDING_SAME.value)
      .putAttr("data_format", dataFormat)

    val input = Tensor[Float](4, 20, 30, 40, 3).rand()
    val filter = Tensor[Int](Array(2, 3, 4, 3, 4), Array(5))
    val outputBackprop = Tensor[Float](4, 20, 15, 14, 4).rand()

    // the output in this case is typical the scale of thousands,
    // so it is ok to have 1e-2 absolute error tolerance
    compare[Float](
      builder,
      Seq(input, filter, outputBackprop),
      0,
      1e-2
    )
  }
} 
Example 31
Source File: DecodeJpegSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.tf

import java.io.File

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest
import com.intel.analytics.bigdl.utils.tf.TFRecordIterator
import org.tensorflow.example.Example

class DecodeJpegSerialTest extends ModuleSerializationTest {
  override def test(): Unit = {
    val decodeJpeg = new DecodeJpeg[Float](1).setName("decodeJpeg")
    val input = getInputs("jpeg")
    runSerializationTest(decodeJpeg, input)
  }

  private def getInputs(name: String): Tensor[ByteString] = {
    import com.intel.analytics.bigdl.utils.tf.TFTensorNumeric.NumericByteString
    val index = name match {
      case "png" => 0
      case "jpeg" => 1
      case "gif" => 2
      case "raw" => 3
    }

    val resource = getClass.getClassLoader.getResource("tf")
    val path = resource.getPath + File.separator + "decode_image_test_case.tfrecord"
    val file = new File(path)

    val bytesVector = TFRecordIterator(file).toVector
    val pngBytes = bytesVector(index)

    val example = Example.parseFrom(pngBytes)
    val imageByteString = example.getFeatures.getFeatureMap.get("image/encoded")
      .getBytesList.getValueList.get(0)

    Tensor[ByteString](Array(imageByteString), Array[Int]())
  }
} 
Example 32
Source File: DecodeGifSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.tf

import java.io.{File => JFile}

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest
import com.intel.analytics.bigdl.utils.tf.TFRecordIterator
import org.tensorflow.example.Example

class DecodeGifSerialTest extends ModuleSerializationTest {
  private def getInputs(name: String): Tensor[ByteString] = {
    import com.intel.analytics.bigdl.utils.tf.TFTensorNumeric.NumericByteString
    val index = name match {
      case "png" => 0
      case "jpeg" => 1
      case "gif" => 2
      case "raw" => 3
    }

    val resource = getClass.getClassLoader.getResource("tf")
    val path = resource.getPath + JFile.separator + "decode_image_test_case.tfrecord"
    val file = new JFile(path)

    val bytesVector = TFRecordIterator(file).toVector
    val pngBytes = bytesVector(index)

    val example = Example.parseFrom(pngBytes)
    val imageByteString = example.getFeatures.getFeatureMap.get("image/encoded")
      .getBytesList.getValueList.get(0)

    Tensor[ByteString](Array(imageByteString), Array[Int]())
  }

  override def test(): Unit = {
    val decodeGif = new DecodeGif[Float]().setName("decodeGif")
    val input = getInputs("gif")
    runSerializationTest(decodeGif, input)
  }
} 
Example 33
Source File: DecodePngSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.tf

import java.io.{File => JFile}

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest
import com.intel.analytics.bigdl.utils.tf.TFRecordIterator
import org.tensorflow.example.Example

class DecodePngSerialTest extends ModuleSerializationTest {
  private def getInputs(name: String): Tensor[ByteString] = {
    import com.intel.analytics.bigdl.utils.tf.TFTensorNumeric.NumericByteString
    val index = name match {
      case "png" => 0
      case "jpeg" => 1
      case "gif" => 2
      case "raw" => 3
    }

    val resource = getClass.getClassLoader.getResource("tf")
    val path = resource.getPath + JFile.separator + "decode_image_test_case.tfrecord"
    val file = new JFile(path)

    val bytesVector = TFRecordIterator(file).toVector
    val pngBytes = bytesVector(index)

    val example = Example.parseFrom(pngBytes)
    val imageByteString = example.getFeatures.getFeatureMap.get("image/encoded")
      .getBytesList.getValueList.get(0)

    Tensor[ByteString](Array(imageByteString), Array[Int]())
  }

  override def test(): Unit = {
    val decodePng = new DecodePng[Float](1).setName("decodePng")
    val input = getInputs("png")
    runSerializationTest(decodePng, input)
  }
} 
Example 34
Source File: DecodeBmpSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.tf

import java.io.{File => JFile}

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest
import com.intel.analytics.bigdl.utils.tf.TFRecordIterator
import org.tensorflow.example.Example

class DecodeBmpSerialTest extends ModuleSerializationTest {
  private def getInputs(name: String): Tensor[ByteString] = {
    import com.intel.analytics.bigdl.utils.tf.TFTensorNumeric.NumericByteString
    
    val index = name match {
      case "png" => 0
      case "jpeg" => 1
      case "gif" => 2
      case "raw" => 3
      case "bmp" => 0
    }

    val resource = getClass.getClassLoader.getResource("tf")
    val path = resource.getPath + JFile.separator + "decode_image_test_case.tfrecord"
    val file = new JFile(path)

    val bytesVector = TFRecordIterator(file).toVector
    val bmpBytes = bytesVector(index)

    val example = Example.parseFrom(bmpBytes)
    val imageByteString = example.getFeatures.getFeatureMap.get("image/encoded")
      .getBytesList.getValueList.get(0)

    Tensor[ByteString](Array(imageByteString), Array[Int]())
  }

  override def test(): Unit = {
    val decodeBmp = new DecodeBmp[Float](1).setName("decodeBmp")
    val input = getInputs("bmp")
    runSerializationTest(decodeBmp, input)
  }
} 
Example 35
Source File: AssertSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.tf

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.T
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest
import com.intel.analytics.bigdl.utils.tf.TFTensorNumeric.NumericByteString

class AssertSerialTest extends ModuleSerializationTest {
  override def test(): Unit = {
    val assert = new Assert[Float]().setName("assert")
    val predictTensor = Tensor[Boolean](Array(1))
    predictTensor.setValue(1, true)
    val msg = Tensor[ByteString](Array(1))
    msg.setValue(1, ByteString.copyFromUtf8("must be true"))
    val input = T(predictTensor, msg)
    runSerializationTest(assert, input)
  }
} 
Example 36
Source File: DecodeRawSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.tf

import java.io.File

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest
import com.intel.analytics.bigdl.utils.tf.TFRecordIterator
import org.tensorflow.example.Example
import org.tensorflow.framework.DataType

class DecodeRawSerialTest extends ModuleSerializationTest {
  override def test(): Unit = {
    val decodeRaw = new DecodeRaw[Float](DataType.DT_UINT8, true).setName("decodeRaw")
    val input = getInputs("raw")
    runSerializationTest(decodeRaw, input)
  }

  private def getInputs(name: String): Tensor[ByteString] = {
    import com.intel.analytics.bigdl.utils.tf.TFTensorNumeric.NumericByteString
    val index = name match {
      case "png" => 0
      case "jpeg" => 1
      case "gif" => 2
      case "raw" => 3
    }

    val resource = getClass.getClassLoader.getResource("tf")
    val path = resource.getPath + File.separator + "decode_image_test_case.tfrecord"
    val file = new File(path)

    val bytesVector = TFRecordIterator(file).toVector
    val pngBytes = bytesVector(index)

    val example = Example.parseFrom(pngBytes)
    val imageByteString = example.getFeatures.getFeatureMap.get("image/encoded")
      .getBytesList.getValueList.get(0)

    Tensor[ByteString](Array(imageByteString), Array[Int]())
  }
} 
Example 37
Source File: ParseSingleExampleSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.tf

import com.intel.analytics.bigdl.tensor.{FloatType, LongType, StringType, Tensor}
import com.google.protobuf.{ByteString, CodedOutputStream}
import com.intel.analytics.bigdl.utils.T
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest
import org.scalatest.{FlatSpec, Matchers}
import org.tensorflow.example._
import com.intel.analytics.bigdl.utils.tf.TFTensorNumeric.NumericByteString

class ParseSingleExampleSpec extends FlatSpec with Matchers {

  "ParseSingleExample" should "be able to parse a example" in {

    val floatBuilder = FloatList.newBuilder()
      .addValue(0.0f).addValue(1.0f).addValue(2.0f)
    val floatFeature = Feature.newBuilder().setFloatList(floatBuilder).build()

    val longBuilder = Int64List.newBuilder()
      .addValue(0).addValue(1).addValue(2)
    val longFeature = Feature.newBuilder().setInt64List(longBuilder).build()

    val bytesBuilder = BytesList.newBuilder().addValue(ByteString.copyFromUtf8("abcd"))
    val bytesFeature = Feature.newBuilder().setBytesList(bytesBuilder).build()

    val features = Features.newBuilder()
      .putFeature("floatFeature", floatFeature)
      .putFeature("longFeature", longFeature)
      .putFeature("bytesFeature", bytesFeature)
    val example = Example.newBuilder().setFeatures(features).build()
    val length = example.getSerializedSize
    val data = new Array[Byte](length)
    val outputStream = CodedOutputStream.newInstance(data)
    example.writeTo(outputStream)

    val key1 = ByteString.copyFromUtf8("floatFeature")
    val key2 = ByteString.copyFromUtf8("longFeature")
    val key3 = ByteString.copyFromUtf8("bytesFeature")
    val denseKeys = Seq(key1, key2, key3)

    val exampleParser = new ParseSingleExample[Float](
      Seq(FloatType, LongType, StringType), denseKeys, Seq(Array(3), Array(3), Array()))

    val serialized = Tensor[ByteString](Array(ByteString.copyFrom(data)), Array[Int](1))

    val input = T(serialized)

    val output = exampleParser.forward(input)

    val floatTensor = output(1).asInstanceOf[Tensor[Float]]
    val longTensor = output(2).asInstanceOf[Tensor[Long]]
    val stringTensor = output(3).asInstanceOf[Tensor[ByteString]]

    floatTensor should be (Tensor[Float](T(0.0f, 1.0f, 2.0f)))
    longTensor should be (Tensor[Long](T(0L, 1L, 2L)))
    stringTensor should be (Tensor.scalar((ByteString.copyFromUtf8("abcd"))))
  }

}

class ParseSingleExampleSerialTest extends ModuleSerializationTest {
  override def test(): Unit = {
    import com.intel.analytics.bigdl.utils.tf.TFTensorNumeric.NumericByteString

    val floatBuilder = FloatList.newBuilder()
      .addValue(0.0f).addValue(1.0f).addValue(2.0f)
    val floatFeature = Feature.newBuilder().setFloatList(floatBuilder).build()

    val longBuilder = Int64List.newBuilder()
      .addValue(0).addValue(1).addValue(2)
    val longFeature = Feature.newBuilder().setInt64List(longBuilder).build()

    val bytesBuilder = BytesList.newBuilder().addValue(ByteString.copyFromUtf8("abcd"))
    val bytesFeature = Feature.newBuilder().setBytesList(bytesBuilder).build()

    val features = Features.newBuilder()
      .putFeature("floatFeature", floatFeature)
      .putFeature("longFeature", longFeature)
      .putFeature("bytesFeature", bytesFeature)
    val example = Example.newBuilder().setFeatures(features).build()
    val length = example.getSerializedSize
    val data = new Array[Byte](length)
    val outputStream = CodedOutputStream.newInstance(data)
    example.writeTo(outputStream)

    val key1 = ByteString.copyFromUtf8("floatFeature")
    val key2 = ByteString.copyFromUtf8("longFeature")
    val key3 = ByteString.copyFromUtf8("bytesFeature")
    val denseKeys = Seq(key1, key2, key3)

    val exampleParser = new ParseSingleExample[Float](Seq(FloatType, LongType, StringType),
      denseKeys, Seq(Array(3), Array(3), Array())).setName("parseSingleExample")

    val serialized = Tensor[ByteString](Array(ByteString.copyFrom(data)), Array[Int](1))

    val input = T(serialized)
    runSerializationTest(exampleParser, input)
  }
} 
Example 38
Source File: SubstrSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn.ops

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.T
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest
import org.scalatest.{FlatSpec, Matchers}

class SubstrSpec extends FlatSpec with Matchers {
  "Substr operation" should "works correctly" in {
    import com.intel.analytics.bigdl.utils.tf.TFTensorNumeric.NumericByteString
    val data = Tensor.scalar(ByteString.copyFromUtf8("abc"))
    val pos = Tensor.scalar(0)
    val len = Tensor.scalar(2)
    val expectOutput = Tensor.scalar(ByteString.copyFromUtf8("ab"))

    val output = Substr().forward(T(data, pos, len))
    output should be(expectOutput)
  }
}

class SubstrSerialTest extends ModuleSerializationTest {
  override def test(): Unit = {
    import com.intel.analytics.bigdl.utils.tf.TFTensorNumeric.NumericByteString
    val subStr = Substr[Float]().setName("subStr")
    val input = T(Tensor.scalar[ByteString](ByteString.copyFromUtf8("HelloBigDL")),
      Tensor.scalar[Int](0), Tensor.scalar[Int](5))
    runSerializationTest(subStr, input)
  }
} 
Example 39
Source File: WavesToPbConversions.scala    From matcher   with MIT License 5 votes vote down vote up
package com.wavesplatform.dex.grpc.integration.protobuf

import com.google.protobuf.ByteString
import com.wavesplatform.account.Address
import com.wavesplatform.common.state.ByteStr
import com.wavesplatform.dex.grpc.integration.services._
import com.wavesplatform.protobuf.Amount
import com.wavesplatform.protobuf.order.{AssetPair, Order}
import com.wavesplatform.protobuf.transaction.ExchangeTransactionData
import com.wavesplatform.transaction.Asset
import com.wavesplatform.transaction.assets.{exchange => ve}
import com.wavesplatform.{account => va}

object WavesToPbConversions {

  implicit class VanillaExchangeTransactionOps(tx: ve.ExchangeTransaction) {
    def toPB: SignedExchangeTransaction =
      SignedExchangeTransaction(
        transaction = Some(
          ExchangeTransaction(
            chainId = tx.chainByte.getOrElse(va.AddressScheme.current.chainId).toInt,
            senderPublicKey = tx.sender.toPB,
            fee = Some(Amount(assetId = tx.assetFee._1.toPB, amount = tx.assetFee._2)),
            timestamp = tx.timestamp,
            version = tx.version,
            data = ExchangeTransaction.Data.Exchange(
              ExchangeTransactionData(
                amount = tx.amount,
                price = tx.price,
                buyMatcherFee = tx.buyMatcherFee,
                sellMatcherFee = tx.sellMatcherFee,
                orders = Seq(tx.buyOrder.toPB, tx.sellOrder.toPB)
              )
            )
          )
        ),
        proofs = tx.proofs.proofs.map(_.toPB)
      )
  }

  implicit class VanillaAssetOps(self: Asset) {
    def toPB: ByteString = self match {
      case Asset.IssuedAsset(assetId) => assetId.toPB
      case Asset.Waves                => ByteString.EMPTY
    }
  }

  implicit class VanillaAddressOps(self: Address) {
    def toPB: ByteString = self.bytes.toPB
  }

  implicit class VanillaOrderOps(order: ve.Order) {
    def toPB: Order =
      Order(
        chainId = va.AddressScheme.current.chainId.toInt,
        senderPublicKey = order.senderPublicKey.toPB,
        matcherPublicKey = order.matcherPublicKey.toPB,
        assetPair = Some(AssetPair(order.assetPair.amountAsset.toPB, order.assetPair.priceAsset.toPB)),
        orderSide = order.orderType match {
          case ve.OrderType.BUY  => Order.Side.BUY
          case ve.OrderType.SELL => Order.Side.SELL
        },
        amount = order.amount,
        price = order.price,
        timestamp = order.timestamp,
        expiration = order.expiration,
        matcherFee = Some(Amount(order.matcherFeeAssetId.toPB, order.matcherFee)),
        version = order.version,
        proofs = order.proofs.map(_.toPB)
      )
  }

  implicit class VanillaByteStrOps(val self: ByteStr) extends AnyVal {
    def toPB: ByteString = ByteString.copyFrom(self.arr)
  }
} 
Example 40
Source File: ConversionUtils.scala    From scala-serialization   with MIT License 5 votes vote down vote up
package com.komanov.serialization.converters

import java.nio.ByteBuffer
import java.time.Instant
import java.util.UUID

import com.google.protobuf.ByteString

object ConversionUtils {

  def uuidToBytes(uuid: UUID): ByteString = {
    val bb = uuidToByteBuffer(uuid)
    if (bb == null) ByteString.EMPTY else ByteString.copyFrom(bb)
  }

  def uuidToByteBuffer(uuid: UUID): ByteBuffer = {
    if (uuid == null) {
      return null
    }

    val buffer = ByteBuffer.allocate(16)
    buffer.putLong(uuid.getMostSignificantBits)
    buffer.putLong(uuid.getLeastSignificantBits)
    buffer.rewind()
    buffer
  }

  def bytesToUuid(bb: ByteBuffer): UUID = {
    if (bb == null) {
      return null
    }

    val length = bb.limit() - bb.position()
    if (length == 0) {
      return null
    }

    require(length >= 16, s"expected 16 bytes: ${bb.capacity()} / ${bb.limit()}")

    new UUID(bb.getLong, bb.getLong)
  }

  def bytesToUuid(bs: ByteString): UUID = {
    bytesToUuid(bs.asReadOnlyByteBuffer())
  }

  def instantToLong(v: Instant) = v.toEpochMilli

  def longToInstance(v: Long) = Instant.ofEpochMilli(v)

} 
Example 41
Source File: BigtableTypeSpec.scala    From magnolify   with Apache License 2.0 5 votes vote down vote up
package magnolify.bigtable.test

import java.net.URI
import java.time.Duration

import cats._
import cats.instances.all._
import com.google.bigtable.v2.Row
import com.google.protobuf.ByteString
import magnolify.bigtable._
import magnolify.cats.auto._
import magnolify.scalacheck.auto._
import magnolify.shared.CaseMapper
import magnolify.test.Simple._
import magnolify.test._
import org.scalacheck._

import scala.reflect._

object BigtableTypeSpec extends MagnolifySpec("BigtableType") {
  private def test[T: Arbitrary: ClassTag](implicit t: BigtableType[T], eq: Eq[T]): Unit = {
    val tpe = ensureSerializable(t)
    property(className[T]) = Prop.forAll { t: T =>
      val mutations = tpe(t, "cf")
      val row = BigtableType.mutationsToRow(ByteString.EMPTY, mutations)
      val copy = tpe(row, "cf")
      val rowCopy = BigtableType.mutationsToRow(ByteString.EMPTY, BigtableType.rowToMutations(row))

      Prop.all(
        eq.eqv(t, copy),
        row == rowCopy
      )
    }
  }

  test[Numbers]
  test[Required]
  test[Nullable]
  test[BigtableNested]

  {
    implicit val arbByteString: Arbitrary[ByteString] =
      Arbitrary(Gen.alphaNumStr.map(ByteString.copyFromUtf8))
    implicit val eqByteString: Eq[ByteString] = Eq.instance(_ == _)
    implicit val eqByteArray: Eq[Array[Byte]] = Eq.by(_.toList)
    test[BigtableTypes]
  }

  {
    import Custom._
    implicit val btfUri: BigtableField[URI] =
      BigtableField.from[String](x => URI.create(x))(_.toString)
    implicit val btfDuration: BigtableField[Duration] =
      BigtableField.from[Long](Duration.ofMillis)(_.toMillis)

    test[Custom]
  }

  {
    val it = BigtableType[DefaultInner]
    ensureSerializable(it)
    require(it(Row.getDefaultInstance, "cf") == DefaultInner())
    val inner = DefaultInner(2, Some(2))
    require(it(BigtableType.mutationsToRow(ByteString.EMPTY, it(inner, "cf")), "cf") == inner)

    val ot = BigtableType[DefaultOuter]
    ensureSerializable(ot)
    require(ot(Row.getDefaultInstance, "cf") == DefaultOuter())
    val outer =
      DefaultOuter(DefaultInner(3, Some(3)), Some(DefaultInner(3, Some(3))))
    require(ot(BigtableType.mutationsToRow(ByteString.EMPTY, ot(outer, "cf")), "cf") == outer)
  }

  {
    implicit val bt = BigtableType[LowerCamel](CaseMapper(_.toUpperCase))
    test[LowerCamel]

    val fields = LowerCamel.fields
      .map(_.toUpperCase)
      .map(l => if (l == "INNERFIELD") "INNERFIELD.INNERFIRST" else l)
    val record = bt(LowerCamel.default, "cf")
    require(record.map(_.getSetCell.getColumnQualifier.toStringUtf8) == fields)
  }
}

// Collections are not supported
case class BigtableNested(b: Boolean, i: Int, s: String, r: Required, o: Option[Required])

case class BigtableTypes(b: Byte, c: Char, s: Short, bs: ByteString, ba: Array[Byte])

// Collections are not supported
case class DefaultInner(i: Int = 1, o: Option[Int] = Some(1))
case class DefaultOuter(
  i: DefaultInner = DefaultInner(2, Some(2)),
  o: Option[DefaultInner] = Some(DefaultInner(2, Some(2)))
) 
Example 42
Source File: ExampleTypeSpec.scala    From magnolify   with Apache License 2.0 5 votes vote down vote up
package magnolify.tensorflow.test

import java.net.URI
import java.time.Duration

import cats._
import cats.instances.all._
import com.google.protobuf.ByteString
import magnolify.cats.auto._
import magnolify.scalacheck.auto._
import magnolify.shared.CaseMapper
import magnolify.shims.JavaConverters._
import magnolify.tensorflow._
import magnolify.tensorflow.unsafe._
import magnolify.test.Simple._
import magnolify.test._
import org.scalacheck._

import scala.reflect._

object ExampleTypeSpec extends MagnolifySpec("ExampleType") {
  private def test[T: Arbitrary: ClassTag](implicit t: ExampleType[T], eq: Eq[T]): Unit = {
    val tpe = ensureSerializable(t)
    property(className[T]) = Prop.forAll { t: T =>
      val r = tpe(t)
      val copy = tpe(r)
      eq.eqv(t, copy)
    }
  }

  test[Integers]
  test[Required]
  test[Nullable]
  test[Repeated]
  test[ExampleNested]

  {
    // workaround for Double to Float precision loss
    implicit val arbDouble: Arbitrary[Double] =
      Arbitrary(Arbitrary.arbFloat.arbitrary.map(_.toDouble))
    test[Unsafe]
  }

  {
    import Collections._
    test[Collections]
    test[MoreCollections]
  }

  {
    import Custom._
    implicit val efUri: ExampleField.Primitive[URI] =
      ExampleField.from[ByteString](x => URI.create(x.toStringUtf8))(x =>
        ByteString.copyFromUtf8(x.toString)
      )
    implicit val efDuration: ExampleField.Primitive[Duration] =
      ExampleField.from[Long](Duration.ofMillis)(_.toMillis)

    test[Custom]
  }

  {
    implicit val arbByteString: Arbitrary[ByteString] =
      Arbitrary(Gen.alphaNumStr.map(ByteString.copyFromUtf8))
    implicit val eqByteString: Eq[ByteString] = Eq.instance(_ == _)
    implicit val eqByteArray: Eq[Array[Byte]] = Eq.by(_.toList)
    test[ExampleTypes]
  }

  {
    implicit val et = ExampleType[LowerCamel](CaseMapper(_.toUpperCase))
    test[LowerCamel]

    val fields = LowerCamel.fields
      .map(_.toUpperCase)
      .map(l => if (l == "INNERFIELD") "INNERFIELD.INNERFIRST" else l)
    val record = et(LowerCamel.default)
    require(record.getFeatures.getFeatureMap.keySet().asScala == fields.toSet)
  }
}

// Option[T] and Seq[T] not supported
case class ExampleNested(b: Boolean, i: Int, s: String, r: Required, o: Option[Required])
case class ExampleTypes(f: Float, bs: ByteString, ba: Array[Byte])

case class Unsafe(b: Byte, c: Char, s: Short, i: Int, d: Double, bool: Boolean, str: String) 
Example 43
Source File: ShadowNode.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.view.state.avlTree

import NodeMsg.NodeProtoMsg.NodeTypes.ShadowNodeProto
import cats.Monoid
import com.google.protobuf.ByteString
import com.typesafe.scalalogging.StrictLogging
import encry.storage.VersionalStorage
import encry.storage.VersionalStorage.StorageKey
import encry.view.state.avlTree.utils.implicits.{Hashable, Serializer}
import io.iohk.iodb.ByteArrayWrapper
import org.encryfoundation.common.utils.Algos

import scala.util.Try

abstract class ShadowNode[K: Serializer: Monoid, V: Serializer: Monoid] extends Node[K, V] {
  def restoreFullNode(storage: VersionalStorage): Node[K, V]
  final override def selfInspection = this
}

object ShadowNode {

  def nodeToShadow[K: Serializer : Hashable : Monoid, V: Serializer : Monoid](node: Node[K, V]): ShadowNode[K, V] = node match {
    case internal: InternalNode[K, V] =>
      NonEmptyShadowNode(internal.hash, internal.height, internal.balance, internal.key)
    case leaf: LeafNode[K, V] =>
      NonEmptyShadowNode(leaf.hash, leaf.height, leaf.balance, leaf.key)
    case _: EmptyNode[K, V] =>
      new EmptyShadowNode[K, V]
    case anotherShadow: ShadowNode[K, V] => anotherShadow
  }

  def childsToShadowNode[K: Serializer : Hashable : Monoid, V: Serializer : Monoid](node: Node[K, V]): Node[K, V] = node match {
    case internal: InternalNode[K, V] =>
      internal.copy(
        leftChild = nodeToShadow(internal.leftChild),
        rightChild = nodeToShadow(internal.rightChild),
      )
    case _ => node
  }
} 
Example 44
Source File: LeafNode.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.view.state.avlTree

import NodeMsg.NodeProtoMsg.NodeTypes.LeafNodeProto
import cats.Monoid
import com.google.common.primitives.Bytes
import com.google.protobuf.ByteString
import encry.view.state.avlTree.utils.implicits.{ Hashable, Serializer }
import org.encryfoundation.common.utils.Algos
import scala.util.Try

final case class LeafNode[K: Serializer: Monoid, V: Serializer: Monoid](key: K, value: V)(implicit hashK: Hashable[K])
    extends Node[K, V] {

  override lazy val hash: Array[Byte] = Algos.hash(
    Bytes.concat(
      implicitly[Serializer[K]].toBytes(key),
      implicitly[Serializer[V]].toBytes(value)
    )
  )

  override val balance: Int = 0

  override val height: Int = 0

  override def toString: String =
    s"(${Algos.encode(implicitly[Serializer[K]].toBytes(key))}," +
    s" $value, height: 0, balance: 0, hash: ${Algos.encode(hash)})"

  override def selfInspection = this
}

object LeafNode {
  def toProto[K, V](leaf: LeafNode[K, V])(implicit kSer: Serializer[K], vSer: Serializer[V]): LeafNodeProto =
    LeafNodeProto()
      .withKey(ByteString.copyFrom(kSer.toBytes(leaf.key)))
      .withValue(ByteString.copyFrom(vSer.toBytes(leaf.value)))

  def fromProto[K: Hashable: Monoid, V: Monoid](
    leafProto: LeafNodeProto
  )(implicit kSer: Serializer[K], vSer: Serializer[V]): Try[LeafNode[K, V]] = Try {
    LeafNode(
      kSer.fromBytes(leafProto.key.toByteArray),
      vSer.fromBytes(leafProto.value.toByteArray),
    )
  }
} 
Example 45
Source File: NonEmptyShadowNode.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.view.state.avlTree

import NodeMsg.NodeProtoMsg.NodeTypes.{NonEmptyShadowNodeProto, ShadowNodeProto}
import cats.Monoid
import com.google.protobuf.ByteString
import com.typesafe.scalalogging.StrictLogging
import encry.storage.VersionalStorage
import encry.storage.VersionalStorage.StorageKey
import encry.view.state.avlTree.utils.implicits.{Hashable, Serializer}
import org.encryfoundation.common.utils.Algos

import scala.util.Try

case class NonEmptyShadowNode[K: Serializer: Hashable, V: Serializer](nodeHash: Array[Byte],
                                                                      height: Int,
                                                                      balance: Int,
                                                                      key: K)
                                                                     (implicit kM: Monoid[K], vM: Monoid[V]) extends ShadowNode[K, V] with StrictLogging {

  override val value: V = vM.empty

  override val hash = nodeHash

  def restoreFullNode(storage: VersionalStorage): Node[K, V] = if (nodeHash.nonEmpty) {
    NodeSerilalizer.fromBytes[K, V](
      {
        val res = storage.get(StorageKey @@ AvlTree.nodeKey(hash))
        if (res.isEmpty) logger.info(s"Empty node at key: ${Algos.encode(hash)}")
        res.get
      }
    )
  } else EmptyNode[K, V]()

  def tryRestore(storage: VersionalStorage): Option[Node[K, V]] =
    Try(restoreFullNode(storage)).toOption

  override def toString: String = s"ShadowNode(Hash:${Algos.encode(hash)}, height: ${height}, balance: ${balance})"
}

object NonEmptyShadowNode {

  def toProto[K: Serializer, V](node: NonEmptyShadowNode[K, V]): NonEmptyShadowNodeProto = NonEmptyShadowNodeProto()
    .withHeight(node.height)
    .withHash(ByteString.copyFrom(node.hash))
    .withBalance(node.balance)
    .withKey(ByteString.copyFrom(implicitly[Serializer[K]].toBytes(node.key)))

  def fromProto[K: Hashable : Monoid : Serializer, V: Monoid : Serializer](protoNode: NonEmptyShadowNodeProto): Try[NonEmptyShadowNode[K, V]] = Try {
    NonEmptyShadowNode(
      nodeHash = protoNode.hash.toByteArray,
      height = protoNode.height,
      balance = protoNode.balance,
      key = implicitly[Serializer[K]].fromBytes(protoNode.key.toByteArray)
    )
  }
} 
Example 46
Source File: PayloadSerializer.scala    From eventuate   with Apache License 2.0 5 votes vote down vote up
package com.rbmhtechnology.eventuate.serializer

import akka.actor.ExtendedActorSystem
import akka.serialization.SerializationExtension
import akka.serialization.SerializerWithStringManifest
import com.google.protobuf.ByteString
import com.rbmhtechnology.eventuate.BinaryPayload
import com.rbmhtechnology.eventuate.serializer.CommonFormats.PayloadFormat


class BinaryPayloadSerializer(system: ExtendedActorSystem) extends PayloadSerializer {

  override def payloadFormatBuilder(payload: AnyRef): PayloadFormat.Builder = {
    val binaryPayload = payload.asInstanceOf[BinaryPayload]
    val builder = PayloadFormat.newBuilder()
      .setPayload(binaryPayload.bytes)
      .setSerializerId(binaryPayload.serializerId)
      .setIsStringManifest(binaryPayload.isStringManifest)
    binaryPayload.manifest.foreach(builder.setPayloadManifest)
    builder
  }

  override def payload(payloadFormat: PayloadFormat): AnyRef = {
    BinaryPayload(
      payloadFormat.getPayload,
      payloadFormat.getSerializerId,
      if (payloadFormat.hasPayloadManifest) Some(payloadFormat.getPayloadManifest) else None,
      payloadFormat.getIsStringManifest)
  }
} 
Example 47
Source File: BinaryPayloadManifestFilterSpec.scala    From eventuate   with Apache License 2.0 5 votes vote down vote up
package com.rbmhtechnology.eventuate

import com.google.protobuf.ByteString
import org.scalatest.Matchers
import org.scalatest.WordSpec

object BinaryPayloadManifestFilterSpec {
  def durableEventWithBinaryPayloadManifest(manifest: Option[String]): DurableEvent =
    DurableEvent(BinaryPayload(ByteString.EMPTY, 0, manifest, isStringManifest = true), "emitterId")
}

class BinaryPayloadManifestFilterSpec extends WordSpec with Matchers {

  import BinaryPayloadManifestFilterSpec._

  "BinaryPayloadManifestFilter" must {
    "pass BinaryPayloads with matching manifest" in {
      BinaryPayloadManifestFilter("a.*".r).apply(durableEventWithBinaryPayloadManifest(Some("abc"))) should be(true)
    }
    "filter BinaryPayloads with partially matching manifest" in {
      BinaryPayloadManifestFilter("b".r).apply(durableEventWithBinaryPayloadManifest(Some("abc"))) should be(false)
    }
    "filter BinaryPayloads with non-matching manifest" in {
      BinaryPayloadManifestFilter("a.*".r).apply(durableEventWithBinaryPayloadManifest(Some("bc"))) should be(false)
    }
    "filter BinaryPayloads without manifest" in {
      BinaryPayloadManifestFilter("a.*".r).apply(durableEventWithBinaryPayloadManifest(None)) should be(false)
    }
    "filter other payload" in {
      BinaryPayloadManifestFilter("a.*".r).apply(DurableEvent("payload", "emitterId")) should be(false)
    }
  }
} 
Example 48
Source File: PrimitiveWrappersSpec.scala    From scalapb-circe   with MIT License 5 votes vote down vote up
package scalapb_circe

import com.google.protobuf.ByteString
import jsontest.test3._
import io.circe.{Encoder, Json}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers

class PrimitiveWrappersSpec extends AnyFlatSpec with Matchers {

  private[this] def render[A](a: A)(implicit A: Encoder[A]): Json =
    A.apply(a)

  "Empty object" should "give empty json for Wrapper" in {
    JsonFormat.toJson(Wrapper()) must be(render(Map.empty[String, Json]))
  }

  "primitive values" should "serialize properly" in {
    JsonFormat.toJson(Wrapper(wBool = Some(false))) must be(render(Map("wBool" -> Json.fromBoolean(false))))
    JsonFormat.toJson(Wrapper(wBool = Some(true))) must be(render(Map("wBool" -> Json.fromBoolean(true))))
    JsonFormat.toJson(Wrapper(wDouble = Some(3.1))) must be(render(Map("wDouble" -> Json.fromDouble(3.1))))
    JsonFormat.toJson(Wrapper(wFloat = Some(3.0f))) must be(render(Map("wFloat" -> Json.fromDouble(3.0))))
    JsonFormat.toJson(Wrapper(wInt32 = Some(35544))) must be(render(Map("wInt32" -> Json.fromLong(35544))))
    JsonFormat.toJson(Wrapper(wInt32 = Some(0))) must be(render(Map("wInt32" -> Json.fromLong(0))))
    JsonFormat.toJson(Wrapper(wInt64 = Some(125))) must be(render(Map("wInt64" -> Json.fromString("125"))))
    JsonFormat.toJson(Wrapper(wUint32 = Some(125))) must be(render(Map("wUint32" -> Json.fromLong(125))))
    JsonFormat.toJson(Wrapper(wUint64 = Some(125))) must be(render(Map("wUint64" -> Json.fromString("125"))))
    JsonFormat.toJson(Wrapper(wString = Some("bar"))) must be(render(Map("wString" -> Json.fromString("bar"))))
    JsonFormat.toJson(Wrapper(wString = Some(""))) must be(render(Map("wString" -> Json.fromString(""))))
    JsonFormat.toJson(Wrapper(wBytes = Some(ByteString.copyFrom(Array[Byte](3, 5, 4))))) must be(
      render(Map("wBytes" -> Json.fromString("AwUE")))
    )
    JsonFormat.toJson(Wrapper(wBytes = Some(ByteString.EMPTY))) must be(render(Map("wBytes" -> Json.fromString(""))))
    new Printer(formattingLongAsNumber = true).toJson(Wrapper(wUint64 = Some(125))) must be(
      render(Map("wUint64" -> Json.fromLong(125)))
    )
    new Printer(formattingLongAsNumber = true).toJson(Wrapper(wInt64 = Some(125))) must be(
      render(Map("wInt64" -> Json.fromLong(125)))
    )
  }

  "primitive values" should "parse properly" in {
    JsonFormat.fromJson[Wrapper](render(Map("wBool" -> Json.fromBoolean(false)))) must be(Wrapper(wBool = Some(false)))
    JsonFormat.fromJson[Wrapper](render(Map("wBool" -> Json.fromBoolean(true)))) must be(Wrapper(wBool = Some(true)))
    JsonFormat.fromJson[Wrapper](render(Map("wDouble" -> Json.fromDouble(3.1)))) must be(Wrapper(wDouble = Some(3.1)))
    JsonFormat.fromJson[Wrapper](render(Map("wFloat" -> Json.fromDouble(3.0)))) must be(Wrapper(wFloat = Some(3.0f)))
    JsonFormat.fromJson[Wrapper](render(Map("wInt32" -> Json.fromLong(35544)))) must be(Wrapper(wInt32 = Some(35544)))
    JsonFormat.fromJson[Wrapper](render(Map("wInt32" -> Json.fromLong(0)))) must be(Wrapper(wInt32 = Some(0)))
    JsonFormat.fromJson[Wrapper](render(Map("wInt64" -> Json.fromString("125")))) must be(Wrapper(wInt64 = Some(125)))
    JsonFormat.fromJson[Wrapper](render(Map("wUint32" -> Json.fromLong(125)))) must be(Wrapper(wUint32 = Some(125)))
    JsonFormat.fromJson[Wrapper](render(Map("wUint64" -> Json.fromString("125")))) must be(Wrapper(wUint64 = Some(125)))
    JsonFormat.fromJson[Wrapper](render(Map("wString" -> Json.fromString("bar")))) must be(
      Wrapper(wString = Some("bar"))
    )
    JsonFormat.fromJson[Wrapper](render(Map("wString" -> Json.fromString("")))) must be(Wrapper(wString = Some("")))
    JsonFormat.fromJson[Wrapper](render(Map("wBytes" -> Json.fromString("AwUE")))) must be(
      Wrapper(wBytes = Some(ByteString.copyFrom(Array[Byte](3, 5, 4))))
    )
    JsonFormat.fromJson[Wrapper](render(Map("wBytes" -> Json.fromString("")))) must be(
      Wrapper(wBytes = Some(ByteString.EMPTY))
    )
  }

} 
Example 49
Source File: ProtobufScoringController.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package ml.combust.mleap.springboot

import java.util.concurrent.CompletionStage

import akka.actor.ActorSystem
import com.google.protobuf.ByteString
import ml.combust.mleap.executor._
import ml.combust.mleap.pb.TransformStatus.STATUS_ERROR
import ml.combust.mleap.pb.{BundleMeta, Mleap, Model, TransformFrameResponse}
import ml.combust.mleap.runtime.serialization.{FrameReader, FrameWriter}
import ml.combust.mleap.springboot.TypeConverters._
import org.apache.commons.lang3.exception.ExceptionUtils
import org.slf4j.LoggerFactory
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.http.HttpStatus
import org.springframework.web.bind.annotation._

import scala.compat.java8.FutureConverters._
import scala.concurrent.Future
import scala.util.{Failure, Success}

@RestController
@RequestMapping
class ProtobufScoringController(@Autowired val actorSystem : ActorSystem,
                                @Autowired val mleapExecutor: MleapExecutor) {

  private val executor = actorSystem.dispatcher

  @PostMapping(path = Array("/models"),
    consumes = Array("application/x-protobuf; charset=UTF-8"),
    produces = Array("application/x-protobuf; charset=UTF-8"))
  @ResponseStatus(HttpStatus.ACCEPTED)
  def loadModel(@RequestBody request: Mleap.LoadModelRequest,
                @RequestHeader(value = "timeout", defaultValue = "60000") timeout: Int) : CompletionStage[Mleap.Model] = {
    mleapExecutor
      .loadModel(javaPbToExecutorLoadModelRequest(request))(timeout)
      .map(model => Model.toJavaProto(model))(executor).toJava
  }

  @DeleteMapping(path = Array("/models/{model_name}"),
    consumes = Array("application/x-protobuf; charset=UTF-8"),
    produces = Array("application/x-protobuf; charset=UTF-8"))
  def unloadModel(@PathVariable("model_name") modelName: String,
                  @RequestHeader(value = "timeout", defaultValue = "60000") timeout: Int): CompletionStage[Mleap.Model] =
    mleapExecutor
      .unloadModel(UnloadModelRequest(modelName))(timeout)
      .map(model => Model.toJavaProto(model))(executor).toJava

  @GetMapping(path = Array("/models/{model_name}"),
    consumes = Array("application/x-protobuf; charset=UTF-8"),
    produces = Array("application/x-protobuf; charset=UTF-8"))
  def getModel(@PathVariable("model_name") modelName: String,
               @RequestHeader(value = "timeout", defaultValue = "60000") timeout: Int): CompletionStage[Mleap.Model] =
    mleapExecutor
      .getModel(GetModelRequest(modelName))(timeout)
      .map(model => Model.toJavaProto(model))(executor).toJava

  @GetMapping(path = Array("/models/{model_name}/meta"),
    consumes = Array("application/x-protobuf; charset=UTF-8"),
    produces = Array("application/x-protobuf; charset=UTF-8"))
  def getMeta(@PathVariable("model_name") modelName: String,
              @RequestHeader(value = "timeout", defaultValue = "60000") timeout: Int) : CompletionStage[Mleap.BundleMeta] =
    mleapExecutor
      .getBundleMeta(GetBundleMetaRequest(modelName))(timeout)
      .map(meta => BundleMeta.toJavaProto(meta))(executor).toJava

  @PostMapping(path = Array("/models/transform"),
    consumes = Array("application/x-protobuf; charset=UTF-8"),
    produces = Array("application/x-protobuf; charset=UTF-8"))
  def transform(@RequestBody request: Mleap.TransformFrameRequest,
                @RequestHeader(value = "timeout", defaultValue = "60000") timeout: Int) : CompletionStage[Mleap.TransformFrameResponse] = {
    FrameReader(request.getFormat).fromBytes(request.getFrame.toByteArray) match {
      case Success(frame) =>
        mleapExecutor.transform(TransformFrameRequest(request.getModelName, frame, request.getOptions))(timeout)
        .mapAll {
          case Success(resp) => resp match {
            case Success(frame) => TransformFrameResponse(tag = request.getTag,
              frame = ByteString.copyFrom(FrameWriter(frame, request.getFormat).toBytes().get))
            case Failure(ex) => handleTransformFailure(request.getTag, ex)
          }
          case Failure(ex) => handleTransformFailure(request.getTag, ex)
        }(executor)
        .map(response => TransformFrameResponse.toJavaProto(response))(executor).toJava
      case Failure(ex) => Future {
          TransformFrameResponse.toJavaProto(handleTransformFailure(request.getTag, ex))
        }(executor).toJava
    }
  }

  private def handleTransformFailure(tag: Long, ex: Throwable): TransformFrameResponse = {
    ProtobufScoringController.logger.error("Transform error due to ", ex)
    TransformFrameResponse(tag = tag, status = STATUS_ERROR,
      error = ExceptionUtils.getMessage(ex), backtrace = ExceptionUtils.getStackTrace(ex))
  }
}

object ProtobufScoringController {
  val logger = LoggerFactory.getLogger(classOf[ProtobufScoringController])
} 
Example 50
Source File: ProtobufScoringSpec.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package ml.combust.mleap.springboot

import java.net.URI

import com.google.protobuf.ByteString
import ml.combust.mleap.pb._
import ml.combust.mleap.runtime.frame.DefaultLeapFrame
import ml.combust.mleap.runtime.serialization.{BuiltinFormats, FrameWriter}
import ml.combust.mleap.springboot.TestUtil.validFrame
import org.junit.runner.RunWith
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT
import org.springframework.http.{HttpEntity, HttpHeaders, ResponseEntity}
import org.springframework.test.context.junit4.SpringRunner

@RunWith(classOf[SpringRunner])
@SpringBootTest(webEnvironment = RANDOM_PORT)
class ProtobufScoringSpec extends ScoringBase[Mleap.LoadModelRequest, Mleap.Model, Mleap.BundleMeta, Mleap.TransformFrameRequest, Mleap.TransformFrameResponse] {

  override def createLoadModelRequest(modelName: String, uri: URI, createTmpFile: Boolean): HttpEntity[Mleap.LoadModelRequest] = {
    val request = LoadModelRequest(modelName = modelName,
      uri = TestUtil.getBundle(uri, createTmpFile).toString,
      config = Some(ModelConfig(Some(9000L), Some(9000L))))
    new HttpEntity[Mleap.LoadModelRequest](LoadModelRequest.toJavaProto(request), ProtobufScoringSpec.protoHeaders)
  }

  override def createTransformFrameRequest(modelName: String, frame: DefaultLeapFrame, options: Option[TransformOptions]): HttpEntity[Mleap.TransformFrameRequest] = {
    val request = TransformFrameRequest(modelName = modelName,
      format = BuiltinFormats.binary,
      initTimeout = Some(35000L),
      frame = ByteString.copyFrom(FrameWriter(frame, BuiltinFormats.binary).toBytes().get),
      options = options
    )

    new HttpEntity[Mleap.TransformFrameRequest](TransformFrameRequest.toJavaProto(request),
      ProtobufScoringSpec.protoHeaders)
  }

  override def createTransformFrameRequest(frame: DefaultLeapFrame): HttpEntity[Array[Byte]] =
    new HttpEntity[Array[Byte]](FrameWriter(validFrame, leapFrameFormat()).toBytes().get, ProtobufScoringSpec.protoHeaders)

  override def extractModelResponse(response: ResponseEntity[_ <: Any]): Mleap.Model = response.getBody.asInstanceOf[Mleap.Model]

  override def createEmptyBodyRequest(): HttpEntity[Unit] = ProtobufScoringSpec.httpEntityWithProtoHeaders

  override def extractBundleMetaResponse(response: ResponseEntity[_]): Mleap.BundleMeta = response.getBody.asInstanceOf[Mleap.BundleMeta]

  override def extractTransformResponse(response: ResponseEntity[_]): Mleap.TransformFrameResponse = response.getBody.asInstanceOf[Mleap.TransformFrameResponse]

  override def leapFrameFormat(): String = BuiltinFormats.binary

  override def createInvalidTransformFrameRequest(modelName: String, bytes: Array[Byte]): HttpEntity[Mleap.TransformFrameRequest] = {
    val request = TransformFrameRequest(modelName = modelName,
      format = BuiltinFormats.binary,
      initTimeout = Some(35000L),
      frame = ByteString.copyFrom(bytes),
      options = None
    )

    new HttpEntity[Mleap.TransformFrameRequest](TransformFrameRequest.toJavaProto(request),
      ProtobufScoringSpec.protoHeaders)
  }
}

object ProtobufScoringSpec {
  lazy val httpEntityWithProtoHeaders = new HttpEntity[Unit](protoHeaders)

  lazy val protoHeaders = {
    val headers = new HttpHeaders
    headers.add("Content-Type", "application/x-protobuf")
    headers.add("timeout", "2000")
    headers
  }
} 
Example 51
Source File: ArbitraryProtoUtils.scala    From sparksql-scalapb   with Apache License 2.0 5 votes vote down vote up
package scalapb.spark

import com.google.protobuf.ByteString
import org.scalacheck.Arbitrary
import org.scalacheck.derive.MkArbitrary
import scalapb.spark.test.{all_types2 => AT2}
import scalapb.spark.test3.{all_types3 => AT3}
import scalapb.{GeneratedEnum, GeneratedEnumCompanion, GeneratedMessage, Message}
import shapeless.Strict
import org.scalacheck.Gen
import scalapb.UnknownFieldSet

object ArbitraryProtoUtils {
  import org.scalacheck.ScalacheckShapeless._

  implicit val arbitraryBS = Arbitrary(
    implicitly[Arbitrary[Array[Byte]]].arbitrary
      .map(t => ByteString.copyFrom(t))
  )

  // Default scalacheck-shapeless would chose Unrecognized instances with recognized values.
  private def fixEnum[A <: GeneratedEnum](
      e: A
  )(implicit cmp: GeneratedEnumCompanion[A]): A = {
    if (e.isUnrecognized) cmp.values.find(_.value == e.value).getOrElse(e)
    else e
  }

  def arbitraryEnum[A <: GeneratedEnum: Arbitrary: GeneratedEnumCompanion] = {
    Arbitrary(implicitly[Arbitrary[A]].arbitrary.map(fixEnum(_)))
  }

  implicit val arbitraryUnknownFields = Arbitrary(
    Gen.const(UnknownFieldSet.empty)
  )

  implicit val nestedEnum2 = arbitraryEnum[AT2.EnumTest.NestedEnum]

  implicit val nestedEnum3 = arbitraryEnum[AT3.EnumTest.NestedEnum]

  implicit val topLevelEnum2 = arbitraryEnum[AT2.TopLevelEnum]

  implicit val topLevelEnum3 = arbitraryEnum[AT3.TopLevelEnum]

  implicit def arbitraryMessage[A <: GeneratedMessage](implicit
      ev: Strict[MkArbitrary[A]]
  ) = {
    implicitly[Arbitrary[A]]
  }
} 
Example 52
Source File: PayloadSerializer.scala    From akka-stream-eventsourcing   with Apache License 2.0 5 votes vote down vote up
package com.github.krasserm.ases.serializer

import akka.actor.ExtendedActorSystem
import akka.serialization.{SerializationExtension, SerializerWithStringManifest}
import com.github.krasserm.ases.serializer.PayloadFormatOuterClass.PayloadFormat
import com.google.protobuf.ByteString

import scala.util.Try

class PayloadSerializer(system: ExtendedActorSystem) {

  def payloadFormatBuilder(payload: AnyRef): PayloadFormat.Builder = {
    val serializer = SerializationExtension(system).findSerializerFor(payload)
    val builder = PayloadFormat.newBuilder()

    if (serializer.includeManifest) {
      val (isStringManifest, manifest) = serializer match {
        case s: SerializerWithStringManifest => (true, s.manifest(payload))
        case _ => (false, payload.getClass.getName)
      }
      builder.setIsStringManifest(isStringManifest)
      builder.setPayloadManifest(manifest)
    }
    builder.setSerializerId(serializer.identifier)
    builder.setPayload(ByteString.copyFrom(serializer.toBinary(payload)))
  }

  def payload(payloadFormat: PayloadFormat): AnyRef = {
    val payload = if (payloadFormat.getIsStringManifest)
      payloadFromStringManifest(payloadFormat)
    else if (payloadFormat.getPayloadManifest.nonEmpty)
      payloadFromClassManifest(payloadFormat)
    else
      payloadFromEmptyManifest(payloadFormat)

    payload.get
  }

  private def payloadFromStringManifest(payloadFormat: PayloadFormat): Try[AnyRef] = {
    SerializationExtension(system).deserialize(
      payloadFormat.getPayload.toByteArray,
      payloadFormat.getSerializerId,
      payloadFormat.getPayloadManifest
    )
  }

  private def payloadFromClassManifest(payloadFormat: PayloadFormat): Try[AnyRef]  = {
    val manifestClass = system.dynamicAccess.getClassFor[AnyRef](payloadFormat.getPayloadManifest).get
    SerializationExtension(system).deserialize(
      payloadFormat.getPayload.toByteArray,
      payloadFormat.getSerializerId,
      Some(manifestClass)
    )
  }

  private def payloadFromEmptyManifest(payloadFormat: PayloadFormat): Try[AnyRef]  = {
    SerializationExtension(system).deserialize(
      payloadFormat.getPayload.toByteArray,
      payloadFormat.getSerializerId,
      None
    )
  }
} 
Example 53
Source File: BigQueryTypeSpec.scala    From shapeless-datatype   with Apache License 2.0 5 votes vote down vote up
package shapeless.datatype.bigquery

import java.net.URI

import com.fasterxml.jackson.databind.{ObjectMapper, SerializationFeature}
import com.google.api.services.bigquery.model.TableRow
import com.google.common.io.BaseEncoding
import com.google.protobuf.ByteString
import org.joda.time.{Instant, LocalDate, LocalDateTime, LocalTime}
import org.scalacheck.Prop.forAll
import org.scalacheck.ScalacheckShapeless._
import org.scalacheck._
import shapeless._
import shapeless.datatype.record._

import scala.reflect.runtime.universe._

object BigQueryTypeSpec extends Properties("BigQueryType") {
  import shapeless.datatype.test.Records._
  import shapeless.datatype.test.SerializableUtils._

  val mapper = new ObjectMapper().disable(SerializationFeature.FAIL_ON_EMPTY_BEANS)

  implicit def compareByteArrays(x: Array[Byte], y: Array[Byte]) = java.util.Arrays.equals(x, y)
  implicit def compareIntArrays(x: Array[Int], y: Array[Int]) = java.util.Arrays.equals(x, y)

  def roundTrip[A: TypeTag, L <: HList](m: A)(implicit
    gen: LabelledGeneric.Aux[A, L],
    fromL: FromTableRow[L],
    toL: ToTableRow[L],
    mr: MatchRecord[L]
  ): Boolean = {
    BigQuerySchema[A] // FIXME: verify the generated schema
    val t = ensureSerializable(BigQueryType[A])
    val f1: SerializableFunction[A, TableRow] =
      new SerializableFunction[A, TableRow] {
        override def apply(m: A): TableRow = t.toTableRow(m)
      }
    val f2: SerializableFunction[TableRow, Option[A]] =
      new SerializableFunction[TableRow, Option[A]] {
        override def apply(m: TableRow): Option[A] = t.fromTableRow(m)
      }
    val toFn = ensureSerializable(f1)
    val fromFn = ensureSerializable(f2)
    val copy = fromFn(mapper.readValue(mapper.writeValueAsString(toFn(m)), classOf[TableRow]))
    val rm = RecordMatcher[A]
    copy.exists(rm(_, m))
  }

  implicit val byteStringBigQueryMappableType = BigQueryType.at[ByteString]("BYTES")(
    x => ByteString.copyFrom(BaseEncoding.base64().decode(x.toString)),
    x => BaseEncoding.base64().encode(x.toByteArray)
  )
  property("required") = forAll { m: Required => roundTrip(m) }
  property("optional") = forAll { m: Optional => roundTrip(m) }
  property("repeated") = forAll { m: Repeated => roundTrip(m) }
  property("mixed") = forAll { m: Mixed => roundTrip(m) }
  property("nested") = forAll { m: Nested => roundTrip(m) }
  property("seqs") = forAll { m: Seqs => roundTrip(m) }

  implicit val arbDate = Arbitrary(arbInstant.arbitrary.map(i => new LocalDate(i.getMillis)))
  implicit val arbTime = Arbitrary(arbInstant.arbitrary.map(i => new LocalTime(i.getMillis)))
  implicit val arbDateTime = Arbitrary(
    arbInstant.arbitrary.map(i => new LocalDateTime(i.getMillis))
  )

  case class DateTimeTypes(
    instant: Instant,
    date: LocalDate,
    time: LocalTime,
    dateTime: LocalDateTime
  )
  property("date time types") = forAll { m: DateTimeTypes => roundTrip(m) }

  implicit val uriBigQueryType =
    BigQueryType.at[URI]("STRING")(v => URI.create(v.toString), _.toASCIIString)
  property("custom") = forAll { m: Custom => roundTrip(m) }
} 
Example 54
Source File: AvroTypeSpec.scala    From shapeless-datatype   with Apache License 2.0 5 votes vote down vote up
package shapeless.datatype.avro

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.net.URI
import java.nio.ByteBuffer

import com.google.protobuf.ByteString
import org.apache.avro.Schema
import org.apache.avro.generic.{GenericDatumReader, GenericDatumWriter, GenericRecord}
import org.apache.avro.io.{DecoderFactory, EncoderFactory}
import org.joda.time.Instant
import org.scalacheck.Prop.forAll
import org.scalacheck.ScalacheckShapeless._
import org.scalacheck._
import shapeless._
import shapeless.datatype.record._

import scala.reflect.runtime.universe._

object AvroTypeSpec extends Properties("AvroType") {
  import shapeless.datatype.test.Records._
  import shapeless.datatype.test.SerializableUtils._

  implicit def compareByteArrays(x: Array[Byte], y: Array[Byte]) = java.util.Arrays.equals(x, y)
  implicit def compareIntArrays(x: Array[Int], y: Array[Int]) = java.util.Arrays.equals(x, y)

  def roundTrip[A: TypeTag, L <: HList](m: A)(implicit
    gen: LabelledGeneric.Aux[A, L],
    fromL: FromAvroRecord[L],
    toL: ToAvroRecord[L],
    mr: MatchRecord[L]
  ): Boolean = {
    val t = ensureSerializable(AvroType[A])
    val f1: SerializableFunction[A, GenericRecord] =
      new SerializableFunction[A, GenericRecord] {
        override def apply(m: A): GenericRecord = t.toGenericRecord(m)
      }
    val f2: SerializableFunction[GenericRecord, Option[A]] =
      new SerializableFunction[GenericRecord, Option[A]] {
        override def apply(m: GenericRecord): Option[A] = t.fromGenericRecord(m)
      }
    val toFn = ensureSerializable(f1)
    val fromFn = ensureSerializable(f2)
    val copy = fromFn(roundTripRecord(toFn(m)))
    val rm = RecordMatcher[A]
    copy.exists(rm(_, m))
  }

  def roundTripRecord(r: GenericRecord): GenericRecord = {
    val writer = new GenericDatumWriter[GenericRecord](r.getSchema)
    val baos = new ByteArrayOutputStream()
    val encoder = EncoderFactory.get().binaryEncoder(baos, null)
    writer.write(r, encoder)
    encoder.flush()
    baos.close()
    val bytes = baos.toByteArray

    val reader = new GenericDatumReader[GenericRecord](r.getSchema)
    val bais = new ByteArrayInputStream(bytes)
    val decoder = DecoderFactory.get().binaryDecoder(bais, null)
    reader.read(null, decoder)
  }

  implicit val byteStringAvroType = AvroType.at[ByteString](Schema.Type.BYTES)(
    v => ByteString.copyFrom(v.asInstanceOf[ByteBuffer]),
    v => ByteBuffer.wrap(v.toByteArray)
  )
  implicit val instantAvroType =
    AvroType.at[Instant](Schema.Type.LONG)(v => new Instant(v.asInstanceOf[Long]), _.getMillis)
  property("required") = forAll { m: Required => roundTrip(m) }
  property("optional") = forAll { m: Optional => roundTrip(m) }
  property("repeated") = forAll { m: Repeated => roundTrip(m) }
  property("mixed") = forAll { m: Mixed => roundTrip(m) }
  property("nested") = forAll { m: Nested => roundTrip(m) }
  property("seqs") = forAll { m: Seqs => roundTrip(m) }

  implicit val uriAvroType =
    AvroType.at[URI](Schema.Type.STRING)(v => URI.create(v.toString), _.toString)
  property("custom") = forAll { m: Custom => roundTrip(m) }
} 
Example 55
Source File: DatastoreType.scala    From shapeless-datatype   with Apache License 2.0 5 votes vote down vote up
package shapeless.datatype.datastore

import com.google.datastore.v1.client.DatastoreHelper.makeValue
import com.google.datastore.v1.{Entity, Value}
import com.google.protobuf.{ByteString, Timestamp}
import org.joda.time.{DateTimeConstants, Instant}
import shapeless._

class DatastoreType[A] extends Serializable {
  def fromEntityBuilder[L <: HList](m: Entity.Builder)(implicit
    gen: LabelledGeneric.Aux[A, L],
    fromL: FromEntity[L]
  ): Option[A] =
    fromL(m).map(gen.from)

  def fromEntity[L <: HList](m: Entity)(implicit
    gen: LabelledGeneric.Aux[A, L],
    fromL: FromEntity[L]
  ): Option[A] =
    fromL(m.toBuilder).map(gen.from)

  def toEntityBuilder[L <: HList](a: A)(implicit
    gen: LabelledGeneric.Aux[A, L],
    toL: ToEntity[L]
  ): Entity.Builder =
    toL(gen.to(a))

  def toEntity[L <: HList](a: A)(implicit
    gen: LabelledGeneric.Aux[A, L],
    toL: ToEntity[L]
  ): Entity =
    toL(gen.to(a)).build()
}

object DatastoreType {
  def apply[A]: DatastoreType[A] = new DatastoreType[A]

  def at[V](fromFn: Value => V, toFn: V => Value): BaseDatastoreMappableType[V] =
    new BaseDatastoreMappableType[V] {
      override def from(value: Value): V = fromFn(value)
      override def to(value: V): Value = toFn(value)
    }
}

trait DatastoreMappableTypes {
  import DatastoreType.at

  implicit val booleanEntityMappableType = at[Boolean](_.getBooleanValue, makeValue(_).build())
  implicit val intDatastoreMappableType = at[Int](_.getIntegerValue.toInt, makeValue(_).build())
  implicit val longEntityMappableType = at[Long](_.getIntegerValue, makeValue(_).build())
  implicit val floatEntityMappableType = at[Float](_.getDoubleValue.toFloat, makeValue(_).build())
  implicit val doubleEntityMappableType = at[Double](_.getDoubleValue, makeValue(_).build())
  implicit val stringEntityMappableType = at[String](_.getStringValue, makeValue(_).build())
  implicit val byteStringEntityMappableType = at[ByteString](_.getBlobValue, makeValue(_).build())
  implicit val byteArrayEntityMappableType =
    at[Array[Byte]](_.getBlobValue.toByteArray, v => makeValue(ByteString.copyFrom(v)).build())
  implicit val timestampEntityMappableType = at[Instant](toInstant, fromInstant)

  private def toInstant(v: Value): Instant = {
    val t = v.getTimestampValue
    new Instant(t.getSeconds * DateTimeConstants.MILLIS_PER_SECOND + t.getNanos / 1000000)
  }
  private def fromInstant(i: Instant): Value = {
    val t = Timestamp
      .newBuilder()
      .setSeconds(i.getMillis / DateTimeConstants.MILLIS_PER_SECOND)
      .setNanos((i.getMillis % 1000).toInt * 1000000)
    Value.newBuilder().setTimestampValue(t).build()
  }
} 
Example 56
Source File: DatastoreMappableType.scala    From shapeless-datatype   with Apache License 2.0 5 votes vote down vote up
package shapeless.datatype.datastore

import com.google.datastore.v1.client.DatastoreHelper._
import com.google.datastore.v1.{Entity, Value}
import com.google.protobuf.{ByteString, Timestamp}
import org.joda.time.Instant
import org.joda.time.DateTimeConstants
import shapeless.datatype.mappable.{BaseMappableType, MappableType}

import scala.collection.JavaConverters._

trait BaseDatastoreMappableType[V] extends MappableType[Entity.Builder, V] {
  def from(value: Value): V
  def to(value: V): Value

  override def get(m: Entity.Builder, key: String): Option[V] =
    Option(m.getPropertiesMap.get(key)).map(from)
  override def getAll(m: Entity.Builder, key: String): Seq[V] =
    Option(m.getPropertiesMap.get(key)).toSeq
      .flatMap(_.getArrayValue.getValuesList.asScala.map(from))

  override def put(key: String, value: V, tail: Entity.Builder): Entity.Builder =
    tail.putProperties(key, to(value))
  override def put(key: String, value: Option[V], tail: Entity.Builder): Entity.Builder =
    value.foldLeft(tail)((b, v) => b.putProperties(key, to(v)))
  override def put(key: String, values: Seq[V], tail: Entity.Builder): Entity.Builder =
    tail.putProperties(key, makeValue(values.map(to).asJava).build())
}

trait DatastoreMappableType extends DatastoreMappableTypes {
  implicit val datastoreBaseMappableType = new BaseMappableType[Entity.Builder] {
    override def base: Entity.Builder = Entity.newBuilder()

    override def get(m: Entity.Builder, key: String): Option[Entity.Builder] =
      Option(m.getPropertiesMap.get(key)).map(_.getEntityValue.toBuilder)
    override def getAll(m: Entity.Builder, key: String): Seq[Entity.Builder] =
      Option(m.getPropertiesMap.get(key)).toSeq
        .flatMap(_.getArrayValue.getValuesList.asScala.map(_.getEntityValue.toBuilder))

    override def put(key: String, value: Entity.Builder, tail: Entity.Builder): Entity.Builder =
      tail.putProperties(key, makeValue(value).build())
    override def put(
      key: String,
      value: Option[Entity.Builder],
      tail: Entity.Builder
    ): Entity.Builder =
      value.foldLeft(tail)((b, v) => b.putProperties(key, makeValue(v).build()))
    override def put(
      key: String,
      values: Seq[Entity.Builder],
      tail: Entity.Builder
    ): Entity.Builder =
      tail.putProperties(key, makeValue(values.map(v => makeValue(v).build()).asJava).build())
  }
} 
Example 57
Source File: TensorFlowMappableType.scala    From shapeless-datatype   with Apache License 2.0 5 votes vote down vote up
package shapeless.datatype.tensorflow

import com.google.protobuf.ByteString
import org.tensorflow.example._
import shapeless.datatype.mappable.{BaseMappableType, MappableType}

trait BaseTensorFlowMappableType[V] extends MappableType[Features.Builder, V] {
  def from(value: Feature): V = fromSeq(value).head
  def to(value: V): Feature = toSeq(Seq(value))
  def fromSeq(value: Feature): Seq[V]
  def toSeq(value: Seq[V]): Feature

  override def get(m: Features.Builder, key: String): Option[V] =
    Option(m.getFeatureMap.get(key)).map(from)
  override def getAll(m: Features.Builder, key: String): Seq[V] =
    Option(m.getFeatureMap.get(key)).toSeq.flatMap(fromSeq)

  override def put(key: String, value: V, tail: Features.Builder): Features.Builder =
    tail.putFeature(key, to(value))
  override def put(key: String, value: Option[V], tail: Features.Builder): Features.Builder =
    value.foldLeft(tail)((b, v) => b.putFeature(key, to(v)))
  override def put(key: String, values: Seq[V], tail: Features.Builder): Features.Builder =
    tail.putFeature(key, toSeq(values))
}

trait TensorFlowMappableType {
  implicit val tensorFlowBaseMappableType = new BaseMappableType[Features.Builder] {
    override def base: Features.Builder = Features.newBuilder()
    override def get(m: Features.Builder, key: String): Option[Features.Builder] = ???
    override def getAll(m: Features.Builder, key: String): Seq[Features.Builder] = ???
    override def put(
      key: String,
      value: Features.Builder,
      tail: Features.Builder
    ): Features.Builder = ???
    override def put(
      key: String,
      value: Option[Features.Builder],
      tail: Features.Builder
    ): Features.Builder = ???
    override def put(
      key: String,
      values: Seq[Features.Builder],
      tail: Features.Builder
    ): Features.Builder = ???
  }

  import TensorFlowType._

  implicit val booleanTensorFlowMappableType = at[Boolean](toBooleans, fromBooleans)
  implicit val intTensorFlowMappableType = at[Int](toInts, fromInts)
  implicit val longTensorFlowMappableType = at[Long](toLongs, fromLongs)
  implicit val floatTensorFlowMappableType = at[Float](toFloats, fromFloats)
  implicit val doubleTensorFlowMappableType = at[Double](toDoubles, fromDoubles)
  implicit val byteStringTensorFlowMappableType = at[ByteString](toByteStrings, fromByteStrings)
  implicit val byteArrayTensorFlowMappableType = at[Array[Byte]](toByteArrays, fromByteArrays)
  implicit val stringTensorFlowMappableType = at[String](toStrings, fromStrings)
} 
Example 58
Source File: TensorFlowType.scala    From shapeless-datatype   with Apache License 2.0 5 votes vote down vote up
package shapeless.datatype.tensorflow

import com.google.protobuf.ByteString
import org.tensorflow.example._
import shapeless._

import scala.collection.JavaConverters._

class TensorFlowType[A] extends Serializable {
  def fromExampleBuilder[L <: HList](m: Example.Builder)(implicit
    gen: LabelledGeneric.Aux[A, L],
    fromL: FromFeatures[L]
  ): Option[A] =
    fromL(m.getFeaturesBuilder).map(gen.from)

  def fromExample[L <: HList](m: Example)(implicit
    gen: LabelledGeneric.Aux[A, L],
    fromL: FromFeatures[L]
  ): Option[A] =
    fromL(m.getFeatures.toBuilder).map(gen.from)

  def toExampleBuilder[L <: HList](a: A)(implicit
    gen: LabelledGeneric.Aux[A, L],
    toL: ToFeatures[L]
  ): Example.Builder =
    Example.newBuilder().setFeatures(toL(gen.to(a)))

  def toExample[L <: HList](a: A)(implicit
    gen: LabelledGeneric.Aux[A, L],
    toL: ToFeatures[L]
  ): Example =
    Example.newBuilder().setFeatures(toL(gen.to(a))).build()
}

object TensorFlowType {
  def apply[A]: TensorFlowType[A] = new TensorFlowType[A]

  def at[V](
    fromFn: Feature => Seq[V],
    toFn: Seq[V] => Feature.Builder
  ): BaseTensorFlowMappableType[V] =
    new BaseTensorFlowMappableType[V] {
      override def fromSeq(value: Feature): Seq[V] = fromFn(value)
      override def toSeq(value: Seq[V]): Feature = toFn(value).build()
    }

  def fromBooleans(xs: Seq[Boolean]): Feature.Builder =
    Feature
      .newBuilder()
      .setInt64List(
        Int64List.newBuilder().addAllValue(xs.map(x => (if (x) 1L else 0L): java.lang.Long).asJava)
      )
  def toBooleans(f: Feature): Seq[Boolean] =
    f.getInt64List.getValueList.asScala.map(x => if (x > 0) true else false).toSeq

  def fromLongs(xs: Seq[Long]): Feature.Builder =
    Feature.newBuilder().setInt64List(xs.foldLeft(Int64List.newBuilder())(_.addValue(_)).build())
  def toLongs(f: Feature): Seq[Long] =
    f.getInt64List.getValueList.asScala.toSeq.asInstanceOf[Seq[Long]]

  def fromInts(xs: Seq[Int]): Feature.Builder = fromLongs(xs.map(_.toLong))
  def toInts(f: Feature): Seq[Int] = toLongs(f).map(_.toInt)

  def fromFloats(xs: Seq[Float]): Feature.Builder =
    Feature.newBuilder().setFloatList(xs.foldLeft(FloatList.newBuilder())(_.addValue(_)).build())
  def toFloats(f: Feature): Seq[Float] =
    f.getFloatList.getValueList.asScala.toSeq.asInstanceOf[Seq[Float]]

  def fromDoubles(xs: Seq[Double]): Feature.Builder = fromFloats(xs.map(_.toFloat))
  def toDoubles(f: Feature): Seq[Double] = toFloats(f).map(_.toDouble)

  def fromByteStrings(xs: Seq[ByteString]): Feature.Builder =
    Feature.newBuilder().setBytesList(BytesList.newBuilder().addAllValue(xs.asJava))
  def toByteStrings(f: Feature): Seq[ByteString] = f.getBytesList.getValueList.asScala.toSeq

  def fromByteArrays(xs: Seq[Array[Byte]]): Feature.Builder =
    fromByteStrings(xs.map(ByteString.copyFrom))
  def toByteArrays(f: Feature): Seq[Array[Byte]] = toByteStrings(f).map(_.toByteArray)

  def fromStrings(xs: Seq[String]): Feature.Builder =
    fromByteStrings(xs.map(ByteString.copyFromUtf8))
  def toStrings(f: Feature): Seq[String] = toByteStrings(f).map(_.toStringUtf8)
} 
Example 59
Source File: Records.scala    From shapeless-datatype   with Apache License 2.0 5 votes vote down vote up
package shapeless.datatype.test

import java.net.URI

import com.google.protobuf.ByteString
import org.joda.time.Instant
import org.scalacheck._

object Records {
  case class Required(
    booleanField: Boolean,
    intField: Int,
    longField: Long,
    floatField: Float,
    doubleField: Double,
    stringField: String,
    byteStringField: ByteString,
    byteArrayField: Array[Byte],
    timestampField: Instant
  )
  case class Optional(
    booleanField: Option[Boolean],
    intField: Option[Int],
    longField: Option[Long],
    floatField: Option[Float],
    doubleField: Option[Double],
    stringField: Option[String],
    byteStringField: Option[ByteString],
    byteArrayField: Option[Array[Byte]],
    timestampField: Option[Instant]
  )
  case class Repeated(
    booleanField: List[Boolean],
    intField: List[Int],
    longField: List[Long],
    floatField: List[Float],
    doubleField: List[Double],
    stringField: List[String],
    byteStringField: List[ByteString],
    byteArrayField: List[Array[Byte]],
    timestampField: List[Instant]
  )
  case class Mixed(
    longField: Long,
    doubleField: Double,
    stringField: String,
    longFieldO: Option[Long],
    doubleFieldO: Option[Double],
    stringFieldO: Option[String],
    longFieldR: List[Long],
    doubleFieldR: List[Double],
    stringFieldR: List[String]
  )
  case class Nested(
    longField: Long,
    longFieldO: Option[Long],
    longFieldR: List[Long],
    mixedField: Mixed,
    mixedFieldO: Option[Mixed],
    mixedFieldR: List[Mixed]
  )
  case class Seqs(array: Array[Int], list: List[Int], vector: Vector[Int])
  case class Custom(uriField: URI, uriFieldO: Option[URI], uriFieldR: List[URI])

  implicit val arbByteString = Arbitrary(Gen.alphaStr.map(ByteString.copyFromUtf8))
  implicit val arbInstant = Arbitrary(Gen.chooseNum(0, Int.MaxValue).map(new Instant(_)))
  implicit val arbUri = Arbitrary(Gen.alphaStr.map(URI.create))
} 
Example 60
Source File: FieldReaderSpec.scala    From protobuf-generic   with Apache License 2.0 5 votes vote down vote up
package me.lyh.protobuf.generic.test

import com.google.protobuf.{ByteString, Message}
import me.lyh.protobuf.generic._
import me.lyh.protobuf.generic.proto3.Schemas._
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import scala.reflect.ClassTag

class FieldReaderSpec extends AnyFlatSpec with Matchers {
  def read[T <: Message: ClassTag](record: T, fields: List[String], expected: List[Any]): Unit = {
    val schema = SerializableUtils.ensureSerializable(Schema.of[T])
    val reader = SerializableUtils.ensureSerializable(FieldReader.of(schema, fields))
    val actual = reader.read(record.toByteArray)
    actual.toList shouldBe expected
  }

  val fields: List[String] = List(
    "double_field",
    "float_field",
    "int32_field",
    "int64_field",
    "uint32_field",
    "uint64_field",
    "sint32_field",
    "sint64_field",
    "fixed32_field",
    "fixed64_field",
    "sfixed32_field",
    "sfixed64_field",
    "bool_field",
    "string_field",
    "bytes_field",
    "color_field"
  )

  val expected: List[Any] = List(
    math.Pi,
    math.E.toFloat,
    10,
    15,
    20,
    25,
    30,
    35,
    40,
    45,
    50,
    55,
    true,
    "hello",
    ByteString.copyFromUtf8("world"),
    "WHITE"
  )

  "FieldReader" should "read optional" in {
    read[Optional](Records.optional, fields, expected)
    read[Optional](
      Records.optionalEmpty,
      fields,
      List(0.0, 0.0f, 0, 0L, 0, 0L, 0, 0L, 0, 0L, 0, 0L, false, "", ByteString.EMPTY, "BLACK")
    )
  }

  it should "read oneofs" in {
    (Records.oneOfs.drop(1) zip (fields zip expected)).foreach {
      case (r, (f, e)) =>
        read[OneOf](r, List(f), List(e))
    }
  }

  it should "read nested" in {
    val fields = List(
      "mixed_field_o.double_field_o",
      "mixed_field_o.string_field_o",
      "mixed_field_o.bytes_field_o",
      "mixed_field_o.color_field_o"
    )
    val expected = List(math.Pi, "hello", ByteString.copyFromUtf8("world"), "WHITE")
    read[Nested](Records.nested, fields, expected)

    val expectedEmpty = List(0.0, "", ByteString.EMPTY, "BLACK")
    read[Nested](Records.nestedEmpty, fields, expectedEmpty)
  }
} 
Example 61
Source File: ProtobufGenericSpec.scala    From protobuf-generic   with Apache License 2.0 5 votes vote down vote up
package me.lyh.protobuf.generic.test

import java.io.ByteArrayInputStream
import java.nio.ByteBuffer

import com.google.protobuf.{ByteString, Message}
import me.lyh.protobuf.generic._
import me.lyh.protobuf.generic.proto2.Schemas._

import scala.reflect.ClassTag
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class ProtobufGenericSpec extends AnyFlatSpec with Matchers {
  def roundTrip[T <: Message: ClassTag](record: T): Unit = {
    val schema = SerializableUtils.ensureSerializable(Schema.of[T])
    val schemaCopy = Schema.fromJson(schema.toJson)
    schemaCopy shouldBe schema

    val reader = SerializableUtils.ensureSerializable(GenericReader.of(schema))
    val writer = SerializableUtils.ensureSerializable(GenericWriter.of(schema))
    val jsonRecord = reader.read(record.toByteArray).toJson
    jsonRecord shouldBe reader.read(ByteBuffer.wrap(record.toByteArray)).toJson
    jsonRecord shouldBe reader.read(new ByteArrayInputStream(record.toByteArray)).toJson
    val bytes = writer.write(GenericRecord.fromJson(jsonRecord))

    val recordCopy = ProtobufType[T].parseFrom(bytes)
    recordCopy shouldBe record
  }

  "ProtobufGeneric" should "round trip required" in {
    roundTrip[Required](Records.required)
  }

  it should "round trip optional" in {
    roundTrip[Optional](Records.optional)
    roundTrip[Optional](Records.optionalEmpty)
  }

  it should "round trip repeated" in {
    roundTrip[Repeated](Records.repeated)
    roundTrip[Repeated](Records.repeatedEmpty)
    roundTrip[RepeatedPacked](Records.repeatedPacked)
    roundTrip[RepeatedUnpacked](Records.repeatedUnpacked)
  }

  it should "round trip oneofs" in {
    Records.oneOfs.foreach(roundTrip[OneOf])
  }

  it should "round trip mixed" in {
    roundTrip[Mixed](Records.mixed)
    roundTrip[Mixed](Records.mixedEmpty)
  }

  it should "round trip nested" in {
    roundTrip[Nested](Records.nested)
    roundTrip[Nested](Records.nestedEmpty)
  }

  it should "round trip with custom options" in {
    roundTrip[CustomOptionMessage](Records.customOptionMessage)
    roundTrip[CustomOptionMessage](Records.customOptionMessageEmpty)
  }

  it should "round trip with custom defaults" in {
    roundTrip[CustomDefaults](CustomDefaults.getDefaultInstance)
  }

  it should "populate default values" in {
    val schema = Schema.of[CustomDefaults]
    val record = GenericReader.of(schema).read(CustomDefaults.getDefaultInstance.toByteArray)
    record.get("double_field") shouldBe 101.0
    record.get("float_field") shouldBe 102.0f
    record.get("int32_field") shouldBe 103
    record.get("int64_field") shouldBe 104L
    record.get("uint32_field") shouldBe 105
    record.get("uint64_field") shouldBe 106L
    record.get("sint32_field") shouldBe 107
    record.get("sint64_field") shouldBe 108L
    record.get("fixed32_field") shouldBe 109
    record.get("fixed64_field") shouldBe 110L
    record.get("sfixed32_field") shouldBe 111
    record.get("sfixed64_field") shouldBe 112L
    record.get("bool_field") shouldBe true
    record.get("string_field") shouldBe "hello"
    record.get("bytes_field") shouldBe
      Base64.encode(ByteString.copyFromUtf8("world").toByteArray)
    record.get("color_field") shouldBe "GREEN"
  }
} 
Example 62
Source File: PubSubMessage.scala    From akka-cloudpubsub   with Apache License 2.0 5 votes vote down vote up
package com.qubit.pubsub.client

import java.time.{Instant, ZoneOffset, ZonedDateTime}

import com.google.protobuf.{ByteString, Timestamp}
import com.google.pubsub.v1.{
  PubsubMessage => PubSubMessageProto,
  ReceivedMessage => ReceivedPubSubMessageProto
}

import scala.collection.JavaConversions._

final case class PubSubMessage(
    payload: Array[Byte],
    msgId: Option[String] = None,
    publishTs: Option[ZonedDateTime] = None,
    attributes: Option[Map[String, String]] = None) {
  def toProto: PubSubMessageProto = {
    val builder = PubSubMessageProto.newBuilder()
    builder.setData(ByteString.copyFrom(payload))
    publishTs.foreach(
      ts =>
        builder.setPublishTime(
          Timestamp.newBuilder().setSeconds(ts.toEpochSecond).build()))
    msgId.foreach(id => builder.setMessageId(id))
    attributes.foreach(attr => builder.putAllAttributes(attr))
    builder.build()
  }
}

object PubSubMessage {
  def fromProto(proto: PubSubMessageProto): PubSubMessage = {
    val payload = proto.getData.toByteArray
    val msgId = Some(proto.getMessageId)
    val attributes = if (proto.getAttributesMap.isEmpty) { None } else {
      Some(proto.getAttributesMap.toMap)
    }
    val publishTs = if (proto.hasPublishTime) {
      Some(
        ZonedDateTime.ofInstant(
          Instant.ofEpochSecond(proto.getPublishTime.getSeconds),
          ZoneOffset.UTC))
    } else {
      None
    }

    PubSubMessage(payload,
                  msgId = msgId,
                  publishTs = publishTs,
                  attributes = attributes)
  }
}

final case class ReceivedPubSubMessage(ackId: String, payload: PubSubMessage)

object ReceivedPubSubMessage {
  def fromProto(proto: ReceivedPubSubMessageProto): ReceivedPubSubMessage = {
    val ackId = proto.getAckId
    val payload = PubSubMessage.fromProto(proto.getMessage)
    ReceivedPubSubMessage(ackId, payload)
  }
} 
Example 63
Source File: ProtobufAnySerializer.scala    From cloudstate   with Apache License 2.0 5 votes vote down vote up
package io.cloudstate.proxy

import akka.serialization.{BaseSerializer, SerializerWithStringManifest}
import akka.actor.ExtendedActorSystem
import com.google.protobuf.ByteString
import com.google.protobuf.any.{Any => pbAny}

final class ProtobufAnySerializer(override val system: ExtendedActorSystem)
    extends SerializerWithStringManifest
    with BaseSerializer {

  final override def manifest(o: AnyRef): String = o match {
    case any: pbAny => any.typeUrl
    case _ =>
      throw new IllegalArgumentException(s"$this only supports com.google.protobuf.any.Any, not ${o.getClass.getName}!")
  }

  final override def toBinary(o: AnyRef): Array[Byte] = o match {
    case any: pbAny => any.value.toByteArray
    case _ =>
      throw new IllegalArgumentException(s"$this only supports com.google.protobuf.any.Any, not ${o.getClass.getName}!")
  }

  final override def fromBinary(bytes: Array[Byte], manifest: String): AnyRef = manifest match {
    case null =>
      throw new IllegalArgumentException("null manifest detected instead of valid com.google.protobuf.any.Any.typeUrl")
    case typeUrl => pbAny(typeUrl, ByteString.copyFrom(bytes))
  }
} 
Example 64
Source File: Warmup.scala    From cloudstate   with Apache License 2.0 5 votes vote down vote up
package io.cloudstate.proxy

import akka.actor.{Actor, ActorLogging, ActorRef, Props, SupervisorStrategy, Terminated}
import com.google.protobuf.ByteString
import io.cloudstate.proxy.eventsourced.EventSourcedEntity.{Configuration, Stop}
import Warmup.Ready
import io.cloudstate.protocol.entity.{ClientAction, Reply}
import io.cloudstate.protocol.event_sourced.{EventSourcedReply, EventSourcedStreamIn, EventSourcedStreamOut}
import io.cloudstate.proxy.entity.{EntityCommand, UserFunctionReply}
import io.cloudstate.proxy.eventsourced.EventSourcedEntity

import scala.concurrent.duration._

object Warmup {
  def props(needsWarmup: Boolean): Props = Props(new Warmup(needsWarmup))

  case object Ready
}


class Warmup(needsWarmup: Boolean) extends Actor with ActorLogging {

  if (needsWarmup) {
    log.debug("Starting warmup...")

    val stateManager = context.watch(
      context.actorOf(EventSourcedEntity.props(
                        Configuration("warmup.Service", "###warmup", 30.seconds, 100),
                        "###warmup-entity",
                        self,
                        self,
                        self
                      ),
                      "entity")
    )

    stateManager ! EntityCommand(
      entityId = "###warmup-entity",
      name = "foo",
      payload = Some(com.google.protobuf.any.Any("url", ByteString.EMPTY))
    )

    context become warmingUp(stateManager)
  }

  // Default will be overriden above if we need to warm up
  override def receive = warm

  private def warmingUp(eventSourcedEntityManager: ActorRef): Receive = {
    case Ready => sender ! false
    case ConcurrencyEnforcer.Action(_, start) =>
      log.debug("Warmup received action, starting it.")
      start()
    case EventSourcedStreamIn(EventSourcedStreamIn.Message.Event(_), _) =>
    // Ignore
    case EventSourcedStreamIn(EventSourcedStreamIn.Message.Init(_), _) =>
      log.debug("Warmup got init.")
    // Ignore
    case EventSourcedStreamIn(EventSourcedStreamIn.Message.Command(cmd), _) =>
      log.debug("Warmup got forwarded command")
      // It's forwarded us our command, send it a reply
      eventSourcedEntityManager ! EventSourcedStreamOut(
        EventSourcedStreamOut.Message.Reply(
          EventSourcedReply(
            commandId = cmd.id,
            clientAction = Some(
              ClientAction(ClientAction.Action.Reply(Reply(Some(com.google.protobuf.any.Any("url", ByteString.EMPTY)))))
            )
          )
        )
      )
    case _: UserFunctionReply =>
      log.debug("Warmup got forwarded reply")
      // It's forwarded the reply, now stop it
      eventSourcedEntityManager ! Stop
    case Terminated(_) =>
      log.info("Warmup complete")
      context.become(warm)
    case other =>
      // There are a few other messages we'll receive that we don't care about
      log.debug("Warmup received {}", other.getClass)
  }

  private def warm: Receive = {
    case Ready => sender ! true
  }

  override def supervisorStrategy: SupervisorStrategy = SupervisorStrategy.stoppingStrategy
} 
Example 65
Source File: AnySupportSpec.scala    From cloudstate   with Apache License 2.0 5 votes vote down vote up
package io.cloudstate.javasupport.impl

import com.example.shoppingcart.Shoppingcart
import com.google.protobuf.{ByteString, Empty}
import io.cloudstate.javasupport.Jsonable
import io.cloudstate.protocol.entity.UserFunctionError
import io.cloudstate.protocol.event_sourced.EventSourcedProto
import org.scalatest.{Matchers, OptionValues, WordSpec}

import scala.beans.BeanProperty

class AnySupportSpec extends WordSpec with Matchers with OptionValues {

  private val anySupport = new AnySupport(Array(Shoppingcart.getDescriptor, EventSourcedProto.javaDescriptor),
                                          getClass.getClassLoader,
                                          "com.example")
  private val addLineItem = Shoppingcart.AddLineItem
    .newBuilder()
    .setName("item")
    .setProductId("id")
    .setQuantity(10)
    .build()

  "Any support" should {

    "support se/deserializing java protobufs" in {
      val any = anySupport.encodeScala(addLineItem)
      any.typeUrl should ===("com.example/" + Shoppingcart.AddLineItem.getDescriptor.getFullName)
      anySupport.decode(any) should ===(addLineItem)
    }

    "support se/deserializing scala protobufs" in {
      val error = UserFunctionError("error")
      val any = anySupport.encodeScala(UserFunctionError("error"))
      any.typeUrl should ===("com.example/cloudstate.UserFunctionError")
      anySupport.decode(any) should ===(error)
    }

    "support resolving a service descriptor" in {
      val methods = anySupport.resolveServiceDescriptor(Shoppingcart.getDescriptor.findServiceByName("ShoppingCart"))
      methods should have size 3
      val method = methods("AddItem")

      // Input type
      method.inputType.typeUrl should ===("com.example/" + Shoppingcart.AddLineItem.getDescriptor.getFullName)
      method.inputType.typeClass should ===(classOf[Shoppingcart.AddLineItem])
      val iBytes = method.inputType.asInstanceOf[ResolvedType[Any]].toByteString(addLineItem)
      method.inputType.parseFrom(iBytes) should ===(addLineItem)

      // Output type - this also checks that when java_multiple_files is true, it works
      method.outputType.typeUrl should ===("com.example/" + Empty.getDescriptor.getFullName)
      method.outputType.typeClass should ===(classOf[Empty])
      val oBytes = method.outputType.asInstanceOf[ResolvedType[Any]].toByteString(Empty.getDefaultInstance)
      method.outputType.parseFrom(oBytes) should ===(Empty.getDefaultInstance)
    }

    def testPrimitive[T](name: String, value: T, defaultValue: T) = {
      val any = anySupport.encodeScala(value)
      any.typeUrl should ===(AnySupport.CloudStatePrimitive + name)
      anySupport.decode(any) should ===(value)

      val defaultAny = anySupport.encodeScala(defaultValue)
      defaultAny.typeUrl should ===(AnySupport.CloudStatePrimitive + name)
      defaultAny.value.size() shouldBe 0
      anySupport.decode(defaultAny) should ===(defaultValue)
    }

    "support se/deserializing strings" in testPrimitive("string", "foo", "")
    "support se/deserializing ints" in testPrimitive("int32", 10, 0)
    "support se/deserializing longs" in testPrimitive("int64", 10L, 0L)
    "support se/deserializing floats" in testPrimitive("float", 0.5f, 0f)
    "support se/deserializing doubles" in testPrimitive("double", 0.5d, 0d)
    "support se/deserializing bytes" in testPrimitive("bytes", ByteString.copyFromUtf8("foo"), ByteString.EMPTY)
    "support se/deserializing booleans" in testPrimitive("bool", true, false)

    "support se/deserializing json" in {
      val myJsonable = new MyJsonable
      myJsonable.field = "foo"
      val any = anySupport.encodeScala(myJsonable)
      any.typeUrl should ===(AnySupport.CloudStateJson + classOf[MyJsonable].getName)
      anySupport.decode(any).asInstanceOf[MyJsonable].field should ===("foo")
    }

  }

}

@Jsonable
class MyJsonable {
  @BeanProperty var field: String = _
} 
Example 66
Source File: ProtoRDDConversions.scala    From sparksql-protobuf   with Apache License 2.0 5 votes vote down vote up
package com.github.saurfang.parquet.proto.spark.sql

import com.google.protobuf.{ByteString, AbstractMessage}
import com.google.protobuf.Descriptors.{EnumValueDescriptor, FieldDescriptor}
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._
import org.apache.spark.sql.Row

object ProtoRDDConversions {
  def messageToRow[A <: AbstractMessage](message: A): Row = {
    import collection.JavaConversions._

    def toRowData(fd: FieldDescriptor, obj: AnyRef) = {
      fd.getJavaType match {
        case BYTE_STRING => obj.asInstanceOf[ByteString].toByteArray
        case ENUM => obj.asInstanceOf[EnumValueDescriptor].getName
        case MESSAGE => messageToRow(obj.asInstanceOf[AbstractMessage])
        case _ => obj
      }
    }

    val fieldDescriptors = message.getDescriptorForType.getFields
    val fields = message.getAllFields
    Row(
      fieldDescriptors.map{
        fd =>
          if(fields.containsKey(fd)) {
            val obj = fields.get(fd)
            if(fd.isRepeated) {
              obj.asInstanceOf[java.util.List[Object]].map(toRowData(fd, _)).toSeq
            } else {
              toRowData(fd, obj)
            }
          } else if(fd.isRepeated) {
            Seq()
          } else null
      }.toSeq: _*
    )
  }
} 
Example 67
Source File: ProtoRDDConversionSuite.scala    From sparksql-protobuf   with Apache License 2.0 5 votes vote down vote up
package com.github.saurfang.parquet.proto.spark.sql

import com.github.saurfang.parquet.proto.AddressBook.Person
import com.github.saurfang.parquet.proto.AddressBook.Person.{EmptyMessage, PhoneNumber}
import com.github.saurfang.parquet.proto.Simple.SimpleMessage
import com.google.protobuf.ByteString
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext}
import org.apache.spark.sql.{Row, SQLContext}
import org.scalatest.{FunSuite, Matchers}
import ProtoRDDConversions._

class ProtoRDDConversionSuite extends FunSuite with Matchers {
  test("convert protobuf with simple data type to dataframe") {
    val protoMessage =
      SimpleMessage.newBuilder()
        .setBoolean(true)
        .setDouble(1)
        .setFloat(1F)
        .setInt(1)
        .setLong(1L)
        .setFint(2)
        .setFlong(2L)
        .setSfint(3)
        .setSflong(3L)
        .setSint(-4)
        .setSlong(-4)
        .setString("")
        .setUint(5)
        .setUlong(5L)
        .build

    val protoRow = messageToRow(protoMessage)
    protoRow shouldBe
      Row(
        1.0, // double
        1.0F, // float
        1, // int
        1L, // long
        5, // uint
        5L, // ulong
        -4, // sint
        -4L, // slong
        2, // fint
        2L, // flong
        3, // sfint
        3L, // sflong
        true, // boolean
        "", // String
        null // ByteString
      )
  }

  test("convert protobuf with byte string") {
    val bytes = Array[Byte](1, 2, 3, 4)
    val protoMessage =
        SimpleMessage.newBuilder()
          .setByteString(ByteString.copyFrom(bytes))
          .build
    messageToRow(protoMessage).toSeq.last shouldBe bytes
  }

  test("convert protobuf with repeated fields") {
    val protoMessage =
      Person.newBuilder()
        .setName("test")
        .setId(0)
        .addAddress("ABC")
        .addAddress("CDE")
        .addPhone(PhoneNumber.newBuilder().setNumber("12345").setType(Person.PhoneType.MOBILE))
        .build
    val protoRow = messageToRow(protoMessage)
    protoRow shouldBe Row("test", 0, null, Seq(Row("12345", "MOBILE")), Seq("ABC", "CDE"), null)
  }

  test("convert protobuf with empty repeated fields") {
    val protoMessage = Person.newBuilder().setName("test").setId(0).build()
    val protoRow = messageToRow(protoMessage)
    protoRow shouldBe Row("test", 0, null, Seq(), Seq(), null)
  }
} 
Example 68
Source File: EventRecord.scala    From tensorflow_scala   with Apache License 2.0 5 votes vote down vote up
package org.platanios.tensorflow.api.io.events

import org.platanios.tensorflow.proto.TensorProto

import com.google.protobuf.ByteString


trait EventRecord[T] {
  val wallTime: Double
  val step: Long
  val value: T
}

case class ScalarEventRecord(
    override val wallTime: Double,
    override val step: Long,
    override val value: Float
) extends EventRecord[Float]

case class ImageEventRecord(
    override val wallTime: Double,
    override val step: Long,
    override val value: ImageValue
) extends EventRecord[ImageValue]

case class ImageValue(encodedImage: ByteString, width: Int, height: Int, colorSpace: Int)

case class AudioEventRecord(
    override val wallTime: Double,
    override val step: Long,
    override val value: AudioValue
) extends EventRecord[AudioValue]

case class AudioValue(
    encodedAudio: ByteString, contentType: String, sampleRate: Float, numChannels: Long, lengthFrames: Long)

case class HistogramEventRecord(
    override val wallTime: Double,
    override val step: Long,
    override val value: HistogramValue
) extends EventRecord[HistogramValue]

case class HistogramValue(
    min: Double, max: Double, num: Double, sum: Double, sumSquares: Double, bucketLimits: Seq[Double],
    buckets: Seq[Double])

case class CompressedHistogramEventRecord(
    override val wallTime: Double,
    override val step: Long,
    override val value: Seq[HistogramValue]
) extends EventRecord[Seq[HistogramValue]]

case class TensorEventRecord(
    override val wallTime: Double,
    override val step: Long,
    override val value: TensorProto
) extends EventRecord[TensorProto] 
Example 69
Source File: InMemoryTraceBackendServiceIntegrationTestSpec.scala    From haystack-traces   with Apache License 2.0 5 votes vote down vote up
package com.expedia.www.haystack.trace.storage.backends.memory.integration

import java.util.UUID

import com.expedia.open.tracing.backend.{ReadSpansRequest, TraceRecord, WriteSpansRequest}
import com.google.protobuf.ByteString

class InMemoryTraceBackendServiceIntegrationTestSpec extends BaseIntegrationTestSpec {


  describe("In Memory Persistence Service read trace records") {
    it("should get trace records for given traceID from in memory") {
      Given("trace-record ")
      val traceId = UUID.randomUUID().toString
      val serializedSpans =  createSerializedSpanBuffer(traceId)
      val traceRecord = TraceRecord.newBuilder()
        .setTraceId(traceId)
        .setSpans(ByteString.copyFrom(serializedSpans))
        .setTimestamp(System.currentTimeMillis())
        .build()

      When("write span is invoked")
      val writeSpanRequest = WriteSpansRequest.newBuilder().addRecords(traceRecord).build()
     val response =  client.writeSpans(writeSpanRequest)

      Then("should be able to retrieve the trace-record back")
      val readSpansResponse =  client.readSpans(ReadSpansRequest.newBuilder().addTraceIds(traceId).build())
      readSpansResponse.getRecordsCount shouldBe 1
      readSpansResponse.getRecordsCount shouldEqual  1
      readSpansResponse.getRecordsList.get(0).getTraceId shouldEqual traceId
    }
  }
} 
Example 70
Source File: CassandraTraceRecordReadResultListener.scala    From haystack-traces   with Apache License 2.0 5 votes vote down vote up
package com.expedia.www.haystack.trace.storage.backends.cassandra.store

import com.codahale.metrics.{Meter, Timer}
import com.datastax.driver.core.exceptions.NoHostAvailableException
import com.datastax.driver.core.{ResultSet, ResultSetFuture, Row}
import com.expedia.open.tracing.api.Trace
import com.expedia.open.tracing.backend.TraceRecord
import com.expedia.www.haystack.trace.storage.backends.cassandra.client.CassandraTableSchema
import com.google.protobuf.ByteString
import org.slf4j.{Logger, LoggerFactory}

import scala.collection.JavaConverters._
import scala.concurrent.Promise
import scala.util.{Failure, Success, Try}

object CassandraTraceRecordReadResultListener {
  protected val LOGGER: Logger = LoggerFactory.getLogger(classOf[CassandraTraceRecordReadResultListener])
}

class CassandraTraceRecordReadResultListener(asyncResult: ResultSetFuture,
                                             timer: Timer.Context,
                                             failure: Meter,
                                             promise: Promise[Seq[TraceRecord]]) extends Runnable {

  import CassandraTraceRecordReadResultListener._

  override def run(): Unit = {
    timer.close()

    Try(asyncResult.get)
      .flatMap(tryGetTraceRows)
      .flatMap(mapTraceRecords)
    match {
      case Success(records) =>
        promise.success(records)
      case Failure(ex) =>
        if (fatalError(ex)) {
          LOGGER.error("Fatal error in reading from cassandra, tearing down the app", ex)
        } else {
          LOGGER.error("Failed in reading the record from cassandra", ex)
        }
        failure.mark()
        promise.failure(ex)
    }
  }

  private def fatalError(ex: Throwable): Boolean = {
    if (ex.isInstanceOf[NoHostAvailableException]) true else ex.getCause != null && fatalError(ex.getCause)
  }

  private def tryGetTraceRows(resultSet: ResultSet): Try[Seq[Row]] = {
    val rows = resultSet.all().asScala
    if (rows.isEmpty) Failure(new RuntimeException()) else Success(rows)
  }

  private def mapTraceRecords(rows: Seq[Row]): Try[List[TraceRecord]] = {
    Try {
      rows.map(row => {
        val spanBytes = row.getBytes(CassandraTableSchema.SPANS_COLUMN_NAME).array()
        val timeStamp = row.getLong(CassandraTableSchema.TIMESTAMP_COLUMN_NAME)
        val traceId = row.getString(CassandraTableSchema.ID_COLUMN_NAME)
        val record = TraceRecord.newBuilder()
          .setSpans(ByteString.copyFrom(spanBytes))
          .setTimestamp(timeStamp)
          .setTraceId(traceId)
          .build()
        record
      }).toList
    }
  }
} 
Example 71
Source File: BigSamplerProto.scala    From ratatool   with Apache License 2.0 5 votes vote down vote up
package com.spotify.ratatool.samplers

import com.google.common.hash.Hasher
import com.google.protobuf.{AbstractMessage, ByteString}
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType
import org.slf4j.LoggerFactory

import scala.reflect.ClassTag

private[samplers] object BigSamplerProto {
  private val log = LoggerFactory.getLogger(BigSamplerProto.getClass)

  
  private[samplers] def buildKey(distributionFields: Seq[String])(m: AbstractMessage)
  : List[String] = {
    distributionFields.map(f => getProtobufField(m, f)).toSet
      .map { x: Any =>
        // can't call toString on null
        if (x == null) {
          "null"
        } else {
          x.toString
        }
      }.toList.sorted
  }

  // scalastyle:off cyclomatic.complexity
  private[samplers] def hashProtobufField[T <: AbstractMessage : ClassTag](m: T,
                                                                           fieldStr: String,
                                                                           hasher: Hasher)
  : Hasher = {
    val subfields = fieldStr.split(BigSampler.fieldSep)
    val field = Option(m.getDescriptorForType.findFieldByName(subfields.head)).getOrElse {
      throw new NoSuchElementException(s"Can't find field $fieldStr in protobuf schema")
    }
    val v = m.getField(field)
    if (v == null) {
      log.debug(s"Field `${field.getFullName}` of type ${field.getType} is null - won't account" +
        s" for hash")
      hasher
    } else {
      field.getJavaType match {
        case JavaType.MESSAGE => hashProtobufField(
          v.asInstanceOf[AbstractMessage], subfields.tail.mkString("."), hasher)
        case JavaType.INT => hasher.putLong(v.asInstanceOf[Int].toLong)
        case JavaType.LONG => hasher.putLong(v.asInstanceOf[Long])
        case JavaType.FLOAT => hasher.putFloat(v.asInstanceOf[Float])
        case JavaType.DOUBLE => hasher.putDouble(v.asInstanceOf[Double])
        case JavaType.BOOLEAN => hasher.putBoolean(v.asInstanceOf[Boolean])
        case JavaType.STRING => hasher.putString(v.asInstanceOf[CharSequence],
          BigSampler.utf8Charset)
        case JavaType.BYTE_STRING => hasher.putBytes(v.asInstanceOf[ByteString].toByteArray)
        case JavaType.ENUM => hasher.putString(v.asInstanceOf[Enum[_]].name, BigSampler.utf8Charset)
        // Array, Union
      }
    }
  }
  // scalastyle:on cyclomatic.complexity

  // scalastyle:off cyclomatic.complexity
  private[samplers] def getProtobufField[T <: AbstractMessage : ClassTag](m: T, fieldStr: String)
  : Any = {
    val subfields = fieldStr.split(BigSampler.fieldSep)
    val field = Option(m.getDescriptorForType.findFieldByName(subfields.head)).getOrElse {
      throw new NoSuchElementException(s"Can't find field $fieldStr in protobuf schema")
    }
    val v = m.getField(field)
    if (v == null) {
      log.debug(s"Field `${field.getFullName}` of type ${field.getType} is null - won't account" +
        s" for key")
    } else {
      field.getJavaType match {
        case JavaType.MESSAGE => getProtobufField(
          v.asInstanceOf[AbstractMessage], subfields.tail.mkString("."))
        case _ => v
      }
    }
  }
  // scalastyle:on cyclomatic.complexity
} 
Example 72
Source File: MessageSerializer.scala    From aecor   with MIT License 5 votes vote down vote up
package aecor.runtime.akkageneric.serialization

import aecor.runtime.akkageneric.GenericAkkaRuntime.KeyedCommand
import aecor.runtime.akkageneric.GenericAkkaRuntimeActor.{ Command, CommandResult }
import akka.actor.ExtendedActorSystem
import akka.serialization.{ BaseSerializer, SerializerWithStringManifest }
import com.google.protobuf.ByteString
import scodec.bits.BitVector

import scala.collection.immutable.HashMap

class MessageSerializer(val system: ExtendedActorSystem)
    extends SerializerWithStringManifest
    with BaseSerializer {

  val KeyedCommandManifest = "A"
  val CommandManifest = "B"
  val CommandResultManifest = "C"

  private val fromBinaryMap =
    HashMap[String, Array[Byte] => AnyRef](
      KeyedCommandManifest -> keyedCommandFromBinary,
      CommandManifest -> commandFromBinary,
      CommandResultManifest -> commandResultFromBinary
    )

  override def manifest(o: AnyRef): String = o match {
    case KeyedCommand(_, _) => KeyedCommandManifest
    case Command(_)         => CommandManifest
    case CommandResult(_)   => CommandResultManifest
    case x                  => throw new IllegalArgumentException(s"Serialization of [$x] is not supported")
  }

  override def toBinary(o: AnyRef): Array[Byte] = o match {
    case Command(bytes) =>
      bytes.toByteArray
    case CommandResult(bytes) =>
      bytes.toByteArray
    case x @ KeyedCommand(_, _) =>
      entityCommandToBinary(x)
    case x => throw new IllegalArgumentException(s"Serialization of [$x] is not supported")
  }

  override def fromBinary(bytes: Array[Byte], manifest: String): AnyRef =
    fromBinaryMap.get(manifest) match {
      case Some(f) => f(bytes)
      case other   => throw new IllegalArgumentException(s"Unknown manifest [$other]")
    }

  private def entityCommandToBinary(a: KeyedCommand): Array[Byte] =
    msg.KeyedCommand(a.key, ByteString.copyFrom(a.bytes.toByteBuffer)).toByteArray

  private def keyedCommandFromBinary(bytes: Array[Byte]): KeyedCommand =
    msg.KeyedCommand.parseFrom(bytes) match {
      case msg.KeyedCommand(key, commandBytes) =>
        KeyedCommand(key, BitVector(commandBytes.asReadOnlyByteBuffer()))
    }

  private def commandFromBinary(bytes: Array[Byte]): Command =
    Command(BitVector(bytes))

  private def commandResultFromBinary(bytes: Array[Byte]): CommandResult =
    CommandResult(BitVector(bytes))
} 
Example 73
Source File: MessageSerializer.scala    From aecor   with MIT License 5 votes vote down vote up
package aecor.runtime.akkapersistence.serialization

import aecor.runtime.akkapersistence.AkkaPersistenceRuntime.EntityCommand
import aecor.runtime.akkapersistence.AkkaPersistenceRuntimeActor.{ CommandResult, HandleCommand }
import akka.actor.ExtendedActorSystem
import akka.serialization.{ BaseSerializer, SerializerWithStringManifest }
import com.google.protobuf.ByteString
import scodec.bits.BitVector

import scala.collection.immutable._

class MessageSerializer(val system: ExtendedActorSystem)
    extends SerializerWithStringManifest
    with BaseSerializer {

  val HandleCommandManifest = "A"
  val EntityCommandManifest = "B"
  val CommandResultManifest = "C"

  private val fromBinaryMap =
    HashMap[String, Array[Byte] => AnyRef](
      HandleCommandManifest -> handleCommandFromBinary,
      EntityCommandManifest -> entityCommandFromBinary,
      CommandResultManifest -> commandResultFromBinary
    )

  override def manifest(o: AnyRef): String = o match {
    case HandleCommand(_)    => HandleCommandManifest
    case EntityCommand(_, _) => EntityCommandManifest
    case CommandResult(_)    => CommandResultManifest
    case x                   => throw new IllegalArgumentException(s"Serialization of [$x] is not supported")
  }

  override def toBinary(o: AnyRef): Array[Byte] = o match {
    case x @ HandleCommand(_) =>
      x.commandBytes.toByteArray
    case _ @CommandResult(resultBytes) =>
      resultBytes.toByteArray
    case x @ EntityCommand(_, _) =>
      entityCommandToBinary(x)
    case x => throw new IllegalArgumentException(s"Serialization of [$x] is not supported")
  }

  override def fromBinary(bytes: Array[Byte], manifest: String): AnyRef =
    fromBinaryMap.get(manifest) match {
      case Some(f) => f(bytes)
      case other   => throw new IllegalArgumentException(s"Unknown manifest [$other]")
    }

  private def entityCommandToBinary(a: EntityCommand): Array[Byte] =
    msg.EntityCommand(a.entityKey, ByteString.copyFrom(a.commandBytes.toByteBuffer)).toByteArray

  private def entityCommandFromBinary(bytes: Array[Byte]): EntityCommand =
    msg.EntityCommand.parseFrom(bytes) match {
      case msg.EntityCommand(entityId, commandBytes) =>
        EntityCommand(entityId, BitVector(commandBytes.asReadOnlyByteBuffer))
    }

  private def handleCommandFromBinary(bytes: Array[Byte]): HandleCommand =
    HandleCommand(BitVector(bytes))

  private def commandResultFromBinary(bytes: Array[Byte]): CommandResult =
    CommandResult(BitVector(bytes))

} 
Example 74
Source File: TestServiceImpl.scala    From akka-grpc   with Apache License 2.0 5 votes vote down vote up
package akka.grpc.interop

import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.reflect.ClassTag
import scala.collection.immutable

import akka.grpc.scaladsl.{GrpcMarshalling}

import akka.NotUsed
import akka.actor.ActorSystem
import akka.grpc._
import akka.stream.scaladsl.{Flow, Source}
import akka.stream.{ Materializer, SystemMaterializer }

import com.google.protobuf.ByteString
import io.grpc.{ Status, StatusRuntimeException }

// Generated by our plugin
import io.grpc.testing.integration.test.TestService
import io.grpc.testing.integration.messages._
import io.grpc.testing.integration.empty.Empty

object TestServiceImpl {
  val parametersToResponseFlow: Flow[ResponseParameters, StreamingOutputCallResponse, NotUsed] =
    Flow[ResponseParameters]
      .map { parameters =>
        StreamingOutputCallResponse(
          Some(Payload(body = ByteString.copyFrom(new Array[Byte](parameters.size)))))
      }
}


class TestServiceImpl(implicit sys: ActorSystem) extends TestService {
  import TestServiceImpl._

  implicit val mat: Materializer = SystemMaterializer(sys).materializer
  implicit val ec: ExecutionContext = sys.dispatcher
  
  override def emptyCall(req: Empty) =
    Future.successful(Empty())

  override def unaryCall(req: SimpleRequest): Future[SimpleResponse] = {
    req.responseStatus match {
      case None =>
        Future.successful(SimpleResponse(Some(Payload(ByteString.copyFrom(new Array[Byte](req.responseSize))))))
      case Some(requestStatus) =>
        val responseStatus = Status.fromCodeValue(requestStatus.code).withDescription(requestStatus.message)
        //  - Either one of the following works
        Future.failed(new GrpcServiceException(responseStatus))
        // throw new GrpcServiceException(responseStatus)
    }
  }

  override def cacheableUnaryCall(in: SimpleRequest): Future[SimpleResponse] = ???

  override def fullDuplexCall(in: Source[StreamingOutputCallRequest, NotUsed]): Source[StreamingOutputCallResponse, NotUsed] =
    in.map(req => {
      req.responseStatus.foreach(reqStatus =>
        throw new GrpcServiceException(
          Status.fromCodeValue(reqStatus.code).withDescription(reqStatus.message)))
      req
    }).mapConcat(
      _.responseParameters.to[immutable.Seq]).via(parametersToResponseFlow)

  override def halfDuplexCall(in: Source[StreamingOutputCallRequest, NotUsed]): Source[StreamingOutputCallResponse, NotUsed] = ???

  override def streamingInputCall(in: Source[StreamingInputCallRequest, NotUsed]): Future[StreamingInputCallResponse] = {
    in
      .map(_.payload.map(_.body.size).getOrElse(0))
      .runFold(0)(_ + _)
      .map { sum =>
        StreamingInputCallResponse(sum)
      }
  }

  override def streamingOutputCall(in: StreamingOutputCallRequest): Source[StreamingOutputCallResponse, NotUsed] =
    Source(in.responseParameters.to[immutable.Seq]).via(parametersToResponseFlow)

  override def unimplementedCall(in: Empty): Future[Empty] = ???
} 
Example 75
Source File: TestServiceImpl.scala    From akka-grpc   with Apache License 2.0 5 votes vote down vote up
package akka.grpc.interop

import akka.NotUsed
import akka.actor.ActorSystem
import akka.grpc.GrpcServiceException
import akka.stream.{ Materializer, SystemMaterializer }
import akka.stream.scaladsl.{ Flow, Source }

import com.google.protobuf.ByteString

import io.grpc.Status
import io.grpc.testing.integration.empty.Empty

import scala.concurrent.{ ExecutionContext, Future }

// Generated by our plugin
import io.grpc.testing.integration.messages._
import io.grpc.testing.integration.test.TestService

object TestServiceImpl {
  val parametersToResponseFlow: Flow[ResponseParameters, StreamingOutputCallResponse, NotUsed] =
    Flow[ResponseParameters].map { parameters =>
      StreamingOutputCallResponse(Some(Payload(body = ByteString.copyFrom(new Array[Byte](parameters.size)))))
    }
}


class TestServiceImpl(implicit sys: ActorSystem) extends TestService {
  import TestServiceImpl._

  implicit val mat: Materializer = SystemMaterializer(sys).materializer
  implicit val ec: ExecutionContext = sys.dispatcher

  override def emptyCall(req: Empty) =
    Future.successful(Empty())

  override def unaryCall(req: SimpleRequest): Future[SimpleResponse] =
    req.responseStatus match {
      case None =>
        Future.successful(SimpleResponse(Some(Payload(ByteString.copyFrom(new Array[Byte](req.responseSize))))))
      case Some(requestStatus) =>
        val responseStatus = Status.fromCodeValue(requestStatus.code).withDescription(requestStatus.message)
        //  - Either one of the following works
        // Future.failed(new GrpcServiceException(responseStatus))
        throw new GrpcServiceException(responseStatus)
    }

  override def cacheableUnaryCall(in: SimpleRequest): Future[SimpleResponse] = ???

  override def fullDuplexCall(
      in: Source[StreamingOutputCallRequest, NotUsed]): Source[StreamingOutputCallResponse, NotUsed] =
    in.map(req => {
      req.responseStatus.foreach(reqStatus =>
        throw new GrpcServiceException(Status.fromCodeValue(reqStatus.code).withDescription(reqStatus.message)))
      req
    }).mapConcat(_.responseParameters.toList)
      .via(parametersToResponseFlow)

  override def halfDuplexCall(
      in: Source[StreamingOutputCallRequest, NotUsed]): Source[StreamingOutputCallResponse, NotUsed] = ???

  override def streamingInputCall(in: Source[StreamingInputCallRequest, NotUsed]): Future[StreamingInputCallResponse] =
    in.map(_.payload.map(_.body.size).getOrElse(0)).runFold(0)(_ + _).map { sum => StreamingInputCallResponse(sum) }

  override def streamingOutputCall(in: StreamingOutputCallRequest): Source[StreamingOutputCallResponse, NotUsed] =
    Source(in.responseParameters.toList).via(parametersToResponseFlow)

  override def unimplementedCall(in: Empty): Future[Empty] = ???
} 
Example 76
Source File: TestCaseAllFields.scala    From protoless   with Apache License 2.0 5 votes vote down vote up
package io.protoless.tests.samples

import com.google.protobuf.ByteString

import io.protoless.tag._
import io.protoless.tests.samples.Schemas.Color

case class TestCaseAllFields(
  d: Double,
  f: Float,
  i: Int,
  l: Long,
  ui: Int @@ Unsigned,
  ul: Long @@ Unsigned,
  si: Int @@ Signed,
  sl: Long @@ Signed,
  fi: Int @@ Fixed,
  fl: Long @@ Fixed,
  sfi: Int @@ Signed with Fixed,
  sfl: Long @@ Signed with Fixed,
  b: Boolean,
  s: String,
  by: ByteString,
  c: Colors.Color
)

object TestCaseAllFields extends TestCase[TestCaseAllFields] {
  override val source: TestCaseAllFields = TestCaseAllFields(
    d = Double.MaxValue,
    f = Float.MaxValue,
    i = Int.MaxValue,
    l = Long.MaxValue,
    ui = unsigned(100),
    ul = unsigned(100L),
    si = signed(Int.MinValue),
    sl = signed(Long.MinValue),
    fi = fixed(Int.MaxValue),
    fl = fixed(Long.MaxValue),
    sfi = signedFixed(Int.MinValue),
    sfl = signedFixed(Long.MinValue),
    b = true,
    s = "Я тебя люблю",
    by = ByteString.copyFrom("Coucou", "utf8"),
    c = Colors.Green
  )

  override val protobuf: ProtoSerializable = ProtoSerializable(Schemas.Optional.newBuilder()
    .setDoubleField(source.d)
    .setFloatField(source.f)
    .setInt32Field(source.i)
    .setInt64Field(source.l)
    .setUint32Field(source.ui)
    .setUint64Field(source.ul)
    .setSint32Field(source.si)
    .setSint64Field(source.sl)
    .setFixed32Field(source.fi)
    .setFixed64Field(source.fl)
    .setSfixed32Field(source.sfi)
    .setSfixed64Field(source.sfl)
    .setBoolField(source.b)
    .setStringField(source.s)
    .setBytesField(source.by)
    .setColorField(Color.GREEN)
    .build())
} 
Example 77
Source File: AccountsApiGrpcImpl.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.api.grpc

import com.google.protobuf.ByteString
import com.google.protobuf.wrappers.{BytesValue, StringValue}
import com.wavesplatform.account.{Address, Alias}
import com.wavesplatform.api.common.CommonAccountsApi
import com.wavesplatform.common.state.ByteStr
import com.wavesplatform.protobuf.Amount
import com.wavesplatform.protobuf.transaction.PBTransactions
import com.wavesplatform.transaction.Asset
import io.grpc.stub.StreamObserver
import monix.execution.Scheduler
import monix.reactive.Observable

import scala.concurrent.Future

class AccountsApiGrpcImpl(commonApi: CommonAccountsApi)(implicit sc: Scheduler) extends AccountsApiGrpc.AccountsApi {

  private def loadWavesBalance(address: Address): BalanceResponse = {
    val details = commonApi.balanceDetails(address)
    BalanceResponse().withWaves(
      BalanceResponse.WavesBalances(
        details.regular,
        details.generating,
        details.available,
        details.effective,
        details.leaseIn,
        details.leaseOut
      )
    )
  }

  private def assetBalanceResponse(v: (Asset.IssuedAsset, Long)): BalanceResponse =
    BalanceResponse().withAsset(Amount(v._1.id.toPBByteString, v._2))

  override def getBalances(request: BalancesRequest, responseObserver: StreamObserver[BalanceResponse]): Unit = responseObserver.interceptErrors {
    val addressOption: Option[Address] = if (request.address.isEmpty) None else Some(request.address.toAddress)
    val assetIds: Seq[Asset]           = request.assets.map(id => if (id.isEmpty) Asset.Waves else Asset.IssuedAsset(ByteStr(id.toByteArray)))

    val responseStream = (addressOption, assetIds) match {
      case (Some(address), Seq()) =>
        Observable(loadWavesBalance(address)) ++ commonApi.portfolio(address).map(assetBalanceResponse)
      case (Some(address), nonEmptyList) =>
        Observable
          .fromIterable(nonEmptyList)
          .map {
            case Asset.Waves           => loadWavesBalance(address)
            case ia: Asset.IssuedAsset => assetBalanceResponse(ia -> commonApi.assetBalance(address, ia))
          }
      case (None, Seq(_)) => // todo: asset distribution
        Observable.empty
      case (None, _) => // multiple distributions are not supported
        Observable.empty
    }

    responseObserver.completeWith(responseStream)
  }

  override def getScript(request: AccountRequest): Future[ScriptData] = Future {
    commonApi.script(request.address.toAddress) match {
      case Some(desc) => ScriptData(PBTransactions.toPBScript(Some(desc.script)), desc.script.expr.toString, desc.verifierComplexity)
      case None       => ScriptData()
    }
  }

  override def getActiveLeases(request: AccountRequest, responseObserver: StreamObserver[TransactionResponse]): Unit =
    responseObserver.interceptErrors {
      val transactions = commonApi.activeLeases(request.address.toAddress)
      val result       = transactions.map { case (height, transaction) => TransactionResponse(transaction.id(), height, Some(transaction.toPB)) }
      responseObserver.completeWith(result)
    }

  override def getDataEntries(request: DataRequest, responseObserver: StreamObserver[DataEntryResponse]): Unit = responseObserver.interceptErrors {
    val stream = if (request.key.nonEmpty) {
      Observable.fromIterable(commonApi.data(request.address.toAddress, request.key))
    } else {
      commonApi.dataStream(request.address.toAddress, Option(request.key).filter(_.nonEmpty))
    }

    responseObserver.completeWith(stream.map(de => DataEntryResponse(request.address, Some(PBTransactions.toPBDataEntry(de)))))
  }

  override def resolveAlias(request: StringValue): Future[BytesValue] =
    Future {
      val result = for {
        alias   <- Alias.create(request.value)
        address <- commonApi.resolveAlias(alias)
      } yield BytesValue(ByteString.copyFrom(address.bytes))

      result.explicitGetErr()
    }
} 
Example 78
Source File: PBImplicitConversions.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.api.grpc

import com.google.protobuf.ByteString
import com.wavesplatform.account.{Address, AddressScheme, PublicKey}
import com.wavesplatform.block.BlockHeader
import com.wavesplatform.common.state.ByteStr
import com.wavesplatform.common.utils._
import com.wavesplatform.lang.ValidationError
import com.wavesplatform.protobuf.block.{PBBlock, PBBlocks, VanillaBlock}
import com.wavesplatform.protobuf.transaction._
import com.wavesplatform.{block => vb}

//noinspection ScalaStyle
trait PBImplicitConversions {
  implicit class VanillaTransactionConversions(tx: VanillaTransaction) {
    def toPB: PBSignedTransaction = PBTransactions.protobuf(tx)
  }

  implicit class PBSignedTransactionConversions(tx: PBSignedTransaction) {
    def toVanilla: Either[ValidationError, VanillaTransaction] = PBTransactions.vanilla(tx)
  }

  implicit class VanillaBlockConversions(block: VanillaBlock) {
    def toPB: PBBlock = PBBlocks.protobuf(block)
  }

  implicit class PBBlockHeaderConversionOps(header: PBBlock.Header) {
    def toVanilla(signature: ByteStr): vb.BlockHeader = {
      BlockHeader(
        header.version.toByte,
        header.timestamp,
        header.reference.toByteStr,
        header.baseTarget,
        header.generationSignature.toByteStr,
        header.generator.toPublicKey,
        header.featureVotes.map(intToShort),
        header.rewardVote,
        header.transactionsRoot.toByteStr
      )
    }
  }

  implicit class VanillaHeaderConversionOps(header: vb.BlockHeader) {
    def toPBHeader: PBBlock.Header = PBBlock.Header(
      0: Byte,
      header.reference.toPBByteString,
      header.baseTarget,
      header.generationSignature.toPBByteString,
      header.featureVotes.map(shortToInt),
      header.timestamp,
      header.version,
      header.generator,
      header.rewardVote,
      ByteString.copyFrom(header.transactionsRoot.arr)
    )
  }

  implicit class PBRecipientConversions(r: Recipient) {
    def toAddress        = PBRecipients.toAddress(r, AddressScheme.current.chainId).explicitGet()
    def toAddressOrAlias = PBRecipients.toAddressOrAlias(r, AddressScheme.current.chainId).explicitGet()
  }

  implicit class VanillaByteStrConversions(bytes: ByteStr) {
    def toPBByteString = ByteString.copyFrom(bytes.arr)
  }

  implicit class PBByteStringConversions(bytes: ByteString) {
    def toByteStr   = ByteStr(bytes.toByteArray)
    def toPublicKey = PublicKey(bytes.toByteArray)
    def toAddress: Address =
      PBRecipients.toAddress(bytes.toByteArray, AddressScheme.current.chainId).fold(ve => throw new IllegalArgumentException(ve.toString), identity)
  }

  implicit def vanillaByteStrToPBByteString(bs: ByteStr): ByteString = bs.toPBByteString
  implicit def pbByteStringToVanillaByteStr(bs: ByteString): ByteStr = bs.toByteStr

  private[this] implicit def shortToInt(s: Short): Int = {
    java.lang.Short.toUnsignedInt(s)
  }

  private[this] def intToShort(int: Int): Short = {
    require(int.isValidShort, s"Short overflow: $int")
    int.toShort
  }
} 
Example 79
Source File: BlockchainApiGrpcImpl.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.api.grpc

import com.google.protobuf.ByteString
import com.google.protobuf.empty.Empty
import com.wavesplatform.features.{BlockchainFeatureStatus, BlockchainFeatures}
import com.wavesplatform.settings.FeaturesSettings
import com.wavesplatform.state.Blockchain
import monix.execution.Scheduler

import scala.concurrent.Future

class BlockchainApiGrpcImpl(blockchain: Blockchain, featuresSettings: FeaturesSettings)(implicit sc: Scheduler)
    extends BlockchainApiGrpc.BlockchainApi {

  override def getActivationStatus(request: ActivationStatusRequest): Future[ActivationStatusResponse] = Future {
    val functionalitySettings = blockchain.settings.functionalitySettings

    ActivationStatusResponse(
      request.height,
      functionalitySettings.activationWindowSize(request.height),
      functionalitySettings.blocksForFeatureActivation(request.height),
      functionalitySettings.activationWindow(request.height).last,
      (blockchain.featureVotes(request.height).keySet ++
        blockchain.approvedFeatures.keySet ++
        BlockchainFeatures.implemented).toSeq.sorted.map(id => {
        val status = blockchain.featureStatus(id, request.height) match {
          case BlockchainFeatureStatus.Undefined => FeatureActivationStatus.BlockchainFeatureStatus.UNDEFINED
          case BlockchainFeatureStatus.Approved  => FeatureActivationStatus.BlockchainFeatureStatus.APPROVED
          case BlockchainFeatureStatus.Activated => FeatureActivationStatus.BlockchainFeatureStatus.ACTIVATED
        }

        FeatureActivationStatus(
          id,
          BlockchainFeatures.feature(id).fold("Unknown feature")(_.description),
          status,
          (BlockchainFeatures.implemented.contains(id), featuresSettings.supported.contains(id)) match {
            case (false, _) => FeatureActivationStatus.NodeFeatureStatus.NOT_IMPLEMENTED
            case (_, true)  => FeatureActivationStatus.NodeFeatureStatus.VOTED
            case _          => FeatureActivationStatus.NodeFeatureStatus.IMPLEMENTED
          },
          blockchain.featureActivationHeight(id).getOrElse(0),
          if (status.isUndefined) blockchain.featureVotes(request.height).getOrElse(id, 0) else 0
        )
      })
    )
  }

  override def getBaseTarget(request: Empty): Future[BaseTargetResponse] = Future {
    BaseTargetResponse(blockchain.lastBlockHeader.get.header.baseTarget)
  }

  override def getCumulativeScore(request: Empty): Future[ScoreResponse] = Future {
    ScoreResponse(ByteString.copyFrom(blockchain.score.toByteArray))
  }
} 
Example 80
Source File: RollbackBenchmark.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform

import java.io.File

import com.google.common.primitives.Ints
import com.google.protobuf.ByteString
import com.wavesplatform.account.{Address, AddressScheme, KeyPair}
import com.wavesplatform.block.Block
import com.wavesplatform.common.state.ByteStr
import com.wavesplatform.common.utils._
import com.wavesplatform.database.{LevelDBWriter, openDB}
import com.wavesplatform.protobuf.transaction.PBRecipients
import com.wavesplatform.state.{Diff, Portfolio}
import com.wavesplatform.transaction.Asset.IssuedAsset
import com.wavesplatform.transaction.assets.IssueTransaction
import com.wavesplatform.transaction.{GenesisTransaction, Proofs}
import com.wavesplatform.utils.{NTP, ScorexLogging}
import monix.reactive.Observer

object RollbackBenchmark extends ScorexLogging {
  def main(args: Array[String]): Unit = {
    val settings      = Application.loadApplicationConfig(Some(new File(args(0))))
    val db            = openDB(settings.dbSettings.directory)
    val time          = new NTP(settings.ntpServer)
    val levelDBWriter = LevelDBWriter(db, Observer.stopped, settings)

    val issuer = KeyPair(new Array[Byte](32))

    log.info("Generating addresses")

    val addresses = 1 to 18000 map { i =>
      PBRecipients.toAddress(Ints.toByteArray(i) ++ new Array[Byte](Address.HashLength - 4), AddressScheme.current.chainId).explicitGet()
    }

    log.info("Generating issued assets")

    val assets = 1 to 200 map { i =>
      IssueTransaction(
        1.toByte,
        issuer.publicKey,
        ByteString.copyFromUtf8("asset-" + i),
        ByteString.EMPTY,
        100000e2.toLong,
        2.toByte,
        false,
        None,
        1e8.toLong,
        time.getTimestamp(),
        Proofs(ByteStr(new Array[Byte](64))),
        AddressScheme.current.chainId
      )
    }

    log.info("Building genesis block")
    val genesisBlock = Block
      .buildAndSign(
        1.toByte,
        time.getTimestamp(),
        Block.GenesisReference,
        1000,
        Block.GenesisGenerationSignature,
        GenesisTransaction.create(issuer.publicKey.toAddress, 100000e8.toLong, time.getTimestamp()).explicitGet() +: assets,
        issuer,
        Seq.empty,
        -1
      )
      .explicitGet()

    val map = assets.map(it => IssuedAsset(it.id()) -> 1L).toMap
    val portfolios = for {
      address <- addresses
    } yield address -> Portfolio(assets = map)

    log.info("Appending genesis block")
    levelDBWriter.append(
      Diff.empty.copy(portfolios = portfolios.toMap),
      0,
      0,
      None,
      genesisBlock.header.generationSignature,
      genesisBlock
    )

    val nextBlock =
      Block
        .buildAndSign(2.toByte, time.getTimestamp(), genesisBlock.id(), 1000, Block.GenesisGenerationSignature, Seq.empty, issuer, Seq.empty, -1)
        .explicitGet()
    val nextDiff = Diff.empty.copy(portfolios = addresses.map(_ -> Portfolio(1, assets = Map(IssuedAsset(assets.head.id()) -> 1L))).toMap)

    log.info("Appending next block")
    levelDBWriter.append(nextDiff, 0, 0, None, ByteStr.empty, nextBlock)

    log.info("Rolling back")
    val start = System.nanoTime()
    levelDBWriter.rollbackTo(genesisBlock.id())
    val end = System.nanoTime()
    log.info(f"Rollback took ${(end - start) * 1e-6}%.3f ms")
    levelDBWriter.close()
  }
} 
Example 81
Source File: FunctionTypeMapper.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.lang.contract.meta

import cats.implicits._
import com.google.protobuf.ByteString
import com.wavesplatform.lang.v1.compiler.Types.FINAL
import com.wavesplatform.protobuf.dapp.DAppMeta
import com.wavesplatform.protobuf.dapp.DAppMeta.CallableFuncSignature

class FunctionTypeMapper(mapper: TypeBitMapper, version: MetaVersion) {
  def toProto(funcTypes: List[List[FINAL]]): Either[String, DAppMeta] =
    funcTypes
      .traverse(funcToProto)
      .map(DAppMeta(version.number, _))

  private def funcToProto(types: List[FINAL]): Either[String, CallableFuncSignature] =
    types
      .traverse(t => mapper.toIndex(t).map(_.toByte))
      .map(_.toArray)
      .map(ByteString.copyFrom)
      .map(CallableFuncSignature(_))

  def fromProto(meta: DAppMeta): Either[String, List[List[FINAL]]] =
    meta.funcs.toList.traverse(protoToFunc)

  private def protoToFunc(funcs: CallableFuncSignature): Either[String, List[FINAL]] =
    funcs.types.toByteArray.toList
      .traverse(b => mapper.fromIndex(b.toInt))

} 
Example 82
Source File: GrpcIntegrationSuiteWithThreeAddress.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.it

import com.google.protobuf.ByteString
import com.wavesplatform.account.{Address, KeyPair}
import com.wavesplatform.common.utils.EitherExt2
import com.wavesplatform.it.api.SyncGrpcApi._
import com.wavesplatform.it.util._
import com.wavesplatform.protobuf.transaction.{PBRecipients, PBTransactions, Recipient}
import com.wavesplatform.transaction.transfer.TransferTransaction
import com.wavesplatform.utils.ScorexLogging
import org.scalatest.concurrent.{IntegrationPatience, ScalaFutures}
import org.scalatest.{BeforeAndAfterAll, Matchers, RecoverMethods, Suite}

trait GrpcIntegrationSuiteWithThreeAddress
    extends BeforeAndAfterAll
    with Matchers
    with ScalaFutures
    with IntegrationPatience
    with RecoverMethods
    with IntegrationTestsScheme
    with Nodes
    with ScorexLogging {
  this: Suite =>

  def miner: Node    = nodes.head
  def notMiner: Node = nodes.last

  protected def sender: Node = miner

  protected lazy val firstAcc: KeyPair  = KeyPair("first_acc".getBytes("UTF-8"))
  protected lazy val secondAcc: KeyPair = KeyPair("second_acc".getBytes("UTF-8"))
  protected lazy val thirdAcc: KeyPair  = KeyPair("third_acc".getBytes("UTF-8"))

  protected lazy val firstAddress: ByteString  = PBRecipients.create(Address.fromPublicKey(firstAcc.publicKey)).getPublicKeyHash
  protected lazy val secondAddress: ByteString = PBRecipients.create(Address.fromPublicKey(secondAcc.publicKey)).getPublicKeyHash
  protected lazy val thirdAddress: ByteString  = PBRecipients.create(Address.fromPublicKey(thirdAcc.publicKey)).getPublicKeyHash

  abstract protected override def beforeAll(): Unit = {
    super.beforeAll()

    val defaultBalance: Long = 100.waves

    def dumpBalances(node: Node, accounts: Seq[ByteString], label: String): Unit = {
      accounts.foreach(acc => {
        val balance = miner.wavesBalance(acc).available
        val eff     = miner.wavesBalance(acc).effective

        val formatted = s"$acc: balance = $balance, effective = $eff"
        log.debug(s"$label account balance:\n$formatted")
      })
    }

    def waitForTxsToReachAllNodes(txIds: Seq[String]): Unit = {
      val txNodePairs = for {
        txId <- txIds
        node <- nodes
      } yield (node, txId)

      txNodePairs.foreach({ case (node, tx) => node.waitForTransaction(tx) })
    }

    def makeTransfers(accounts: Seq[ByteString]): Seq[String] = accounts.map { acc =>
      PBTransactions
        .vanilla(
          sender.broadcastTransfer(sender.keyPair, Recipient().withPublicKeyHash(acc), defaultBalance, sender.fee(TransferTransaction.typeId))
        )
        .explicitGet()
        .id()
        .toString
    }

    def correctStartBalancesFuture(): Unit = {
      nodes.foreach(n => n.waitForHeight(2))
      val accounts = Seq(firstAddress, secondAddress, thirdAddress)

      dumpBalances(sender, accounts, "initial")
      val txs = makeTransfers(accounts)

      val height = nodes.map(_.height).max

      withClue(s"waitForHeight(${height + 2})") {
        nodes.foreach(n => n.waitForHeight(height + 1))
        nodes.foreach(n => n.waitForHeight(height + 2))
      }

      withClue("waitForTxsToReachAllNodes") {
        waitForTxsToReachAllNodes(txs)
      }

      dumpBalances(sender, accounts, "after transfer")
      accounts.foreach(acc => miner.wavesBalance(acc).available shouldBe defaultBalance)
      accounts.foreach(acc => miner.wavesBalance(acc).effective shouldBe defaultBalance)
    }

    withClue("beforeAll") {
      correctStartBalancesFuture()
    }
  }
} 
Example 83
Source File: GetTransactionGrpcSuite.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.it.sync.grpc

import com.google.protobuf.ByteString
import com.wavesplatform.common.utils.{Base58, EitherExt2}
import com.wavesplatform.it.api.SyncGrpcApi._
import com.wavesplatform.it.sync._
import com.wavesplatform.protobuf.transaction.{PBRecipients, PBTransactions, Recipient}

class GetTransactionGrpcSuite extends GrpcBaseTransactionSuite {

  test("get transaction by sender, by recipient, by sender&recipient and id") {
    val txId = PBTransactions.vanilla(
      sender.broadcastTransfer(firstAcc, Recipient().withPublicKeyHash(secondAddress), transferAmount, minFee, waitForTx = true)
    ).explicitGet().id().toString
    val transactionBySenderAndId = sender.getTransaction(sender = firstAddress, id = txId).getTransaction
    val transactionByRecipientAndId = sender.getTransaction(recipient = Some(Recipient().withPublicKeyHash(secondAddress)), id = txId).getTransaction
    val transactionBySenderRecipientAndId = sender.getTransaction(sender = firstAddress, recipient = Some(Recipient().withPublicKeyHash(secondAddress)), id = txId).getTransaction

    transactionBySenderAndId.senderPublicKey shouldBe ByteString.copyFrom(Base58.decode(firstAcc.publicKey.toString))
    transactionByRecipientAndId.getTransfer.getRecipient shouldBe PBRecipients.create(secondAcc.toAddress)
    transactionBySenderRecipientAndId.senderPublicKey shouldBe ByteString.copyFrom(Base58.decode(firstAcc.publicKey.toString))
    transactionBySenderRecipientAndId.getTransfer.getRecipient shouldBe PBRecipients.create(secondAcc.toAddress)
  }

  test("get multiple transactions") {
    val txs = List.fill(10)(sender.broadcastTransfer(thirdAcc, Recipient().withPublicKeyHash(secondAddress), transferAmount / 10, minFee, waitForTx = true))
    val txsIds = txs.map(tx => PBTransactions.vanilla(tx).explicitGet().id().toString)

    val transactionsByIds = sender.getTransactionSeq(txsIds, sender = thirdAddress, recipient = Some(Recipient().withPublicKeyHash(secondAddress)))
    transactionsByIds.size shouldBe 10
    for(tx <- transactionsByIds) {
      tx.getTransaction.getTransaction.senderPublicKey shouldBe ByteString.copyFrom(thirdAcc.publicKey.arr)
      tx.getTransaction.getTransaction.getTransfer.getRecipient shouldBe PBRecipients.create(secondAcc.toAddress)
    }
  }
} 
Example 84
Source File: InvokeScriptErrorMsgGrpcSuite.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.it.sync.grpc

import com.google.protobuf.ByteString
import com.wavesplatform.common.utils.{Base58, EitherExt2}
import com.wavesplatform.it.api.SyncGrpcApi._
import com.wavesplatform.it.sync._
import com.wavesplatform.lang.v1.estimator.v2.ScriptEstimatorV2
import com.wavesplatform.protobuf.Amount
import com.wavesplatform.protobuf.transaction.{PBTransactions, Recipient}
import com.wavesplatform.transaction.smart.script.ScriptCompiler
import io.grpc.Status.Code

class InvokeScriptErrorMsgGrpcSuite extends GrpcBaseTransactionSuite {
  private val (contract, contractAddress) = (firstAcc, firstAddress)
  private val caller                      = secondAcc

  protected override def beforeAll(): Unit = {
    super.beforeAll()

    val scriptText =
      """
        |{-# STDLIB_VERSION 3 #-}
        |{-# CONTENT_TYPE DAPP #-}
        |
        |@Callable(inv)
        |func default() = {
        | let pmt = inv.payment.extract()
        | TransferSet([ScriptTransfer(inv.caller, 1, pmt.assetId),
        | ScriptTransfer(inv.caller, 1, pmt.assetId),
        | ScriptTransfer(inv.caller, 1, pmt.assetId),
        | ScriptTransfer(inv.caller, 1, pmt.assetId),
        | ScriptTransfer(inv.caller, 1, pmt.assetId),
        | ScriptTransfer(inv.caller, 1, pmt.assetId),
        | ScriptTransfer(inv.caller, 1, pmt.assetId),
        | ScriptTransfer(inv.caller, 1, pmt.assetId),
        | ScriptTransfer(inv.caller, 1, pmt.assetId),
        | ScriptTransfer(inv.caller, 1, pmt.assetId)])
        |}
        |""".stripMargin
    val contractScript = ScriptCompiler.compile(scriptText, ScriptEstimatorV2).explicitGet()._1
    sender.setScript(contract, Right(Some(contractScript)), setScriptFee, waitForTx = true)

    sender.setScript(caller, Right(Some(script)), setScriptFee, waitForTx = true)
  }

  test("cannot invoke script without having enough fee; error message is informative") {
    val asset1 = PBTransactions
      .vanilla(
        sender.broadcastIssue(
          caller,
          "ScriptedAsset",
          someAssetAmount,
          decimals = 0,
          reissuable = true,
          fee = issueFee + smartFee,
          script = Right(Some(script)),
          waitForTx = true
        )
      )
      .explicitGet()
      .id()
      .toString

    val payments = Seq(Amount.of(ByteString.copyFrom(Base58.decode(asset1)), 10))
    assertGrpcError(
      sender.broadcastInvokeScript(
        caller,
        Recipient().withPublicKeyHash(contractAddress),
        None,
        payments = payments,
        fee = 1000
      ),
      "Transaction sent from smart account. Requires 400000 extra fee. Transaction involves 1 scripted assets",
      Code.INVALID_ARGUMENT
    )

    val tx = sender.broadcastInvokeScript(
      caller,
      Recipient().withPublicKeyHash(contractAddress),
      None,
      payments = payments,
      fee = 1300000,
      waitForTx = true
    )

    sender
      .stateChanges(tx.id)
      ._2
      .error
      .get
      .text should include regex "Fee in WAVES for InvokeScriptTransaction .* with 12 total scripts invoked does not exceed minimal value"
  }

} 
Example 85
Source File: BlockV5GrpcSuite.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.it.sync.grpc

import com.google.protobuf.ByteString
import com.typesafe.config.Config
import com.wavesplatform.api.grpc.BlockRangeRequest
import com.wavesplatform.block.Block
import com.wavesplatform.common.state.ByteStr
import com.wavesplatform.crypto
import com.wavesplatform.it.api.SyncGrpcApi._
import com.wavesplatform.it.sync.activation.ActivationStatusRequest
import com.wavesplatform.it.transactions.NodesFromDocker
import com.wavesplatform.it.{GrpcIntegrationSuiteWithThreeAddress, NodeConfigs, ReportingTestName}
import org.scalatest.{CancelAfterFailure, FreeSpec, Matchers, OptionValues}

import scala.concurrent.duration._

class BlockV5GrpcSuite
    extends FreeSpec
    with Matchers
    with CancelAfterFailure
    with NodesFromDocker
    with ActivationStatusRequest
    with ReportingTestName
    with OptionValues
    with GrpcIntegrationSuiteWithThreeAddress {

  override def nodeConfigs: Seq[Config] =
    NodeConfigs.newBuilder
      .overrideBase(_.quorum(0))
      .withDefault(1)
      .withSpecial(1, _.nonMiner)
      .buildNonConflicting()

  "block v5 appears and blockchain grows" - {
    "when feature activation happened" in {
      sender.waitForHeight(sender.height + 1, 2.minutes)
      val currentHeight = sender.height

      val blockV5     = sender.blockAt(currentHeight)
      val blockV5ById = sender.blockById(ByteString.copyFrom(blockV5.id().arr))

      blockV5.header.version shouldBe Block.ProtoBlockVersion
      blockV5.id().arr.length shouldBe crypto.DigestLength
      blockV5.signature.arr.length shouldBe crypto.SignatureLength
      blockV5.header.generationSignature.arr.length shouldBe Block.GenerationVRFSignatureLength
      assert(blockV5.transactionsRootValid(), "transactionsRoot is not valid")
      blockV5ById.header.version shouldBe Block.ProtoBlockVersion
      blockV5ById.header.generationSignature.arr.length shouldBe Block.GenerationVRFSignatureLength
      assert(blockV5ById.transactionsRootValid(), "transactionsRoot is not valid")

      sender.waitForHeight(currentHeight + 1, 2.minutes)

      val blockAfterVRFUsing     = sender.blockAt(currentHeight + 1)
      val blockAfterVRFUsingById = sender.blockById(ByteString.copyFrom(blockAfterVRFUsing.id().arr))

      blockAfterVRFUsing.header.version shouldBe Block.ProtoBlockVersion
      blockAfterVRFUsing.header.generationSignature.arr.length shouldBe Block.GenerationVRFSignatureLength
      ByteStr(sender.blockHeaderAt(currentHeight + 1).reference.toByteArray) shouldBe blockV5.id()
      blockAfterVRFUsingById.header.version shouldBe Block.ProtoBlockVersion
      blockAfterVRFUsingById.header.generationSignature.arr.length shouldBe Block.GenerationVRFSignatureLength
      assert(blockAfterVRFUsingById.transactionsRootValid(), "transactionsRoot is not valid")

      val blockSeqOfBlocksV5 = sender.blockSeq(currentHeight, currentHeight + 2)

      for (blockV5 <- blockSeqOfBlocksV5) {
        blockV5.header.version shouldBe Block.ProtoBlockVersion
        blockV5.header.generationSignature.arr.length shouldBe Block.GenerationVRFSignatureLength
        assert(blockV5.transactionsRootValid(), "transactionsRoot is not valid")
      }

      val blockSeqOfBlocksV5ByAddress = sender.blockSeqByAddress(miner.address, currentHeight, currentHeight + 2)

      for (blockV5 <- blockSeqOfBlocksV5ByAddress) {
        blockV5.header.generator shouldBe miner.keyPair.publicKey
        blockV5.header.version shouldBe Block.ProtoBlockVersion
        blockV5.header.generationSignature.arr.length shouldBe Block.GenerationVRFSignatureLength
        assert(blockV5.transactionsRootValid(), "transactionsRoot is not valid")
      }

      val blockSeqOfBlocksV5ByPKGrpc = NodeExtGrpc(sender).blockSeq(
        currentHeight,
        currentHeight + 2,
        BlockRangeRequest.Filter.GeneratorPublicKey(ByteString.copyFrom(miner.keyPair.publicKey.arr))
      )

      for (blockV5 <- blockSeqOfBlocksV5ByPKGrpc) {
        blockV5.header.generator shouldBe miner.keyPair.publicKey
        blockV5.header.version shouldBe Block.ProtoBlockVersion
        blockV5.header.generationSignature.arr.length shouldBe Block.GenerationVRFSignatureLength
        assert(blockV5.transactionsRootValid(), "transactionsRoot is not valid")
      }
    }
  }
} 
Example 86
Source File: PBAmounts.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.protobuf.transaction
import com.google.protobuf.ByteString
import com.wavesplatform.protobuf.Amount
import com.wavesplatform.transaction.Asset
import com.wavesplatform.transaction.Asset.{IssuedAsset, Waves}
import com.wavesplatform.protobuf.utils.PBImplicitConversions._

object PBAmounts {
  def toPBAssetId(asset: Asset): ByteString = asset match {
    case Asset.IssuedAsset(id) => id.toByteString
    case Asset.Waves           => ByteString.EMPTY
  }

  def toVanillaAssetId(byteStr: ByteString): Asset = {
    if (byteStr.isEmpty) Waves
    else IssuedAsset(byteStr.toByteStr)
  }

  def fromAssetAndAmount(asset: Asset, amount: Long): Amount =
    Amount(toPBAssetId(asset), amount)

  def toAssetAndAmount(value: Amount): (Asset, Long) =
    (toVanillaAssetId(value.assetId), value.amount)
} 
Example 87
Source File: PBRecipients.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.protobuf.transaction
import com.google.common.primitives.Bytes
import com.google.protobuf.ByteString
import com.wavesplatform.account._
import com.wavesplatform.crypto
import com.wavesplatform.lang.ValidationError
import com.wavesplatform.transaction.TxValidationError.GenericError

object PBRecipients {
  def create(addressOrAlias: AddressOrAlias): Recipient = addressOrAlias match {
    case a: Address => Recipient().withPublicKeyHash(ByteString.copyFrom(publicKeyHash(a)))
    case a: Alias   => Recipient().withAlias(a.name)
    case _          => sys.error("Should not happen " + addressOrAlias)
  }

  def toAddress(bytes: Array[Byte], chainId: Byte): Either[ValidationError, Address] = bytes.length match {
    case Address.HashLength => // Compressed address
      val withHeader = Bytes.concat(Array(Address.AddressVersion, chainId), bytes)
      val checksum   = Address.calcCheckSum(withHeader)
      Address.fromBytes(Bytes.concat(withHeader, checksum), chainId)

    case Address.AddressLength => // Regular address
      Address.fromBytes(bytes, chainId)

    case crypto.KeyLength => // Public key
      Right(PublicKey(bytes).toAddress(chainId))

    case _ =>
      Left(GenericError(s"Invalid address length: ${bytes.length}"))
  }

  def toAddress(r: Recipient, chainId: Byte): Either[ValidationError, Address] = r.recipient match {
    case Recipient.Recipient.PublicKeyHash(bytes) => toAddress(bytes.toByteArray, chainId)
    case _                                        => Left(GenericError(s"Not an address: $r"))
  }

  def toAlias(r: Recipient, chainId: Byte): Either[ValidationError, Alias] = r.recipient match {
    case Recipient.Recipient.Alias(alias) => Alias.createWithChainId(alias, chainId)
    case _                                => Left(GenericError(s"Not an alias: $r"))
  }

  def toAddressOrAlias(r: Recipient, chainId: Byte): Either[ValidationError, AddressOrAlias] = {
    if (r.recipient.isPublicKeyHash) toAddress(r, chainId)
    else if (r.recipient.isAlias) toAlias(r, chainId)
    else Left(GenericError(s"Not an address or alias: $r"))
  }

  @inline
  final def publicKeyHash(address: Address): Array[Byte] =
    address.bytes.slice(2, address.bytes.length - Address.ChecksumLength)
} 
Example 88
Source File: PBImplicitConversions.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.protobuf.utils
import com.google.protobuf.ByteString
import com.wavesplatform.account.PublicKey
import com.wavesplatform.common.state.ByteStr
import com.wavesplatform.lang.ValidationError
import com.wavesplatform.protobuf.Amount
import com.wavesplatform.protobuf.transaction._
import com.wavesplatform.transaction.Asset
import com.wavesplatform.transaction.Asset.{IssuedAsset, Waves}

object PBImplicitConversions {
  import com.google.protobuf.{ByteString => PBByteString}
  import com.wavesplatform.account.{AddressOrAlias, Address => VAddress, Alias => VAlias}

  implicit def fromAddressOrAlias(addressOrAlias: AddressOrAlias): Recipient = PBRecipients.create(addressOrAlias)
  implicit def fromAddress(address: VAddress): PBByteString                  = PBByteString.copyFrom(address.bytes)

  implicit class PBRecipientImplicitConversionOps(recipient: Recipient) {
    def toAddress(chainId: Byte): Either[ValidationError, VAddress]              = PBRecipients.toAddress(recipient, chainId)
    def toAlias(chainId: Byte): Either[ValidationError, VAlias]                  = PBRecipients.toAlias(recipient, chainId)
    def toAddressOrAlias(chainId: Byte): Either[ValidationError, AddressOrAlias] = PBRecipients.toAddressOrAlias(recipient, chainId)
  }

  implicit class ByteStrExt(val bs: ByteStr) extends AnyVal {
    def toByteString: PBByteString = ByteString.copyFrom(bs.arr)
  }

  implicit class ByteStringExt(val bs: ByteString) extends AnyVal {
    def toByteStr: ByteStr = ByteStr(bs.toByteArray)
  }

  implicit def fromAssetIdAndAmount(v: (VanillaAssetId, Long)): Amount = v match {
    case (IssuedAsset(assetId), amount) =>
      Amount()
        .withAssetId(assetId.toByteString)
        .withAmount(amount)

    case (Waves, amount) =>
      Amount().withAmount(amount)
  }

  implicit class AmountImplicitConversions(a: Amount) {
    def longAmount: Long      = a.amount
    def vanillaAssetId: Asset = PBAmounts.toVanillaAssetId(a.assetId)
  }

  implicit class PBByteStringOps(bs: PBByteString) {
    def byteStr: ByteStr            = ByteStr(bs.toByteArray)
    def publicKeyAccount: PublicKey = PublicKey(bs.toByteArray)
  }

  implicit def byteStringToByte(bytes: ByteString): Byte =
    if (bytes.isEmpty) 0
    else bytes.byteAt(0)

  implicit def byteToByteString(chainId: Byte): ByteString = {
    if (chainId == 0) ByteString.EMPTY else ByteString.copyFrom(Array(chainId))
  }
} 
Example 89
Source File: PBBlocks.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.protobuf.block

import com.google.protobuf.ByteString
import com.wavesplatform.account.{AddressScheme, PublicKey}
import com.wavesplatform.block.BlockHeader
import com.wavesplatform.common.state.ByteStr
import com.wavesplatform.common.utils.EitherExt2
import com.wavesplatform.protobuf.block.Block.{Header => PBHeader}
import com.wavesplatform.protobuf.transaction.PBTransactions

import scala.util.Try

object PBBlocks {
  def vanilla(header: PBBlock.Header): BlockHeader =
    BlockHeader(
      header.version.toByte,
      header.timestamp,
      ByteStr(header.reference.toByteArray),
      header.baseTarget,
      ByteStr(header.generationSignature.toByteArray),
      PublicKey(header.generator.toByteArray),
      header.featureVotes.map(_.toShort),
      header.rewardVote,
      ByteStr(header.transactionsRoot.toByteArray)
    )

  def vanilla(block: PBBlock, unsafe: Boolean = false): Try[VanillaBlock] = Try {
    require(block.header.isDefined, "block header is missing")
    val header       = block.getHeader
    val transactions = block.transactions.map(PBTransactions.vanilla(_, unsafe).explicitGet())

    VanillaBlock(vanilla(header), ByteStr(block.signature.toByteArray), transactions)
  }

  def protobuf(header: BlockHeader): PBHeader = PBBlock.Header(
    AddressScheme.current.chainId,
    ByteString.copyFrom(header.reference.arr),
    header.baseTarget,
    ByteString.copyFrom(header.generationSignature.arr),
    header.featureVotes.map(_.toInt),
    header.timestamp,
    header.version,
    ByteString.copyFrom(header.generator.arr),
    header.rewardVote,
    ByteString.copyFrom(header.transactionsRoot.arr)
  )

  def protobuf(block: VanillaBlock): PBBlock = {
    import block._

    new PBBlock(
      Some(protobuf(header)),
      ByteString.copyFrom(block.signature.arr),
      transactionData.map(PBTransactions.protobuf)
    )
  }

  def clearChainId(block: PBBlock): PBBlock = {
    block.update(
      _.header.chainId := 0,
      _.transactions.foreach(_.transaction.chainId := 0)
    )
  }

  def addChainId(block: PBBlock): PBBlock = {
    val chainId = AddressScheme.current.chainId

    block.update(
      _.header.chainId := chainId,
      _.transactions.foreach(_.transaction.chainId := chainId)
    )
  }
} 
Example 90
Source File: PBBlockHeaders.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.protobuf.block

import com.google.protobuf.ByteString
import com.wavesplatform.account.{AddressScheme, PublicKey}
import com.wavesplatform.block.BlockHeader
import com.wavesplatform.common.state.ByteStr

object PBBlockHeaders {
  def protobuf(header: VanillaBlockHeader): PBBlockHeader = {
    import header._

    PBBlock.Header(
      AddressScheme.current.chainId,
      ByteString.copyFrom(reference.arr),
      baseTarget,
      ByteString.copyFrom(generationSignature.arr),
      header.featureVotes.map(shortToInt),
      header.timestamp,
      header.version,
      ByteString.copyFrom(generator.arr),
      header.rewardVote,
      ByteString.copyFrom(header.transactionsRoot.arr)
    )
  }

  def vanilla(header: PBBlockHeader): VanillaBlockHeader =
    BlockHeader(
      header.version.toByte,
      header.timestamp,
      ByteStr(header.reference.toByteArray),
      header.baseTarget,
      ByteStr(header.generationSignature.toByteArray),
      PublicKey(header.generator.toByteArray),
      header.featureVotes.map(intToShort),
      header.rewardVote,
      ByteStr(header.transactionsRoot.toByteArray)
    )

  private[this] def shortToInt(s: Short): Int = {
    java.lang.Short.toUnsignedInt(s)
  }

  private[this] def intToShort(int: Int): Short = {
    require(int >= 0 && int <= 65535, s"Short overflow: $int")
    int.toShort
  }
} 
Example 91
Source File: package.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform

import java.security.SecureRandom

import com.google.common.base.Charsets
import com.google.protobuf.ByteString
import com.wavesplatform.common.state.ByteStr
import com.wavesplatform.common.state.ByteStr._
import com.wavesplatform.common.utils.Base58
import org.apache.commons.lang3.time.DurationFormatUtils
import play.api.libs.json._

import scala.annotation.tailrec

package object utils extends ScorexLogging {

  private val BytesMaxValue  = 256
  private val Base58MaxValue = 58

  private val BytesLog = math.log(BytesMaxValue)
  private val BaseLog  = math.log(Base58MaxValue)

  def base58Length(byteArrayLength: Int): Int = math.ceil(BytesLog / BaseLog * byteArrayLength).toInt

  def forceStopApplication(reason: ApplicationStopReason = Default): Unit =
    System.exit(reason.code)

  def humanReadableSize(bytes: Long, si: Boolean = true): String = {
    val (baseValue, unitStrings) =
      if (si)
        (1000, Vector("B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"))
      else
        (1024, Vector("B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"))

    @tailrec
    def getExponent(curBytes: Long, baseValue: Int, curExponent: Int = 0): Int =
      if (curBytes < baseValue) curExponent
      else {
        val newExponent = 1 + curExponent
        getExponent(curBytes / (baseValue * newExponent), baseValue, newExponent)
      }

    val exponent   = getExponent(bytes, baseValue)
    val divisor    = Math.pow(baseValue, exponent)
    val unitString = unitStrings(exponent)

    f"${bytes / divisor}%.1f $unitString"
  }

  def humanReadableDuration(duration: Long): String =
    DurationFormatUtils.formatDurationHMS(duration)

  implicit class Tap[A](a: A) {
    def tap(g: A => Unit): A = {
      g(a)
      a
    }
  }

  def randomBytes(howMany: Int = 32): Array[Byte] = {
    val r = new Array[Byte](howMany)
    new SecureRandom().nextBytes(r) //overrides r
    r
  }

  implicit val byteStrFormat: Format[ByteStr] = new Format[ByteStr] {
    override def writes(o: ByteStr): JsValue = JsString(o.toString)

    override def reads(json: JsValue): JsResult[ByteStr] = json match {
      case JsString(v) if v.startsWith("base64:") =>
        decodeBase64(v.substring(7)).fold(e => JsError(s"Error parsing base64: ${e.getMessage}"), b => JsSuccess(b))
      case JsString(v) if v.length > Base58.defaultDecodeLimit => JsError(s"Length ${v.length} exceeds maximum length of 192")
      case JsString(v)                                         => decodeBase58(v).fold(e => JsError(s"Error parsing base58: ${e.getMessage}"), b => JsSuccess(b))
      case _                                                   => JsError("Expected JsString")
    }
  }

  implicit class StringBytes(val s: String) extends AnyVal {
    def utf8Bytes: Array[Byte]   = s.getBytes(Charsets.UTF_8)
    def toByteString: ByteString = ByteString.copyFromUtf8(s)
  }
} 
Example 92
Source File: BasicMessagesRepoSpec.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.network

import java.io.ByteArrayOutputStream

import com.google.protobuf.{ByteString, CodedOutputStream, WireFormat}
import com.wavesplatform.TransactionGen
import com.wavesplatform.common.state.ByteStr
import com.wavesplatform.common.utils.EitherExt2
import com.wavesplatform.mining.MiningConstraints
import com.wavesplatform.protobuf.block._
import com.wavesplatform.protobuf.transaction._
import com.wavesplatform.transaction.Asset.IssuedAsset
import com.wavesplatform.transaction.smart.SetScriptTransaction
import com.wavesplatform.transaction.{DataTransaction, Proofs, TxVersion}
import org.scalatest._

class BasicMessagesRepoSpec extends FreeSpec with Matchers with TransactionGen {
  "PBBlockSpec max length" in {
    val maxSizedHeader = PBBlock.Header(
      Byte.MaxValue,
      ByteString.copyFrom(bytes64gen.sample.get),
      Long.MaxValue,
      ByteString.copyFrom(byteArrayGen(VanillaBlock.GenerationVRFSignatureLength).sample.get),
      Seq.fill(VanillaBlock.MaxFeaturesInBlock)(Short.MaxValue),
      Long.MaxValue,
      Byte.MaxValue,
      ByteString.copyFrom(bytes32gen.sample.get),
      Long.MaxValue,
      ByteString.copyFrom(bytes32gen.sample.get)
    )
    val maxSignature = ByteString.copyFrom(bytes64gen.sample.get)

    val headerSize    = maxSizedHeader.serializedSize
    val signatureSize = maxSignature.toByteArray.length

    val headerPBPrefix      = new ByteArrayOutputStream()
    val codedHeaderPBPrefix = CodedOutputStream.newInstance(headerPBPrefix)
    codedHeaderPBPrefix.writeTag(PBBlock.HEADER_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED)
    codedHeaderPBPrefix.writeUInt32NoTag(headerSize)
    codedHeaderPBPrefix.flush()

    val signaturePBPrefix      = new ByteArrayOutputStream()
    val codedSignaturePBPrefix = CodedOutputStream.newInstance(signaturePBPrefix)
    codedSignaturePBPrefix.writeTag(PBBlock.SIGNATURE_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED)
    codedSignaturePBPrefix.writeUInt32NoTag(maxSignature.toByteArray.length)
    codedSignaturePBPrefix.flush()

    val transactionPBPrefix               = new ByteArrayOutputStream()
    val codedTransactionMaxLengthPBPrefix = CodedOutputStream.newInstance(transactionPBPrefix)
    codedTransactionMaxLengthPBPrefix.writeTag(PBBlock.TRANSACTIONS_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED)
    codedTransactionMaxLengthPBPrefix.writeUInt32NoTag(MiningConstraints.MaxTxsSizeInBytes)
    codedTransactionMaxLengthPBPrefix.flush()

    val minPossibleTransactionSize = PBTransactions
      .protobuf(
        SetScriptTransaction
          .selfSigned(
            TxVersion.V2,
            accountGen.sample.get,
            None,
            1L,
            0L
          )
          .explicitGet()
      )
      .serializedSize

    val maxSize =
      headerPBPrefix.toByteArray.length + headerSize +
        signaturePBPrefix.toByteArray.length + signatureSize +
        MiningConstraints.MaxTxsSizeInBytes +
        (transactionPBPrefix.toByteArray.length * MiningConstraints.MaxTxsSizeInBytes / minPossibleTransactionSize)

    maxSize should be <= PBBlockSpec.maxLength
  }

  "PBTransactionSpec max length" in {
    val maxSizeTransaction = PBSignedTransaction(
      Some(
        PBTransaction(
          Byte.MaxValue,
          ByteString.copyFrom(bytes32gen.sample.get),
          Some(PBAmounts.fromAssetAndAmount(IssuedAsset(ByteStr(bytes32gen.sample.get)), Long.MaxValue)),
          Long.MaxValue,
          Byte.MaxValue
        )
      ),
      Seq.fill(Proofs.MaxProofs)(ByteString.copyFrom(byteArrayGen(Proofs.MaxProofSize).sample.get))
    )

    val dataPBPrefix      = new ByteArrayOutputStream()
    val codedDataPBPrefix = CodedOutputStream.newInstance(dataPBPrefix)
    codedDataPBPrefix.writeTag(Transaction.DATA_TRANSACTION_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED)
    codedDataPBPrefix.writeUInt32NoTag(DataTransaction.MaxProtoBytes)
    codedDataPBPrefix.flush()

    val size = maxSizeTransaction.serializedSize + dataPBPrefix.toByteArray.length + DataTransaction.MaxProtoBytes

    size should be <= PBTransactionSpec.maxLength
  }
} 
Example 93
Source File: AudioReceiveFrameTimingBench.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.impl.jni

import java.nio.ByteBuffer

import com.google.protobuf.ByteString
import im.tox.tox4j.av.callbacks.ToxAvEventAdapter
import im.tox.tox4j.av.data._
import im.tox.tox4j.av.proto.{ AvEvents, AudioReceiveFrame }
import im.tox.tox4j.bench.TimingReport
import im.tox.tox4j.bench.ToxBenchBase._

final class AudioReceiveFrameTimingBench extends TimingReport {

  timing.of[AudioReceiveFrame] {

    val audioLength = AudioLength.Length60
    val channels = AudioChannels.Mono
    val samplingRate = SamplingRate.Rate48k

    val frame = AudioReceiveFrame(0, ByteString.copyFrom(ByteBuffer.wrap(Array.ofDim[Byte](
      SampleCount(audioLength, samplingRate).value * channels.value * 2
    ))))

    val frames = range("frames")(10000).map { count =>
      AvEvents(audioReceiveFrame = (0 until count) map (_ => frame)).toByteArray
    }

    val handler = new ToxAvEventAdapter[Unit] with Serializable

    performance of "60ms per frame at 48k" in {
      using(frames) in { eventData =>
        ToxAvEventDispatch.dispatch(handler, eventData)(())
      }
    }

  }

} 
Example 94
Source File: VideoReceiveFrameTimingBench.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.impl.jni

import com.google.protobuf.ByteString
import im.tox.tox4j.av.callbacks.ToxAvEventAdapter
import im.tox.tox4j.av.data._
import im.tox.tox4j.av.proto.{ AvEvents, VideoReceiveFrame }
import im.tox.tox4j.bench.TimingReport
import im.tox.tox4j.bench.ToxBenchBase._

final class VideoReceiveFrameTimingBench extends TimingReport {

  timing.of[VideoReceiveFrame] {

    val width = 100
    val height = 100
    val y = ByteString.copyFrom(Array.ofDim[Byte](width * height))
    val uv = ByteString.copyFrom(Array.ofDim[Byte](width * height / 4))
    val frame = VideoReceiveFrame(0, width, height, y, uv, uv, width, width, width)

    val frames = range("frames")(10000).map { count =>
      AvEvents(videoReceiveFrame = (0 until count) map (_ => frame)).toByteArray
    }

    val nonCachingHandler = new ToxAvEventAdapter[Unit] with Serializable

    val cachingHandler = new ToxAvEventAdapter[Unit] with Serializable {
      private val cache = Some((
        Array.ofDim[Byte](width * height),
        Array.ofDim[Byte](width * height / 4),
        Array.ofDim[Byte](width * height / 4)
      ))

      override def videoFrameCachedYUV(
        height: Height,
        yStride: Int,
        uStride: Int,
        vStride: Int
      ): Option[(Array[Byte], Array[Byte], Array[Byte])] = {
        cache
      }
    }

    performance of s"${width}x$height" in {
      using(frames) in { eventData =>
        ToxAvEventDispatch.dispatch(nonCachingHandler, eventData)(())
      }
    }

    performance of s"${width}x$height (cached)" in {
      using(frames) in { eventData =>
        ToxAvEventDispatch.dispatch(cachingHandler, eventData)(())
      }
    }

  }

} 
Example 95
Source File: FeatureEncoder.scala    From ecosystem   with Apache License 2.0 5 votes vote down vote up
package org.tensorflow.spark.datasources.tfrecords.serde

import org.tensorflow.example._
import com.google.protobuf.ByteString

trait FeatureEncoder[T] {
  
object BytesListFeatureEncoder extends FeatureEncoder[Seq[Array[Byte]]] {
  override def encode(value: Seq[Array[Byte]]): Feature = {
    val bytesListBuilder = BytesList.newBuilder()
    value.foreach {x =>
      bytesListBuilder.addValue(ByteString.copyFrom(x))
    }
    val bytesList = bytesListBuilder.build()
    Feature.newBuilder().setBytesList(bytesList).build()
  }
} 
Example 96
Source File: MagnolifyTensorFlowExampleTest.scala    From scio   with Apache License 2.0 5 votes vote down vote up
package com.spotify.scio.examples.extra

import com.google.protobuf.ByteString
import com.spotify.scio.io._
import com.spotify.scio.tensorflow.TFRecordIO
import com.spotify.scio.testing._
import org.tensorflow.example._

class MagnolifyTensorFlowExampleTest extends PipelineSpec {
  val textIn = Seq("a b c d e", "a b a b")
  val wordCount = Seq(("a", 3L), ("b", 3L), ("c", 1L), ("d", 1L), ("e", 1L))
  val examples = wordCount.map { kv =>
    Example
      .newBuilder()
      .setFeatures(
        Features
          .newBuilder()
          .putFeature(
            "word",
            Feature
              .newBuilder()
              .setBytesList(BytesList.newBuilder().addValue(ByteString.copyFromUtf8(kv._1)))
              .build()
          )
          .putFeature(
            "count",
            Feature
              .newBuilder()
              .setInt64List(Int64List.newBuilder().addValue(kv._2))
              .build()
          )
      )
      .build()
  }
  val textOut = wordCount.map(kv => kv._1 + ": " + kv._2)

  "MagnolifyTensorFlowWriteExample" should "work" in {
    JobTest[com.spotify.scio.examples.extra.MagnolifyTensorFlowWriteExample.type]
      .args("--input=in.txt", "--output=wc.tfrecords")
      .input(TextIO("in.txt"), textIn)
      .output(TFRecordIO("wc.tfrecords")) {
        _.map(Example.parseFrom) should containInAnyOrder(examples)
      }
      .run()
  }

  "MagnolifyTensorFlowReadExample" should "work" in {
    JobTest[com.spotify.scio.examples.extra.Magnolify.type]
      .args("--input=wc.tfrecords", "--output=out.txt")
      .input(TFRecordIO("wc.tfrecords"), examples.map(_.toByteArray))
      .output(TextIO("out.txt"))(coll => coll should containInAnyOrder(textOut))
      .run()
  }
} 
Example 97
Source File: BigtableExampleTest.scala    From scio   with Apache License 2.0 5 votes vote down vote up
package com.spotify.scio.examples.extra

import com.google.bigtable.v2.{Mutation, Row}
import com.google.protobuf.ByteString
import com.spotify.scio.bigtable._
import com.spotify.scio.io._
import com.spotify.scio.testing._

class BigtableExampleTest extends PipelineSpec {
  import BigtableExample._

  val bigtableOptions = Seq(
    "--bigtableProjectId=my-project",
    "--bigtableInstanceId=my-instance",
    "--bigtableTableId=my-table"
  )

  val textIn = Seq("a b c d e", "a b a b")
  val wordCount = Seq(("a", 3L), ("b", 3L), ("c", 1L), ("d", 1L), ("e", 1L))
  val expectedMutations =
    wordCount.map(kv => BigtableExample.toMutation(kv._1, kv._2))

  "BigtableV1WriteExample" should "work" in {
    JobTest[com.spotify.scio.examples.extra.BigtableWriteExample.type]
      .args(bigtableOptions :+ "--input=in.txt": _*)
      .input(TextIO("in.txt"), textIn)
      // format: off
      .output(BigtableIO[(ByteString, Iterable[Mutation])](
        "my-project", "my-instance", "my-table")) {
        _ should containInAnyOrder(expectedMutations)
      }
      // format: on
      .run()
  }

  def toRow(key: String, value: Long): Row =
    Rows.newRow(
      ByteString.copyFromUtf8(key),
      FAMILY_NAME,
      COLUMN_QUALIFIER,
      ByteString.copyFromUtf8(value.toString)
    )

  val rowsIn = wordCount.map(kv => toRow(kv._1, kv._2))
  val expectedText = wordCount.map(kv => kv._1 + ": " + kv._2)

  "BigtableReadExample" should "work" in {
    JobTest[com.spotify.scio.examples.extra.BigtableReadExample.type]
      .args(bigtableOptions :+ "--output=out.txt": _*)
      .input(BigtableIO("my-project", "my-instance", "my-table"), rowsIn)
      .output(TextIO("out.txt"))(coll => coll should containInAnyOrder(expectedText))
      .run()
  }
} 
Example 98
Source File: BigQueryIT.scala    From scio   with Apache License 2.0 5 votes vote down vote up
package com.spotify.scio.extra.bigquery

import java.{util => ju}

import com.google.protobuf.ByteString
import com.spotify.scio.avro.types.AvroType
import com.spotify.scio.bigquery.client.BigQuery
import com.spotify.scio.bigquery.Table
import com.spotify.scio.bigquery.TableRow
import com.spotify.scio.coders._
import com.spotify.scio.ContextAndArgs
import org.apache.avro.generic.GenericRecord
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryAvroUtilsWrapper
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition
import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

object BigQueryIT {
  @AvroType.fromSchema("""{
      | "type":"record",
      | "name":"Account",
      | "namespace":"com.spotify.scio.avro",
      | "doc":"Record for an account",
      | "fields":[
      |   {"name":"id","type":"long"},
      |   {"name":"type","type":"string"},
      |   {"name":"name","type":"string"},
      |   {"name":"amount","type":"double"},
      |   {"name":"secret","type":"bytes"}]}
    """.stripMargin)
  class Account

  implicit def genericCoder = Coder.avroGenericRecordCoder(Account.schema)

}

final class BigQueryIT extends AnyFlatSpec with Matchers {
  import BigQueryIT._

  it should "save avro to BigQuery" in {
    val args = Array(
      "--project=data-integration-test",
      "--tempLocation=gs://data-integration-test-eu/temp"
    )
    val (sc, _) = ContextAndArgs(args)
    val prefix = ju.UUID.randomUUID().toString.replaceAll("-", "")
    val table = Table.Spec(s"data-integration-test:bigquery_avro_it.${prefix}_accounts")

    val data: Seq[GenericRecord] = (1 to 100).map { i =>
      Account.toGenericRecord(
        Account(i, "checking", s"account$i", i.toDouble, ByteString.copyFromUtf8("%20cフーバー"))
      )
    }

    val tap = sc
      .parallelize(data)
      .saveAvroAsBigQuery(
        table.ref,
        Account.schema,
        writeDisposition = WriteDisposition.WRITE_EMPTY,
        createDisposition = CreateDisposition.CREATE_IF_NEEDED
      )

    val result = sc.run().waitUntilDone()

    val ts = BigQuery.defaultInstance().tables.schema(table.ref)
    val expected: Seq[TableRow] = data.map { gr =>
      BigQueryAvroUtilsWrapper.convertGenericRecordToTableRow(gr, ts)
    }

    result.tap(tap).value.toSet shouldEqual expected.toSet
  }

} 
Example 99
Source File: Schemas.scala    From scio   with Apache License 2.0 5 votes vote down vote up
package com.spotify.scio.bigquery.types

import com.google.protobuf.ByteString
import org.joda.time.{Instant, LocalDate, LocalDateTime, LocalTime}

object Schemas {
  // primitives
  case class Required(
    boolF: Boolean,
    intF: Int,
    longF: Long,
    floatF: Float,
    doubleF: Double,
    stringF: String,
    byteArrayF: Array[Byte],
    byteStringF: ByteString,
    timestampF: Instant,
    dateF: LocalDate,
    timeF: LocalTime,
    datetimeF: LocalDateTime,
    bigDecimalF: BigDecimal,
    geographyF: Geography
  )
  case class Optional(
    boolF: Option[Boolean],
    intF: Option[Int],
    longF: Option[Long],
    floatF: Option[Float],
    doubleF: Option[Double],
    stringF: Option[String],
    byteArrayF: Option[Array[Byte]],
    byteStringF: Option[ByteString],
    timestampF: Option[Instant],
    dateF: Option[LocalDate],
    timeF: Option[LocalTime],
    datetimeF: Option[LocalDateTime],
    bigDecimalF: Option[BigDecimal],
    geographyF: Option[Geography]
  )
  case class Repeated(
    boolF: List[Boolean],
    intF: List[Int],
    longF: List[Long],
    floatF: List[Float],
    doubleF: List[Double],
    stringF: List[String],
    byteArrayF: List[Array[Byte]],
    byteStringF: List[ByteString],
    timestampF: List[Instant],
    dateF: List[LocalDate],
    timeF: List[LocalTime],
    datetimeF: List[LocalDateTime],
    bigDecimalF: List[BigDecimal],
    geographyF: List[Geography]
  )

  // records
  case class RequiredNested(required: Required, optional: Optional, repeated: Repeated)
  case class OptionalNested(
    required: Option[Required],
    optional: Option[Optional],
    repeated: Option[Repeated]
  )
  case class RepeatedNested(
    required: List[Required],
    optional: List[Optional],
    repeated: List[Repeated]
  )

  case class User(@description("user name") name: String, @description("user age") age: Int)
  case class Account(
    @description("account user") user: User,
    @description("in USD") balance: Double
  )
} 
Example 100
Source File: ByteStringSerializer.scala    From scio   with Apache License 2.0 5 votes vote down vote up
package com.spotify.scio.coders.instances.kryo

import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}
import com.google.protobuf.ByteString
import com.twitter.chill.KSerializer

private[coders] class ByteStringSerializer extends KSerializer[ByteString] {
  override def read(kryo: Kryo, input: Input, tpe: Class[ByteString]): ByteString = {
    val n = input.readInt()
    ByteString.copyFrom(input.readBytes(n))
  }

  override def write(kryo: Kryo, output: Output, byteStr: ByteString): Unit = {
    val len = byteStr.size
    output.writeInt(len)
    val bytes = byteStr.iterator
    while (bytes.hasNext) {
      output.write(bytes.nextByte())
    }
  }
} 
Example 101
Source File: TFExampleIOTest.scala    From scio   with Apache License 2.0 5 votes vote down vote up
package com.spotify.scio.tensorflow

import com.google.protobuf.ByteString
import com.spotify.scio.testing._
import magnolify.tensorflow._

object TFExampleIOTest {
  case class Record(i: Int, s: String)

  implicit val efInt = ExampleField.from[Long](_.toInt)(_.toLong)
  implicit val efString = ExampleField.from[ByteString](_.toStringUtf8)(ByteString.copyFromUtf8)
  val recordT: ExampleType[Record] = ExampleType[Record]
}

class TFExampleIOTest extends ScioIOSpec {
  import TFExampleIOTest._

  "TFExampleIO" should "work" in {
    val xs = (1 to 100).map(x => recordT(Record(x, x.toString)))
    testTap(xs)(_.saveAsTfRecordFile(_))(".tfrecords")
    testJobTest(xs)(TFExampleIO(_))(_.tfRecordExampleFile(_))(_.saveAsTfRecordFile(_))
  }
} 
Example 102
Source File: MetadataSchemaTest.scala    From scio   with Apache License 2.0 5 votes vote down vote up
package com.spotify.scio.tensorflow

import com.google.protobuf.ByteString
import org.tensorflow.example._

import scala.jdk.CollectionConverters._

object MetadataSchemaTest {
  // Keep byte list the same length across examples to be parsed as a fixed shape.
  val e1Features = Map[String, Feature](
    "long" -> longFeature(Seq(1, 2, 3)),
    "bytes" -> byteStrFeature(Seq("a", "b", "c").map(ByteString.copyFromUtf8)),
    "floats" -> floatFeature(Seq(1.0f, 2.0f, 3.0f)),
    "indices" -> longFeature(Seq(1, 9)),
    "values" -> byteStrFeature(Seq("one", "nine").map(ByteString.copyFromUtf8)),
    "dense_shape" -> longFeature(Seq(100)),
    "missing_feature" -> longFeature(Seq(10))
  )
  val e2Features = Map[String, Feature](
    "long" -> longFeature(Seq(6)),
    "bytes" -> byteStrFeature(Seq("d", "e", "f").map(ByteString.copyFromUtf8)),
    "floats" -> floatFeature(Seq(4.0f, 5.0f)),
    "indices" -> longFeature(Seq(1, 2, 80)),
    "values" -> byteStrFeature(Seq("one", "two", "eighty").map(ByteString.copyFromUtf8)),
    "dense_shape" -> longFeature(Seq(100))
  )

  val e1FeatureList = Map[String, FeatureList](
    "string_list" -> featureList(
      Seq("one", "two", "eighty")
        .map(v => Seq(ByteString.copyFromUtf8(v)))
        .map(byteStrFeature)
    ),
    "long_list" -> featureList(Seq(1L, 2L, 3L).map(Seq(_)).map(longFeature)),
    "floats_list" -> featureList(Seq(1.0f, 2.0f, 3.0f).map(Seq(_)).map(floatFeature))
  )

  val examples = Seq(e1Features, e2Features).map(mkExample)
  val sequenceExamples = Seq(e1Features, e2Features).map(m => mkSequenceExample(m, e1FeatureList))

  private def longFeature(raw: Seq[Long]): Feature = {
    val fb = Feature.newBuilder()
    val vals = Int64List.newBuilder()
    raw.foreach(vals.addValue)
    fb.setInt64List(vals)
    fb.build
  }

  private def byteStrFeature(raw: Seq[ByteString]): Feature = {
    val fb = Feature.newBuilder()
    val vals = BytesList.newBuilder()
    raw.foreach(vals.addValue)
    fb.setBytesList(vals)
    fb.build
  }

  private def floatFeature(raw: Seq[Float]): Feature = {
    val fb = Feature.newBuilder()
    val vals = FloatList.newBuilder()
    raw.foreach(vals.addValue)
    fb.setFloatList(vals)
    fb.build
  }

  private def featureList(fs: Seq[Feature]): FeatureList =
    FeatureList
      .newBuilder()
      .addAllFeature(fs.asJava)
      .build

  private def mkExample(features: Map[String, Feature]): Example =
    Example
      .newBuilder()
      .setFeatures(Features.newBuilder().putAllFeature(features.asJava))
      .build

  private def mkSequenceExample(
    context: Map[String, Feature],
    featureList: Map[String, FeatureList]
  ): SequenceExample =
    SequenceExample
      .newBuilder()
      .setContext(Features.newBuilder().putAllFeature(context.asJava))
      .setFeatureLists(FeatureLists.newBuilder().putAllFeatureList(featureList.asJava))
      .build
} 
Example 103
Source File: TFSequenceExampleIOTest.scala    From scio   with Apache License 2.0 5 votes vote down vote up
package com.spotify.scio.tensorflow

import com.google.protobuf.ByteString
import com.spotify.scio.testing.ScioIOSpec
import org.tensorflow.example._

import scala.jdk.CollectionConverters._

object TFSequenceExampleIOTest {
  case class Record(i: Int, ss: Seq[String])

  def toSequenceExample(r: Record): SequenceExample = {
    val context = Features
      .newBuilder()
      .putFeature(
        "i",
        Feature
          .newBuilder()
          .setInt64List(Int64List.newBuilder().addValue(r.i).build())
          .build()
      )
      .build()
    val fs = r.ss.map { s =>
      Feature
        .newBuilder()
        .setBytesList(
          BytesList
            .newBuilder()
            .addValue(ByteString.copyFromUtf8(s))
            .build()
        )
        .build()
    }
    val featureLists = FeatureLists
      .newBuilder()
      .putFeatureList("ss", FeatureList.newBuilder().addAllFeature(fs.asJava).build())
      .build()
    SequenceExample
      .newBuilder()
      .setContext(context)
      .setFeatureLists(featureLists)
      .build()
  }
}

class TFSequenceExampleIOTest extends ScioIOSpec {
  import TFSequenceExampleIOTest._

  "TFSequenceExampleIO" should "work" in {
    val xs = (1 to 100).map(x => toSequenceExample(Record(x, Seq(x.toString, x.toString))))
    testTap(xs)(_.saveAsTfRecordFile(_))(".tfrecords")
    testJobTest(xs)(TFSequenceExampleIO(_))(_.tfRecordSequenceExampleFile(_))(
      _.saveAsTfRecordFile(_)
    )
  }
} 
Example 104
Source File: TypedBigQueryIT.scala    From scio   with Apache License 2.0 5 votes vote down vote up
package com.spotify.scio.bigquery

import com.google.protobuf.ByteString
import com.spotify.scio._
import com.spotify.scio.bigquery.client.BigQuery
import com.spotify.scio.testing._
import magnolify.scalacheck.auto._
import org.apache.beam.sdk.options.PipelineOptionsFactory
import org.joda.time.format.DateTimeFormat
import org.joda.time.{Instant, LocalDate, LocalDateTime, LocalTime}
import org.scalacheck._
import org.scalatest.BeforeAndAfterAll

import scala.util.Random

object TypedBigQueryIT {
  @BigQueryType.toTable
  case class Record(
    bool: Boolean,
    int: Int,
    long: Long,
    float: Float,
    double: Double,
    string: String,
    byteString: ByteString,
    timestamp: Instant,
    date: LocalDate,
    time: LocalTime,
    datetime: LocalDateTime
  )

  // Workaround for millis rounding error
  val epochGen = Gen.chooseNum[Long](0L, 1000000000000L).map(x => x / 1000 * 1000)
  implicit val arbByteString = Arbitrary(Gen.alphaStr.map(ByteString.copyFromUtf8))
  implicit val arbInstant = Arbitrary(epochGen.map(new Instant(_)))
  implicit val arbDate = Arbitrary(epochGen.map(new LocalDate(_)))
  implicit val arbTime = Arbitrary(epochGen.map(new LocalTime(_)))
  implicit val arbDatetime = Arbitrary(epochGen.map(new LocalDateTime(_)))

  private val recordGen = {
    implicitly[Arbitrary[Record]].arbitrary
  }

  private val table = {
    val TIME_FORMATTER = DateTimeFormat.forPattern("yyyyMMddHHmmss")
    val now = Instant.now().toString(TIME_FORMATTER)
    val spec =
      "data-integration-test:bigquery_avro_it.records_" + now + "_" + Random.nextInt(Int.MaxValue)
    Table.Spec(spec)
  }
  private val records = Gen.listOfN(1000, recordGen).sample.get
  private val options = PipelineOptionsFactory
    .fromArgs(
      "--project=data-integration-test",
      "--tempLocation=gs://data-integration-test-eu/temp"
    )
    .create()
}

class TypedBigQueryIT extends PipelineSpec with BeforeAndAfterAll {
  import TypedBigQueryIT._

  override protected def beforeAll(): Unit = {
    val sc = ScioContext(options)
    sc.parallelize(records).saveAsTypedBigQueryTable(table)

    sc.run()
    ()
  }

  override protected def afterAll(): Unit =
    BigQuery.defaultInstance().tables.delete(table.ref)

  "TypedBigQuery" should "read records" in {
    val sc = ScioContext(options)
    sc.typedBigQuery[Record](table) should containInAnyOrder(records)
    sc.run()
  }
} 
Example 105
Source File: ByteStringSerializerTest.scala    From scio   with Apache License 2.0 5 votes vote down vote up
package com.spotify.scio.coders.instances.kryo

import com.esotericsoftware.kryo.io.{Input, Output}
import com.google.protobuf.ByteString
import com.twitter.chill.{Kryo, KryoSerializer}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class ByteStringSerializerTest extends AnyFlatSpec with Matchers {
  private def testRoundTrip(ser: ByteStringSerializer, bs: ByteString): Unit = {
    val k: Kryo = KryoSerializer.registered.newKryo()
    val o = new Array[Byte](bs.size() * 2)
    ser.write(k, new Output(o), bs)
    val back = ser.read(k, new Input(o), null)
    bs shouldEqual back
    ()
  }

  "ByteStringSerializer" should "roundtrip large ByteString" in {
    val ser = new ByteStringSerializer
    testRoundTrip(ser, ByteString.copyFrom(Array.fill(1056)(7.toByte)))
  }
} 
Example 106
Source File: SCollectionSyntax.scala    From scio   with Apache License 2.0 5 votes vote down vote up
package com.spotify.scio.bigtable.syntax

import com.google.bigtable.v2._
import com.google.cloud.bigtable.config.BigtableOptions
import com.google.protobuf.ByteString
import com.spotify.scio.coders.Coder
import com.spotify.scio.io.ClosedTap
import com.spotify.scio.values.SCollection
import org.joda.time.Duration

import com.spotify.scio.bigtable.BigtableWrite


  def saveAsBigtable(
    bigtableOptions: BigtableOptions,
    tableId: String,
    numOfShards: Int,
    flushInterval: Duration = BigtableWrite.Bulk.DefaultFlushInterval
  )(implicit coder: Coder[T]): ClosedTap[Nothing] =
    self.write(BigtableWrite[T](bigtableOptions, tableId))(
      BigtableWrite.Bulk(numOfShards, flushInterval)
    )
}

trait SCollectionSyntax {
  implicit def bigtableMutationOps[T <: Mutation](
    sc: SCollection[(ByteString, Iterable[T])]
  ): SCollectionMutationOps[T] = new SCollectionMutationOps[T](sc)
} 
Example 107
Source File: RichRowTest.scala    From scio   with Apache License 2.0 5 votes vote down vote up
package com.spotify.scio.bigtable

import com.google.bigtable.v2.{Cell, Column, Family, Row}
import com.google.protobuf.ByteString
import org.scalatest.matchers.should.Matchers
import org.scalatest.flatspec.AnyFlatSpec

import scala.jdk.CollectionConverters._
import scala.collection.immutable.ListMap

class RichRowTest extends AnyFlatSpec with Matchers {
  def bs(s: String): ByteString = ByteString.copyFromUtf8(s)

  val FAMILY_NAME = "family"

  val dataMap = Seq(
    "a" -> Seq(10 -> "x", 9 -> "y", 8 -> "z"),
    "b" -> Seq(7 -> "u", 6 -> "v", 5 -> "w"),
    "c" -> Seq(4 -> "r", 3 -> "s", 2 -> "t")
  ).map {
    case (q, cs) =>
      val kvs = cs.map(kv => (kv._1.toLong, bs(kv._2)))
      (bs(q), ListMap(kvs: _*))
  }.toMap

  val columns = dataMap.map {
    case (q, cs) =>
      val cells = cs.map {
        case (t, v) =>
          Cell.newBuilder().setTimestampMicros(t).setValue(v).build()
      }
      Column
        .newBuilder()
        .setQualifier(q)
        .addAllCells(cells.asJava)
        .build()
  }

  val row = Row
    .newBuilder()
    .addFamilies(
      Family
        .newBuilder()
        .setName(FAMILY_NAME)
        .addAllColumns(columns.asJava)
    )
    .build()

  "RichRow" should "support getColumnCells" in {
    for ((q, cs) <- dataMap) {
      val cells = cs.map {
        case (t, v) =>
          Cell.newBuilder().setTimestampMicros(t).setValue(v).build()
      }
      row.getColumnCells(FAMILY_NAME, q) shouldBe cells
    }
  }

  it should "support getColumnLatestCell" in {
    for ((q, cs) <- dataMap) {
      val cells = cs.map {
        case (t, v) =>
          Cell.newBuilder().setTimestampMicros(t).setValue(v).build()
      }
      row.getColumnLatestCell(FAMILY_NAME, q) shouldBe cells.headOption
    }
  }

  it should "support getFamilyMap" in {
    val familyMap = dataMap.map { case (q, cs) => (q, cs.head._2) }
    row.getFamilyMap(FAMILY_NAME) shouldBe familyMap
  }

  it should "support getMap" in {
    row.getMap shouldBe Map(FAMILY_NAME -> dataMap)
  }

  it should "support getNoVersionMap" in {
    val noVerMap = dataMap.map { case (q, cs) => (q, cs.head._2) }
    row.getNoVersionMap shouldBe Map(FAMILY_NAME -> noVerMap)
  }

  it should "support getValue" in {
    for ((q, cs) <- dataMap) {
      row.getValue(FAMILY_NAME, q) shouldBe Some(cs.head._2)
    }
  }
} 
Example 108
Source File: BigtableIOTest.scala    From scio   with Apache License 2.0 5 votes vote down vote up
package com.spotify.scio.bigtable

import com.google.bigtable.v2.Mutation.SetCell
import com.google.bigtable.v2.{Mutation, Row}
import com.google.protobuf.ByteString
import com.spotify.scio.testing._

class BigtableIOTest extends ScioIOSpec {
  val projectId = "project"
  val instanceId = "instance"

  "BigtableIO" should "work with input" in {
    val xs = (1 to 100).map { x =>
      Row.newBuilder().setKey(ByteString.copyFromUtf8(x.toString)).build()
    }
    testJobTestInput(xs)(BigtableIO(projectId, instanceId, _))(_.bigtable(projectId, instanceId, _))
  }

  it should "work with output" in {
    val xs = (1 to 100).map { x =>
      val k = ByteString.copyFromUtf8(x.toString)
      val m = Mutation
        .newBuilder()
        .setSetCell(SetCell.newBuilder().setValue(ByteString.copyFromUtf8(x.toString)))
        .build()
      (k, Iterable(m))
    }
    testJobTestOutput(xs)(BigtableIO(projectId, instanceId, _))(
      _.saveAsBigtable(projectId, instanceId, _)
    )
  }
}