/* * Copyright 2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.cloud.stream.app.tensorflow.processor; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Base64; import java.util.HashMap; import java.util.Map; import com.fasterxml.jackson.databind.ObjectMapper; import org.tensorflow.DataType; import org.tensorflow.Tensor; import org.tensorflow.types.UInt8; /** * Utility that helps to covert {@link Tensor} to Json and in reverse. * @author Christian Tzolov */ public class TensorJsonConverter { public static String toJson(Tensor tensor) { // Retrieve all bytes in the buffer ByteBuffer buffer = ByteBuffer.allocate(tensor.numBytes()); tensor.writeTo(buffer); buffer.clear(); byte[] bytes = new byte[buffer.capacity()]; buffer.get(bytes, 0, bytes.length); long[] shape = tensor.shape(); String bytesBase64 = Base64.getEncoder().encodeToString(bytes); return String.format("{ \"type\": \"%s\", \"shape\": %s, \"value\": \"%s\" }", tensor.dataType().name(), Arrays.toString(shape), bytesBase64); } public static Tensor toTensor(String json) { try { JsonTensor jsonTensor = new ObjectMapper().readValue(json, JsonTensor.class); DataType dataType = DataType.valueOf(jsonTensor.getType()); long[] shape = jsonTensor.getShape(); byte[] tfValue = Base64.getDecoder().decode(jsonTensor.getValue()); return Tensor.create(dataTypeToClass(dataType), shape, ByteBuffer.wrap(tfValue)); } catch (Throwable throwable) { throw new RuntimeException(String.format("Can not covert json:'%s' into Tensor", json), throwable); } } private static final Map<DataType, Class<?>> typeToClassMap = new HashMap<>(); static { typeToClassMap.put(DataType.FLOAT, Float.class); typeToClassMap.put(DataType.DOUBLE, Double.class); typeToClassMap.put(DataType.INT32, Integer.class); typeToClassMap.put(DataType.UINT8, UInt8.class); typeToClassMap.put(DataType.INT64, Long.class); typeToClassMap.put(DataType.BOOL, Boolean.class); typeToClassMap.put(DataType.STRING, String.class); } private static Class<?> dataTypeToClass(DataType dataType) { Class<?> clazz = typeToClassMap.get(dataType); if (clazz == null) { throw new IllegalArgumentException("No class found for dataType: " + dataType); } return clazz; } }