package com.rntensorflow.imagerecognition; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.graphics.Canvas; import android.graphics.Matrix; import com.facebook.react.bridge.*; import com.rntensorflow.RNTensorflowInference; import com.rntensorflow.ResourceManager; import org.tensorflow.Tensor; import java.io.IOException; import java.nio.FloatBuffer; import java.util.*; public class ImageRecognizer { private static final int IMAGE_MEAN = 117; private static final float IMAGE_STD = 1; private static final int MAX_RESULTS = 3; private static final float THRESHOLD = 0.1f; private RNTensorflowInference inference; private ResourceManager resourceManager; private int imageMean; private float imageStd; private String[] labels; public ImageRecognizer(RNTensorflowInference inference, ResourceManager resourceManager, int imageMean, float imageStd, String[] labels) { this.inference = inference; this.resourceManager = resourceManager; this.imageMean = imageMean; this.imageStd = imageStd; this.labels = labels; } public static ImageRecognizer init( ReactContext reactContext, String modelFilename, String labelFilename, Integer imageMean, Double imageStd) throws IOException { Integer imageMeanResolved = imageMean != null ? imageMean : IMAGE_MEAN; Float imageStdResolved = imageStd != null ? imageStd.floatValue() : IMAGE_STD; RNTensorflowInference inference = RNTensorflowInference.init(reactContext, modelFilename); ResourceManager resourceManager = new ResourceManager(reactContext); String[] labels = resourceManager.loadResourceAsString(labelFilename).split("\\r?\\n"); return new ImageRecognizer(inference, resourceManager, imageMeanResolved, imageStdResolved, labels); } public WritableArray recognizeImage(final String image, final String inputName, final Integer inputSize, final String outputName, final Integer maxResults, final Double threshold) { String inputNameResolved = inputName != null ? inputName : "input"; String outputNameResolved = outputName != null ? outputName : "output"; Integer maxResultsResolved = maxResults != null ? maxResults : MAX_RESULTS; Float thresholdResolved = threshold != null ? threshold.floatValue() : THRESHOLD; Bitmap bitmapRaw = loadImage(resourceManager.loadResource(image)); int inputSizeResolved = inputSize != null ? inputSize : 224; int[] intValues = new int[inputSizeResolved * inputSizeResolved]; float[] floatValues = new float[inputSizeResolved * inputSizeResolved * 3]; Bitmap bitmap = Bitmap.createBitmap(inputSizeResolved, inputSizeResolved, Bitmap.Config.ARGB_8888); Matrix matrix = createMatrix(bitmapRaw.getWidth(), bitmapRaw.getHeight(), inputSizeResolved, inputSizeResolved); final Canvas canvas = new Canvas(bitmap); canvas.drawBitmap(bitmapRaw, matrix, null); bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); for (int i = 0; i < intValues.length; ++i) { final int val = intValues[i]; floatValues[i * 3 + 0] = (((val >> 16) & 0xFF) - imageMean) / imageStd; floatValues[i * 3 + 1] = (((val >> 8) & 0xFF) - imageMean) / imageStd; floatValues[i * 3 + 2] = ((val & 0xFF) - imageMean) / imageStd; } Tensor tensor = Tensor.create(new long[]{1, inputSizeResolved, inputSizeResolved, 3}, FloatBuffer.wrap(floatValues)); inference.feed(inputNameResolved, tensor); inference.run(new String[] {outputNameResolved}, false); ReadableArray outputs = inference.fetch(outputNameResolved); List<WritableMap> results = new ArrayList<>(); for (int i = 0; i < outputs.size(); ++i) { if (outputs.getDouble(i) > thresholdResolved) { WritableMap entry = new WritableNativeMap(); entry.putString("id", String.valueOf(i)); entry.putString("name", labels.length > i ? labels[i] : "unknown"); entry.putDouble("confidence", outputs.getDouble(i)); results.add(entry); } } Collections.sort(results, new Comparator<ReadableMap>() { @Override public int compare(ReadableMap first, ReadableMap second) { return Double.compare(second.getDouble("confidence"), first.getDouble("confidence")); } }); int finalSize = Math.min(results.size(), maxResultsResolved); WritableArray array = new WritableNativeArray(); for (int i = 0; i < finalSize; i++) { array.pushMap(results.get(i)); } inference.getTfContext().reset(); return array; } private Bitmap loadImage(byte[] image) { BitmapFactory.Options options = new BitmapFactory.Options(); options.inPreferredConfig = Bitmap.Config.ARGB_8888; return BitmapFactory.decodeByteArray(image, 0, image.length); } private Matrix createMatrix(int srcWidth, int srcHeight, int dstWidth, int dstHeight) { Matrix matrix = new Matrix(); if (srcWidth != dstWidth || srcHeight != dstHeight) { float scaleFactorX = dstWidth / (float) srcWidth; float scaleFactorY = dstHeight / (float) srcHeight; float scaleFactor = Math.max(scaleFactorX, scaleFactorY); matrix.postScale(scaleFactor, scaleFactor); } matrix.invert(new Matrix()); return matrix; } }