org.tensorflow.DataType Java Examples

The following examples show how to use org.tensorflow.DataType. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SamplingDataset.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new SamplingDataset operation.
 * 
 * @param scope current scope
 * @param inputDataset 
 * @param rate A scalar representing the sample rate. Each element of `input_dataset` is
 * retained with this probability, independent of all other elements.
 * @param seed A scalar representing seed of random number generator.
 * @param seed2 A scalar representing seed2 of random number generator.
 * @param outputTypes 
 * @param outputShapes 
 * @return a new instance of SamplingDataset
 */
@Endpoint(describeByClass = true)
public static SamplingDataset create(Scope scope, Operand<?> inputDataset, Operand<TFloat32> rate, Operand<TInt64> seed, Operand<TInt64> seed2, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
  OperationBuilder opBuilder = scope.env().opBuilder("SamplingDataset", scope.makeOpName("SamplingDataset"));
  opBuilder.addInput(inputDataset.asOutput());
  opBuilder.addInput(rate.asOutput());
  opBuilder.addInput(seed.asOutput());
  opBuilder.addInput(seed2.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  DataType[] outputTypesArray = new DataType[outputTypes.size()];
  for (int i = 0; i < outputTypesArray.length; ++i) {
    outputTypesArray[i] = outputTypes.get(i);
  }
  opBuilder.setAttr("output_types", outputTypesArray);
  Shape[] outputShapesArray = new Shape[outputShapes.size()];
  for (int i = 0; i < outputShapesArray.length; ++i) {
    outputShapesArray[i] = outputShapes.get(i);
  }
  opBuilder.setAttr("output_shapes", outputShapesArray);
  return new SamplingDataset(opBuilder.build());
}
 
Example #2
Source File: HashTable.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new HashTable operation.
 * 
 * @param scope current scope
 * @param keyDtype Type of the table keys.
 * @param valueDtype Type of the table values.
 * @param options carries optional attributes values
 * @return a new instance of HashTable
 */
@Endpoint(describeByClass = true)
public static <T extends TType, U extends TType> HashTable create(Scope scope, DataType<T> keyDtype, DataType<U> valueDtype, Options... options) {
  OperationBuilder opBuilder = scope.env().opBuilder("HashTableV2", scope.makeOpName("HashTable"));
  opBuilder = scope.applyControlDependencies(opBuilder);
  opBuilder.setAttr("key_dtype", keyDtype);
  opBuilder.setAttr("value_dtype", valueDtype);
  if (options != null) {
    for (Options opts : options) {
      if (opts.container != null) {
        opBuilder.setAttr("container", opts.container);
      }
      if (opts.sharedName != null) {
        opBuilder.setAttr("shared_name", opts.sharedName);
      }
      if (opts.useNodeNameSharing != null) {
        opBuilder.setAttr("use_node_name_sharing", opts.useNodeNameSharing);
      }
    }
  }
  return new HashTable(opBuilder.build());
}
 
Example #3
Source File: DebugNumericsSummary.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new DebugNumericsSummary operation.
 * 
 * @param scope current scope
 * @param input Input tensor, to be summarized by the op.
 * @param outputDtype Optional. The type of the output. Can be float32 or float64 (default: float32).
 * @param options carries optional attributes values
 * @return a new instance of DebugNumericsSummary
 */
@Endpoint(describeByClass = true)
public static <U extends TNumber, T extends TType> DebugNumericsSummary<U> create(Scope scope, Operand<T> input, DataType<U> outputDtype, Options... options) {
  OperationBuilder opBuilder = scope.env().opBuilder("DebugNumericSummaryV2", scope.makeOpName("DebugNumericsSummary"));
  opBuilder.addInput(input.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  opBuilder.setAttr("output_dtype", outputDtype);
  if (options != null) {
    for (Options opts : options) {
      if (opts.tensorDebugMode != null) {
        opBuilder.setAttr("tensor_debug_mode", opts.tensorDebugMode);
      }
      if (opts.tensorId != null) {
        opBuilder.setAttr("tensor_id", opts.tensorId);
      }
    }
  }
  return new DebugNumericsSummary<U>(opBuilder.build());
}
 
Example #4
Source File: TemporaryVariable.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new TemporaryVariable operation.
 * 
 * @param scope current scope
 * @param shape The shape of the variable tensor.
 * @param dtype The type of elements in the variable tensor.
 * @param options carries optional attributes values
 * @return a new instance of TemporaryVariable
 */
@Endpoint(describeByClass = true)
public static <T extends TType> TemporaryVariable<T> create(Scope scope, Shape shape, DataType<T> dtype, Options... options) {
  OperationBuilder opBuilder = scope.env().opBuilder("TemporaryVariable", scope.makeOpName("TemporaryVariable"));
  opBuilder = scope.applyControlDependencies(opBuilder);
  opBuilder.setAttr("shape", shape);
  opBuilder.setAttr("dtype", dtype);
  if (options != null) {
    for (Options opts : options) {
      if (opts.varName != null) {
        opBuilder.setAttr("var_name", opts.varName);
      }
    }
  }
  return new TemporaryVariable<T>(opBuilder.build());
}
 
Example #5
Source File: DatasetToSingleElement.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new DatasetToSingleElement operation.
 * 
 * @param scope current scope
 * @param dataset A handle to a dataset that contains a single element.
 * @param outputTypes 
 * @param outputShapes 
 * @return a new instance of DatasetToSingleElement
 */
@Endpoint(describeByClass = true)
public static DatasetToSingleElement create(Scope scope, Operand<?> dataset, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
  OperationBuilder opBuilder = scope.env().opBuilder("DatasetToSingleElement", scope.makeOpName("DatasetToSingleElement"));
  opBuilder.addInput(dataset.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  DataType[] outputTypesArray = new DataType[outputTypes.size()];
  for (int i = 0; i < outputTypesArray.length; ++i) {
    outputTypesArray[i] = outputTypes.get(i);
  }
  opBuilder.setAttr("output_types", outputTypesArray);
  Shape[] outputShapesArray = new Shape[outputShapes.size()];
  for (int i = 0; i < outputShapesArray.length; ++i) {
    outputShapesArray[i] = outputShapes.get(i);
  }
  opBuilder.setAttr("output_shapes", outputShapesArray);
  return new DatasetToSingleElement(opBuilder.build());
}
 
Example #6
Source File: OptionalGetValue.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new OptionalGetValue operation.
 * 
 * @param scope current scope
 * @param optional 
 * @param outputTypes 
 * @param outputShapes 
 * @return a new instance of OptionalGetValue
 */
@Endpoint(describeByClass = true)
public static OptionalGetValue create(Scope scope, Operand<?> optional, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
  OperationBuilder opBuilder = scope.env().opBuilder("OptionalGetValue", scope.makeOpName("OptionalGetValue"));
  opBuilder.addInput(optional.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  DataType[] outputTypesArray = new DataType[outputTypes.size()];
  for (int i = 0; i < outputTypesArray.length; ++i) {
    outputTypesArray[i] = outputTypes.get(i);
  }
  opBuilder.setAttr("output_types", outputTypesArray);
  Shape[] outputShapesArray = new Shape[outputShapes.size()];
  for (int i = 0; i < outputShapesArray.length; ++i) {
    outputShapesArray[i] = outputShapes.get(i);
  }
  opBuilder.setAttr("output_shapes", outputShapesArray);
  return new OptionalGetValue(opBuilder.build());
}
 
Example #7
Source File: LatencyStatsDataset.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new LatencyStatsDataset operation.
 * 
 * @param scope current scope
 * @param inputDataset 
 * @param tag 
 * @param outputTypes 
 * @param outputShapes 
 * @return a new instance of LatencyStatsDataset
 */
@Endpoint(describeByClass = true)
public static LatencyStatsDataset create(Scope scope, Operand<?> inputDataset, Operand<TString> tag, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
  OperationBuilder opBuilder = scope.env().opBuilder("LatencyStatsDataset", scope.makeOpName("LatencyStatsDataset"));
  opBuilder.addInput(inputDataset.asOutput());
  opBuilder.addInput(tag.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  DataType[] outputTypesArray = new DataType[outputTypes.size()];
  for (int i = 0; i < outputTypesArray.length; ++i) {
    outputTypesArray[i] = outputTypes.get(i);
  }
  opBuilder.setAttr("output_types", outputTypesArray);
  Shape[] outputShapesArray = new Shape[outputShapes.size()];
  for (int i = 0; i < outputShapesArray.length; ++i) {
    outputShapesArray[i] = outputShapes.get(i);
  }
  opBuilder.setAttr("output_shapes", outputShapesArray);
  return new LatencyStatsDataset(opBuilder.build());
}
 
Example #8
Source File: MultiDeviceIteratorGetNextFromShard.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new MultiDeviceIteratorGetNextFromShard operation.
 * 
 * @param scope current scope
 * @param multiDeviceIterator A MultiDeviceIterator resource.
 * @param shardNum Integer representing which shard to fetch data for.
 * @param incarnationId Which incarnation of the MultiDeviceIterator is running.
 * @param outputTypes The type list for the return values.
 * @param outputShapes The list of shapes being produced.
 * @return a new instance of MultiDeviceIteratorGetNextFromShard
 */
@Endpoint(describeByClass = true)
public static MultiDeviceIteratorGetNextFromShard create(Scope scope, Operand<?> multiDeviceIterator, Operand<TInt32> shardNum, Operand<TInt64> incarnationId, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
  OperationBuilder opBuilder = scope.env().opBuilder("MultiDeviceIteratorGetNextFromShard", scope.makeOpName("MultiDeviceIteratorGetNextFromShard"));
  opBuilder.addInput(multiDeviceIterator.asOutput());
  opBuilder.addInput(shardNum.asOutput());
  opBuilder.addInput(incarnationId.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  DataType[] outputTypesArray = new DataType[outputTypes.size()];
  for (int i = 0; i < outputTypesArray.length; ++i) {
    outputTypesArray[i] = outputTypes.get(i);
  }
  opBuilder.setAttr("output_types", outputTypesArray);
  Shape[] outputShapesArray = new Shape[outputShapes.size()];
  for (int i = 0; i < outputShapesArray.length; ++i) {
    outputShapesArray[i] = outputShapes.get(i);
  }
  opBuilder.setAttr("output_shapes", outputShapesArray);
  return new MultiDeviceIteratorGetNextFromShard(opBuilder.build());
}
 
Example #9
Source File: RandomStandardNormal.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new RandomStandardNormal operation.
 * 
 * @param scope current scope
 * @param shape The shape of the output tensor.
 * @param dtype The type of the output.
 * @param options carries optional attributes values
 * @return a new instance of RandomStandardNormal
 */
@Endpoint(describeByClass = true)
public static <U extends TNumber, T extends TNumber> RandomStandardNormal<U> create(Scope scope, Operand<T> shape, DataType<U> dtype, Options... options) {
  OperationBuilder opBuilder = scope.env().opBuilder("RandomStandardNormal", scope.makeOpName("RandomStandardNormal"));
  opBuilder.addInput(shape.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  opBuilder.setAttr("dtype", dtype);
  if (options != null) {
    for (Options opts : options) {
      if (opts.seed != null) {
        opBuilder.setAttr("seed", opts.seed);
      }
      if (opts.seed2 != null) {
        opBuilder.setAttr("seed2", opts.seed2);
      }
    }
  }
  return new RandomStandardNormal<U>(opBuilder.build());
}
 
Example #10
Source File: LmdbDataset.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new LmdbDataset operation.
 * 
 * @param scope current scope
 * @param filenames 
 * @param outputTypes 
 * @param outputShapes 
 * @return a new instance of LmdbDataset
 */
@Endpoint(describeByClass = true)
public static LmdbDataset create(Scope scope, Operand<TString> filenames, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
  OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalLMDBDataset", scope.makeOpName("LmdbDataset"));
  opBuilder.addInput(filenames.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  DataType[] outputTypesArray = new DataType[outputTypes.size()];
  for (int i = 0; i < outputTypesArray.length; ++i) {
    outputTypesArray[i] = outputTypes.get(i);
  }
  opBuilder.setAttr("output_types", outputTypesArray);
  Shape[] outputShapesArray = new Shape[outputShapes.size()];
  for (int i = 0; i < outputShapesArray.length; ++i) {
    outputShapesArray[i] = outputShapes.get(i);
  }
  opBuilder.setAttr("output_shapes", outputShapesArray);
  return new LmdbDataset(opBuilder.build());
}
 
Example #11
Source File: UnbatchDataset.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new UnbatchDataset operation.
 * 
 * @param scope current scope
 * @param inputDataset 
 * @param outputTypes 
 * @param outputShapes 
 * @return a new instance of UnbatchDataset
 */
@Endpoint(describeByClass = true)
public static UnbatchDataset create(Scope scope, Operand<?> inputDataset, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
  OperationBuilder opBuilder = scope.env().opBuilder("UnbatchDataset", scope.makeOpName("UnbatchDataset"));
  opBuilder.addInput(inputDataset.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  DataType[] outputTypesArray = new DataType[outputTypes.size()];
  for (int i = 0; i < outputTypesArray.length; ++i) {
    outputTypesArray[i] = outputTypes.get(i);
  }
  opBuilder.setAttr("output_types", outputTypesArray);
  Shape[] outputShapesArray = new Shape[outputShapes.size()];
  for (int i = 0; i < outputShapesArray.length; ++i) {
    outputShapesArray[i] = outputShapes.get(i);
  }
  opBuilder.setAttr("output_shapes", outputShapesArray);
  return new UnbatchDataset(opBuilder.build());
}
 
Example #12
Source File: LabelImageTensorflowInputConverter.java    From tensorflow-spring-cloud-stream-app-starters with Apache License 2.0 6 votes vote down vote up
public LabelImageTensorflowInputConverter() {
	graph = new Graph();
	GraphBuilder b = new GraphBuilder(graph);
	// Some constants specific to the pre-trained model at:
	// https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
	// - The model was trained with images scaled to 224x224 pixels.
	// - The colors, represented as R, G, B in 1-byte each were converted to
	//   float using (value - Mean)/Scale.
	final int H = 224;
	final int W = 224;
	final float mean = 117f;
	final float scale = 1f;

	final Output input = b.placeholder("input", DataType.STRING);
	graphOutput =
			b.div(
					b.sub(
							b.resizeBilinear(
									b.expandDims(
											b.cast(b.decodeJpeg(input, 3), DataType.FLOAT),
											b.constant("make_batch", 0)),
									b.constant("size", new int[] {H, W})),
							b.constant("mean", mean)),
					b.constant("scale", scale));

}
 
Example #13
Source File: NonSerializableDataset.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new NonSerializableDataset operation.
 * 
 * @param scope current scope
 * @param inputDataset 
 * @param outputTypes 
 * @param outputShapes 
 * @return a new instance of NonSerializableDataset
 */
@Endpoint(describeByClass = true)
public static NonSerializableDataset create(Scope scope, Operand<?> inputDataset, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
  OperationBuilder opBuilder = scope.env().opBuilder("NonSerializableDataset", scope.makeOpName("NonSerializableDataset"));
  opBuilder.addInput(inputDataset.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  DataType[] outputTypesArray = new DataType[outputTypes.size()];
  for (int i = 0; i < outputTypesArray.length; ++i) {
    outputTypesArray[i] = outputTypes.get(i);
  }
  opBuilder.setAttr("output_types", outputTypesArray);
  Shape[] outputShapesArray = new Shape[outputShapes.size()];
  for (int i = 0; i < outputShapesArray.length; ++i) {
    outputShapesArray[i] = outputShapes.get(i);
  }
  opBuilder.setAttr("output_shapes", outputShapesArray);
  return new NonSerializableDataset(opBuilder.build());
}
 
Example #14
Source File: MultiDeviceIterator.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new MultiDeviceIterator operation.
 * 
 * @param scope current scope
 * @param devices A list of devices the iterator works across.
 * @param sharedName If non-empty, this resource will be shared under the given name
 * across multiple sessions.
 * @param container If non-empty, this resource is placed in the given container.
 * Otherwise, a default container is used.
 * @param outputTypes The type list for the return values.
 * @param outputShapes The list of shapes being produced.
 * @return a new instance of MultiDeviceIterator
 */
@Endpoint(describeByClass = true)
public static MultiDeviceIterator create(Scope scope, List<String> devices, String sharedName, String container, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
  OperationBuilder opBuilder = scope.env().opBuilder("MultiDeviceIterator", scope.makeOpName("MultiDeviceIterator"));
  opBuilder = scope.applyControlDependencies(opBuilder);
  String[] devicesArray = new String[devices.size()];
  for (int i = 0; i < devicesArray.length; ++i) {
    devicesArray[i] = devices.get(i);
  }
  opBuilder.setAttr("devices", devicesArray);
  opBuilder.setAttr("shared_name", sharedName);
  opBuilder.setAttr("container", container);
  DataType[] outputTypesArray = new DataType[outputTypes.size()];
  for (int i = 0; i < outputTypesArray.length; ++i) {
    outputTypesArray[i] = outputTypes.get(i);
  }
  opBuilder.setAttr("output_types", outputTypesArray);
  Shape[] outputShapesArray = new Shape[outputShapes.size()];
  for (int i = 0; i < outputShapesArray.length; ++i) {
    outputShapesArray[i] = outputShapes.get(i);
  }
  opBuilder.setAttr("output_shapes", outputShapesArray);
  return new MultiDeviceIterator(opBuilder.build());
}
 
Example #15
Source File: Variable.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new Variable operation.
 * 
 * @param scope current scope
 * @param shape The shape of the variable tensor.
 * @param dtype The type of elements in the variable tensor.
 * @param options carries optional attributes values
 * @return a new instance of Variable
 */
@Endpoint(describeByClass = true)
public static <T extends TType> Variable<T> create(Scope scope, Shape shape, DataType<T> dtype, Options... options) {
  OperationBuilder opBuilder = scope.env().opBuilder("VariableV2", scope.makeOpName("Variable"));
  opBuilder = scope.applyControlDependencies(opBuilder);
  opBuilder.setAttr("shape", shape);
  opBuilder.setAttr("dtype", dtype);
  if (options != null) {
    for (Options opts : options) {
      if (opts.container != null) {
        opBuilder.setAttr("container", opts.container);
      }
      if (opts.sharedName != null) {
        opBuilder.setAttr("shared_name", opts.sharedName);
      }
    }
  }
  return new Variable<T>(opBuilder.build());
}
 
Example #16
Source File: TensorListStack.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new TensorListStack operation.
 * 
 * @param scope current scope
 * @param inputHandle 
 * @param elementShape 
 * @param elementDtype 
 * @param options carries optional attributes values
 * @return a new instance of TensorListStack
 */
@Endpoint(describeByClass = true)
public static <T extends TType> TensorListStack<T> create(Scope scope, Operand<?> inputHandle, Operand<TInt32> elementShape, DataType<T> elementDtype, Options... options) {
  OperationBuilder opBuilder = scope.env().opBuilder("TensorListStack", scope.makeOpName("TensorListStack"));
  opBuilder.addInput(inputHandle.asOutput());
  opBuilder.addInput(elementShape.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  opBuilder.setAttr("element_dtype", elementDtype);
  if (options != null) {
    for (Options opts : options) {
      if (opts.numElements != null) {
        opBuilder.setAttr("num_elements", opts.numElements);
      }
    }
  }
  return new TensorListStack<T>(opBuilder.build());
}
 
Example #17
Source File: AssertNextDataset.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new AssertNextDataset operation.
 * 
 * @param scope current scope
 * @param inputDataset 
 * @param transformations 
 * @param outputTypes 
 * @param outputShapes 
 * @return a new instance of AssertNextDataset
 */
@Endpoint(describeByClass = true)
public static AssertNextDataset create(Scope scope, Operand<?> inputDataset, Operand<TString> transformations, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
  OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalAssertNextDataset", scope.makeOpName("AssertNextDataset"));
  opBuilder.addInput(inputDataset.asOutput());
  opBuilder.addInput(transformations.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  DataType[] outputTypesArray = new DataType[outputTypes.size()];
  for (int i = 0; i < outputTypesArray.length; ++i) {
    outputTypesArray[i] = outputTypes.get(i);
  }
  opBuilder.setAttr("output_types", outputTypesArray);
  Shape[] outputShapesArray = new Shape[outputShapes.size()];
  for (int i = 0; i < outputShapesArray.length; ++i) {
    outputShapesArray[i] = outputShapes.get(i);
  }
  opBuilder.setAttr("output_shapes", outputShapesArray);
  return new AssertNextDataset(opBuilder.build());
}
 
Example #18
Source File: DirectedInterleaveDataset.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new DirectedInterleaveDataset operation.
 * 
 * @param scope current scope
 * @param selectorInputDataset A dataset of scalar `DT_INT64` elements that determines which of the
 * `N` data inputs should produce the next output element.
 * @param dataInputDatasets `N` datasets with the same type that will be interleaved according to
 * the values of `selector_input_dataset`.
 * @param outputTypes 
 * @param outputShapes 
 * @return a new instance of DirectedInterleaveDataset
 */
@Endpoint(describeByClass = true)
public static DirectedInterleaveDataset create(Scope scope, Operand<?> selectorInputDataset, Iterable<Operand<?>> dataInputDatasets, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
  OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalDirectedInterleaveDataset", scope.makeOpName("DirectedInterleaveDataset"));
  opBuilder.addInput(selectorInputDataset.asOutput());
  opBuilder.addInputList(Operands.asOutputs(dataInputDatasets));
  opBuilder = scope.applyControlDependencies(opBuilder);
  DataType[] outputTypesArray = new DataType[outputTypes.size()];
  for (int i = 0; i < outputTypesArray.length; ++i) {
    outputTypesArray[i] = outputTypes.get(i);
  }
  opBuilder.setAttr("output_types", outputTypesArray);
  Shape[] outputShapesArray = new Shape[outputShapes.size()];
  for (int i = 0; i < outputShapesArray.length; ++i) {
    outputShapesArray[i] = outputShapes.get(i);
  }
  opBuilder.setAttr("output_shapes", outputShapesArray);
  return new DirectedInterleaveDataset(opBuilder.build());
}
 
Example #19
Source File: MaxIntraOpParallelismDataset.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new MaxIntraOpParallelismDataset operation.
 * 
 * @param scope current scope
 * @param inputDataset 
 * @param maxIntraOpParallelism Identifies the maximum intra-op parallelism to use.
 * @param outputTypes 
 * @param outputShapes 
 * @return a new instance of MaxIntraOpParallelismDataset
 */
@Endpoint(describeByClass = true)
public static MaxIntraOpParallelismDataset create(Scope scope, Operand<?> inputDataset, Operand<TInt64> maxIntraOpParallelism, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
  OperationBuilder opBuilder = scope.env().opBuilder("MaxIntraOpParallelismDataset", scope.makeOpName("MaxIntraOpParallelismDataset"));
  opBuilder.addInput(inputDataset.asOutput());
  opBuilder.addInput(maxIntraOpParallelism.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  DataType[] outputTypesArray = new DataType[outputTypes.size()];
  for (int i = 0; i < outputTypesArray.length; ++i) {
    outputTypesArray[i] = outputTypes.get(i);
  }
  opBuilder.setAttr("output_types", outputTypesArray);
  Shape[] outputShapesArray = new Shape[outputShapes.size()];
  for (int i = 0; i < outputShapesArray.length; ++i) {
    outputShapesArray[i] = outputShapes.get(i);
  }
  opBuilder.setAttr("output_shapes", outputShapesArray);
  return new MaxIntraOpParallelismDataset(opBuilder.build());
}
 
Example #20
Source File: WindowDataset.java    From java with Apache License 2.0 6 votes vote down vote up
/**
 * Factory method to create a class wrapping a new WindowDataset operation.
 * 
 * @param scope current scope
 * @param inputDataset 
 * @param size An integer scalar, representing the number of elements
 * of the input dataset to combine into a window. Must be positive.
 * @param shift An integer scalar, representing the number of input elements
 * by which the window moves in each iteration.  Defaults to `size`.
 * Must be positive.
 * @param stride An integer scalar, representing the stride of the input elements
 * in the sliding window. Must be positive. The default value of 1 means
 * "retain every input element".
 * @param dropRemainder A Boolean scalar, representing whether the last window should be
 * dropped if its size is smaller than `window_size`.
 * @param outputTypes 
 * @param outputShapes 
 * @return a new instance of WindowDataset
 */
@Endpoint(describeByClass = true)
public static WindowDataset create(Scope scope, Operand<?> inputDataset, Operand<TInt64> size, Operand<TInt64> shift, Operand<TInt64> stride, Operand<TBool> dropRemainder, List<DataType<?>> outputTypes, List<Shape> outputShapes) {
  OperationBuilder opBuilder = scope.env().opBuilder("WindowDataset", scope.makeOpName("WindowDataset"));
  opBuilder.addInput(inputDataset.asOutput());
  opBuilder.addInput(size.asOutput());
  opBuilder.addInput(shift.asOutput());
  opBuilder.addInput(stride.asOutput());
  opBuilder.addInput(dropRemainder.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  DataType[] outputTypesArray = new DataType[outputTypes.size()];
  for (int i = 0; i < outputTypesArray.length; ++i) {
    outputTypesArray[i] = outputTypes.get(i);
  }
  opBuilder.setAttr("output_types", outputTypesArray);
  Shape[] outputShapesArray = new Shape[outputShapes.size()];
  for (int i = 0; i < outputShapesArray.length; ++i) {
    outputShapesArray[i] = outputShapes.get(i);
  }
  opBuilder.setAttr("output_shapes", outputShapesArray);
  return new WindowDataset(opBuilder.build());
}
 
Example #21
Source File: CSRSparseMatrixComponents.java    From java with Apache License 2.0 5 votes vote down vote up
/**
 * Factory method to create a class wrapping a new CSRSparseMatrixComponents operation.
 * 
 * @param scope current scope
 * @param csrSparseMatrix A batched CSRSparseMatrix.
 * @param index The index in `csr_sparse_matrix`'s batch.
 * @param type 
 * @return a new instance of CSRSparseMatrixComponents
 */
@Endpoint(describeByClass = true)
public static <T extends TType> CSRSparseMatrixComponents<T> create(Scope scope, Operand<?> csrSparseMatrix, Operand<TInt32> index, DataType<T> type) {
  OperationBuilder opBuilder = scope.env().opBuilder("CSRSparseMatrixComponents", scope.makeOpName("CSRSparseMatrixComponents"));
  opBuilder.addInput(csrSparseMatrix.asOutput());
  opBuilder.addInput(index.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  opBuilder.setAttr("type", type);
  return new CSRSparseMatrixComponents<T>(opBuilder.build());
}
 
Example #22
Source File: InfeedDequeue.java    From java with Apache License 2.0 5 votes vote down vote up
/**
 * Factory method to create a class wrapping a new InfeedDequeue operation.
 * 
 * @param scope current scope
 * @param dtype The type of elements in the tensor.
 * @param shape The shape of the tensor.
 * @return a new instance of InfeedDequeue
 */
@Endpoint(describeByClass = true)
public static <T extends TType> InfeedDequeue<T> create(Scope scope, DataType<T> dtype, Shape shape) {
  OperationBuilder opBuilder = scope.env().opBuilder("InfeedDequeue", scope.makeOpName("InfeedDequeue"));
  opBuilder = scope.applyControlDependencies(opBuilder);
  opBuilder.setAttr("dtype", dtype);
  opBuilder.setAttr("shape", shape);
  return new InfeedDequeue<T>(opBuilder.build());
}
 
Example #23
Source File: MaxPoolWithArgmax.java    From java with Apache License 2.0 5 votes vote down vote up
/**
 * Factory method to create a class wrapping a new MaxPoolWithArgmax operation.
 * 
 * @param scope current scope
 * @param input 4-D with shape `[batch, height, width, channels]`.  Input to pool over.
 * @param ksize The size of the window for each dimension of the input tensor.
 * @param strides The stride of the sliding window for each dimension of the
 * input tensor.
 * @param Targmax 
 * @param padding The type of padding algorithm to use.
 * @param options carries optional attributes values
 * @return a new instance of MaxPoolWithArgmax
 */
@Endpoint(describeByClass = true)
public static <T extends TNumber, U extends TNumber> MaxPoolWithArgmax<T, U> create(Scope scope, Operand<T> input, List<Long> ksize, List<Long> strides, DataType<U> Targmax, String padding, Options... options) {
  OperationBuilder opBuilder = scope.env().opBuilder("MaxPoolWithArgmax", scope.makeOpName("MaxPoolWithArgmax"));
  opBuilder.addInput(input.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  long[] ksizeArray = new long[ksize.size()];
  for (int i = 0; i < ksizeArray.length; ++i) {
    ksizeArray[i] = ksize.get(i);
  }
  opBuilder.setAttr("ksize", ksizeArray);
  long[] stridesArray = new long[strides.size()];
  for (int i = 0; i < stridesArray.length; ++i) {
    stridesArray[i] = strides.get(i);
  }
  opBuilder.setAttr("strides", stridesArray);
  opBuilder.setAttr("Targmax", Targmax);
  opBuilder.setAttr("padding", padding);
  if (options != null) {
    for (Options opts : options) {
      if (opts.includeBatchInIndex != null) {
        opBuilder.setAttr("include_batch_in_index", opts.includeBatchInIndex);
      }
    }
  }
  return new MaxPoolWithArgmax<T, U>(opBuilder.build());
}
 
Example #24
Source File: Recv.java    From java with Apache License 2.0 5 votes vote down vote up
/**
 * Factory method to create a class wrapping a new Recv operation.
 * 
 * @param scope current scope
 * @param dtype The type of the tensor.
 * @param tensorName A string key that identifies the channel.
 * @param shape The shape of the tensor.
 * @return a new instance of Recv
 */
@Endpoint(describeByClass = true)
public static <T extends TType> Recv<T> create(Scope scope, DataType<T> dtype, String tensorName, Shape shape) {
  OperationBuilder opBuilder = scope.env().opBuilder("XlaRecv", scope.makeOpName("Recv"));
  opBuilder = scope.applyControlDependencies(opBuilder);
  opBuilder.setAttr("dtype", dtype);
  opBuilder.setAttr("tensor_name", tensorName);
  opBuilder.setAttr("shape", shape);
  return new Recv<T>(opBuilder.build());
}
 
Example #25
Source File: TensorListElementShape.java    From java with Apache License 2.0 5 votes vote down vote up
/**
 * Factory method to create a class wrapping a new TensorListElementShape operation.
 * 
 * @param scope current scope
 * @param inputHandle 
 * @param shapeType 
 * @return a new instance of TensorListElementShape
 */
@Endpoint(describeByClass = true)
public static <T extends TNumber> TensorListElementShape<T> create(Scope scope, Operand<?> inputHandle, DataType<T> shapeType) {
  OperationBuilder opBuilder = scope.env().opBuilder("TensorListElementShape", scope.makeOpName("TensorListElementShape"));
  opBuilder.addInput(inputHandle.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  opBuilder.setAttr("shape_type", shapeType);
  return new TensorListElementShape<T>(opBuilder.build());
}
 
Example #26
Source File: TensorJsonConverter.java    From tensorflow with Apache License 2.0 5 votes vote down vote up
public static Tensor toTensor(String json) {
	try {
		JsonTensor jsonTensor = new ObjectMapper().readValue(json, JsonTensor.class);
		DataType dataType = DataType.valueOf(jsonTensor.getType());
		long[] shape = jsonTensor.getShape();
		byte[] tfValue = Base64.getDecoder().decode(jsonTensor.getValue());
		return Tensor.create(dataTypeToClass(dataType), shape, ByteBuffer.wrap(tfValue));
	}
	catch (Throwable throwable) {
		throw new RuntimeException(String.format("Can not covert json:'%s' into Tensor", json), throwable);
	}
}
 
Example #27
Source File: JTensorTest.java    From zoltar with Apache License 2.0 5 votes vote down vote up
@Test
public void testFloatTensor() {
  final float[] floatValue = {1, 2, 3, 4, 5};
  final Tensor<Float> tensor = Tensors.create(floatValue);
  final JTensor jt = JTensor.create(tensor);
  assertEquals(DataType.FLOAT, jt.dataType());
  assertEquals(1, jt.numDimensions());
  assertArrayEquals(shape, jt.shape());
  assertArrayEquals(floatValue, jt.floatValue(), 0.0f);
  testException(jt, JTensor::stringValue);
  testException(jt, JTensor::intValue);
  testException(jt, JTensor::longValue);
  testException(jt, JTensor::doubleValue);
}
 
Example #28
Source File: ShardDataset.java    From java with Apache License 2.0 5 votes vote down vote up
/**
 * Factory method to create a class wrapping a new ShardDataset operation.
 * 
 * @param scope current scope
 * @param inputDataset 
 * @param numShards An integer representing the number of shards operating in parallel.
 * @param index An integer representing the current worker index.
 * @param outputTypes 
 * @param outputShapes 
 * @param options carries optional attributes values
 * @return a new instance of ShardDataset
 */
@Endpoint(describeByClass = true)
public static ShardDataset create(Scope scope, Operand<?> inputDataset, Operand<TInt64> numShards, Operand<TInt64> index, List<DataType<?>> outputTypes, List<Shape> outputShapes, Options... options) {
  OperationBuilder opBuilder = scope.env().opBuilder("ShardDataset", scope.makeOpName("ShardDataset"));
  opBuilder.addInput(inputDataset.asOutput());
  opBuilder.addInput(numShards.asOutput());
  opBuilder.addInput(index.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  DataType[] outputTypesArray = new DataType[outputTypes.size()];
  for (int i = 0; i < outputTypesArray.length; ++i) {
    outputTypesArray[i] = outputTypes.get(i);
  }
  opBuilder.setAttr("output_types", outputTypesArray);
  Shape[] outputShapesArray = new Shape[outputShapes.size()];
  for (int i = 0; i < outputShapesArray.length; ++i) {
    outputShapesArray[i] = outputShapes.get(i);
  }
  opBuilder.setAttr("output_shapes", outputShapesArray);
  if (options != null) {
    for (Options opts : options) {
      if (opts.requireNonEmpty != null) {
        opBuilder.setAttr("require_non_empty", opts.requireNonEmpty);
      }
    }
  }
  return new ShardDataset(opBuilder.build());
}
 
Example #29
Source File: AutoShardDataset.java    From java with Apache License 2.0 5 votes vote down vote up
/**
 * Factory method to create a class wrapping a new AutoShardDataset operation.
 * 
 * @param scope current scope
 * @param inputDataset A variant tensor representing the input dataset.
 * @param numWorkers A scalar representing the number of workers to distribute this dataset across.
 * @param index A scalar representing the index of the current worker out of num_workers.
 * @param outputTypes 
 * @param outputShapes 
 * @param options carries optional attributes values
 * @return a new instance of AutoShardDataset
 */
@Endpoint(describeByClass = true)
public static AutoShardDataset create(Scope scope, Operand<?> inputDataset, Operand<TInt64> numWorkers, Operand<TInt64> index, List<DataType<?>> outputTypes, List<Shape> outputShapes, Options... options) {
  OperationBuilder opBuilder = scope.env().opBuilder("ExperimentalAutoShardDataset", scope.makeOpName("AutoShardDataset"));
  opBuilder.addInput(inputDataset.asOutput());
  opBuilder.addInput(numWorkers.asOutput());
  opBuilder.addInput(index.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  DataType[] outputTypesArray = new DataType[outputTypes.size()];
  for (int i = 0; i < outputTypesArray.length; ++i) {
    outputTypesArray[i] = outputTypes.get(i);
  }
  opBuilder.setAttr("output_types", outputTypesArray);
  Shape[] outputShapesArray = new Shape[outputShapes.size()];
  for (int i = 0; i < outputShapesArray.length; ++i) {
    outputShapesArray[i] = outputShapes.get(i);
  }
  opBuilder.setAttr("output_shapes", outputShapesArray);
  if (options != null) {
    for (Options opts : options) {
      if (opts.autoShardPolicy != null) {
        opBuilder.setAttr("auto_shard_policy", opts.autoShardPolicy);
      }
    }
  }
  return new AutoShardDataset(opBuilder.build());
}
 
Example #30
Source File: BarrierTakeMany.java    From java with Apache License 2.0 5 votes vote down vote up
/**
 * Factory method to create a class wrapping a new BarrierTakeMany operation.
 * 
 * @param scope current scope
 * @param handle The handle to a barrier.
 * @param numElements A single-element tensor containing the number of elements to
 * take.
 * @param componentTypes The type of each component in a value.
 * @param options carries optional attributes values
 * @return a new instance of BarrierTakeMany
 */
@Endpoint(describeByClass = true)
public static BarrierTakeMany create(Scope scope, Operand<TString> handle, Operand<TInt32> numElements, List<DataType<?>> componentTypes, Options... options) {
  OperationBuilder opBuilder = scope.env().opBuilder("BarrierTakeMany", scope.makeOpName("BarrierTakeMany"));
  opBuilder.addInput(handle.asOutput());
  opBuilder.addInput(numElements.asOutput());
  opBuilder = scope.applyControlDependencies(opBuilder);
  DataType[] componentTypesArray = new DataType[componentTypes.size()];
  for (int i = 0; i < componentTypesArray.length; ++i) {
    componentTypesArray[i] = componentTypes.get(i);
  }
  opBuilder.setAttr("component_types", componentTypesArray);
  if (options != null) {
    for (Options opts : options) {
      if (opts.allowSmallBatch != null) {
        opBuilder.setAttr("allow_small_batch", opts.allowSmallBatch);
      }
      if (opts.waitForIncomplete != null) {
        opBuilder.setAttr("wait_for_incomplete", opts.waitForIncomplete);
      }
      if (opts.timeoutMs != null) {
        opBuilder.setAttr("timeout_ms", opts.timeoutMs);
      }
    }
  }
  return new BarrierTakeMany(opBuilder.build());
}