org.tensorflow.framework.TensorProto Java Examples

The following examples show how to use org.tensorflow.framework.TensorProto. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TensorConverter.java    From vespa with Apache License 2.0 5 votes vote down vote up
static Tensor toVespaTensor(TensorProto tensorProto, OrderedTensorType type) {
    Values values = readValuesOf(tensorProto);
    if (values.size() == 0) { // Might be stored as "tensor_content" instead
        return toVespaTensor(readTensorContentOf(tensorProto), type);
    }
    IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type.type());
    for (int i = 0; i < values.size(); ++i)
        builder.cellByDirectIndex(i, values.get(i));
    return builder.build();
}
 
Example #2
Source File: TensorConverter.java    From vespa with Apache License 2.0 5 votes vote down vote up
private static Values readValuesOf(TensorProto tensorProto) {
    switch (tensorProto.getDtype()) {
        case DT_BOOL: return new ProtoBoolValues(tensorProto);
        case DT_HALF: return new ProtoHalfValues(tensorProto);
        case DT_INT16: case DT_INT32: return new ProtoIntValues(tensorProto);
        case DT_INT64: return new ProtoInt64Values(tensorProto);
        case DT_FLOAT: return new ProtoFloatValues(tensorProto);
        case DT_DOUBLE: return new ProtoDoubleValues(tensorProto);
        default: throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
    }
}
 
Example #3
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public BoolTensorMapper(TensorProto tensorProto) {
    super(tensorProto);
}
 
Example #4
Source File: ModelServerClassification.java    From hazelcast-jet-demos with Apache License 2.0 4 votes vote down vote up
private static Pipeline buildPipeline(String serverAddress, IMap<Long, String> reviewsMap) {
    ServiceFactory<Tuple2<PredictionServiceFutureStub, WordIndex>, Tuple2<PredictionServiceFutureStub, WordIndex>>
            tfServingContext = ServiceFactory
            .withCreateContextFn(context -> {
                WordIndex wordIndex = new WordIndex(context.attachedDirectory("data"));
                ManagedChannel channel = ManagedChannelBuilder.forTarget(serverAddress)
                                                              .usePlaintext().build();
                return Tuple2.tuple2(PredictionServiceGrpc.newFutureStub(channel), wordIndex);
            })
            .withDestroyContextFn(t -> ((ManagedChannel) t.f0().getChannel()).shutdownNow())
            .withCreateServiceFn((context, tuple2) -> tuple2);

    Pipeline p = Pipeline.create();
    p.readFrom(Sources.map(reviewsMap))
     .map(Map.Entry::getValue)
     .mapUsingServiceAsync(tfServingContext, 16, true, (t, review) -> {
         float[][] featuresTensorData = t.f1().createTensorInput(review);
         TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();
         for (float[] featuresTensorDatum : featuresTensorData) {
             for (float v : featuresTensorDatum) {
                 featuresTensorBuilder.addFloatVal(v);
             }
         }
         TensorShapeProto.Dim featuresDim1 =
                 TensorShapeProto.Dim.newBuilder().setSize(featuresTensorData.length).build();
         TensorShapeProto.Dim featuresDim2 =
                 TensorShapeProto.Dim.newBuilder().setSize(featuresTensorData[0].length).build();
         TensorShapeProto featuresShape =
                 TensorShapeProto.newBuilder().addDim(featuresDim1).addDim(featuresDim2).build();
         featuresTensorBuilder.setDtype(org.tensorflow.framework.DataType.DT_FLOAT)
                              .setTensorShape(featuresShape);
         TensorProto featuresTensorProto = featuresTensorBuilder.build();

         // Generate gRPC request
         Int64Value version = Int64Value.newBuilder().setValue(1).build();
         Model.ModelSpec modelSpec =
                 Model.ModelSpec.newBuilder().setName("reviewSentiment").setVersion(version).build();
         Predict.PredictRequest request = Predict.PredictRequest.newBuilder()
                                                                .setModelSpec(modelSpec)
                                                                .putInputs("input_review", featuresTensorProto)
                                                                .build();

         return toCompletableFuture(t.f0().predict(request))
                 .thenApply(response -> {
                     float classification = response
                             .getOutputsOrThrow("dense_1/Sigmoid:0")
                             .getFloatVal(0);
                     // emit the review along with the classification
                     return tuple2(review, classification);
                 });
     })
     .setLocalParallelism(1) // one worker is enough to drive they async calls
     .writeTo(Sinks.logger());
    return p;
}
 
Example #5
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public StringTensorMapper(TensorProto tensorProto) {
    super(tensorProto);
}
 
Example #6
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public UInt64TensorMapper(TensorProto tensorProto) {
    super(tensorProto);
}
 
Example #7
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public UInt32TensorMapper(TensorProto tensorProto) {
    super(tensorProto);
}
 
Example #8
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public UInt16TensorMapper(TensorProto tensorProto) {
    super(tensorProto);
}
 
Example #9
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public UInt8TensorMapper(TensorProto tensorProto) {
    super(tensorProto);
}
 
Example #10
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public Int64TensorMapper(TensorProto tensorProto) {
    super(tensorProto);
}
 
Example #11
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public Int32TensorMapper(TensorProto tensorProto) {
    super(tensorProto);
}
 
Example #12
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public Int16TensorMapper(TensorProto tensorProto) {
    super(tensorProto);
}
 
Example #13
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public Int8TensorMapper(TensorProto tensorProto) {
    super(tensorProto);
}
 
Example #14
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public BFloat16TensorMapper(TensorProto tensorProto) {
    super(tensorProto);
}
 
Example #15
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public Float64TensorMapper(TensorProto tensorProto) {
    super(tensorProto);
}
 
Example #16
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public Float32TensorMapper(TensorProto tensorProto) {
    super(tensorProto);
}
 
Example #17
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public Float16TensorMapper(TensorProto tensorProto) {
    super(tensorProto);
}
 
Example #18
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public BaseTensorMapper(TensorProto tensorProto){
    this.tfTensor = tensorProto;
}
 
Example #19
Source File: TFTensorMappers.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public static TFTensorMapper<?,?> newMapper(TensorProto tp){

        switch (tp.getDtype()){
            case DT_HALF:
                return new Float16TensorMapper(tp);
            case DT_FLOAT:
                return new Float32TensorMapper(tp);
            case DT_DOUBLE:
                return new Float64TensorMapper(tp);
            case DT_BFLOAT16:
                return new BFloat16TensorMapper(tp);

            case DT_INT8:
                return new Int8TensorMapper(tp);
            case DT_INT16:
                return new Int16TensorMapper(tp);
            case DT_INT32:
                return new Int32TensorMapper(tp);
            case DT_INT64:
                return new Int64TensorMapper(tp);


            case DT_STRING:
                return new StringTensorMapper(tp);

            case DT_BOOL:
                return new BoolTensorMapper(tp);

            case DT_UINT8:
                return new UInt8TensorMapper(tp);
            case DT_UINT16:
                return new UInt16TensorMapper(tp);
            case DT_UINT32:
                return new UInt32TensorMapper(tp);
            case DT_UINT64:
                return new UInt64TensorMapper(tp);

            case DT_QINT8:
            case DT_QUINT8:
            case DT_QINT32:
            case DT_QINT16:
            case DT_QUINT16:
                throw new IllegalStateException("Unable to map quantized type: " + tp.getDtype());
            case DT_COMPLEX64:
            case DT_COMPLEX128:
                throw new IllegalStateException("Unable to map complex type: " + tp.getDtype());
            case DT_FLOAT_REF:
            case DT_DOUBLE_REF:
            case DT_INT32_REF:
            case DT_UINT8_REF:
            case DT_INT16_REF:
            case DT_INT8_REF:
            case DT_STRING_REF:
            case DT_COMPLEX64_REF:
            case DT_INT64_REF:
            case DT_BOOL_REF:
            case DT_QINT8_REF:
            case DT_QUINT8_REF:
            case DT_QINT32_REF:
            case DT_BFLOAT16_REF:
            case DT_QINT16_REF:
            case DT_QUINT16_REF:
            case DT_UINT16_REF:
            case DT_COMPLEX128_REF:
            case DT_HALF_REF:
            case DT_RESOURCE_REF:
            case DT_VARIANT_REF:
            case DT_UINT32_REF:
            case DT_UINT64_REF:
                throw new IllegalStateException("Unable to map reference type: " + tp.getDtype());
            case UNRECOGNIZED:
            case DT_RESOURCE:
            case DT_VARIANT:
            case DT_INVALID:
            default:
                throw new IllegalStateException("Unable to map type: " + tp.getDtype());
        }
    }
 
Example #20
Source File: TensorConverter.java    From vespa with Apache License 2.0 4 votes vote down vote up
private static org.tensorflow.Tensor readTensorContentOf(TensorProto tensorProto) {
    return org.tensorflow.Tensor.create(dataTypeToClass(tensorProto.getDtype()),
                                        asSizeArray(tensorProto.getTensorShape().getDimList()),
                                        tensorProto.getTensorContent().asReadOnlyByteBuffer());
}
 
Example #21
Source File: TensorConverter.java    From vespa with Apache License 2.0 votes vote down vote up
ProtoDoubleValues(TensorProto tensorProto) { super(tensorProto); } 
Example #22
Source File: TensorConverter.java    From vespa with Apache License 2.0 votes vote down vote up
ProtoFloatValues(TensorProto tensorProto) { super(tensorProto); } 
Example #23
Source File: TensorConverter.java    From vespa with Apache License 2.0 votes vote down vote up
ProtoInt64Values(TensorProto tensorProto) { super(tensorProto); } 
Example #24
Source File: TensorConverter.java    From vespa with Apache License 2.0 votes vote down vote up
ProtoIntValues(TensorProto tensorProto) { super(tensorProto); } 
Example #25
Source File: TensorConverter.java    From vespa with Apache License 2.0 votes vote down vote up
ProtoHalfValues(TensorProto tensorProto) { super(tensorProto); } 
Example #26
Source File: TensorConverter.java    From vespa with Apache License 2.0 votes vote down vote up
ProtoBoolValues(TensorProto tensorProto) { super(tensorProto); } 
Example #27
Source File: TensorConverter.java    From vespa with Apache License 2.0 votes vote down vote up
ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; }