package com.amitshekhar.tflite; import android.content.res.AssetFileDescriptor; import android.content.res.AssetManager; import android.graphics.Bitmap; import android.graphics.RectF; import android.util.Log; import org.tensorflow.lite.Interpreter; import java.io.BufferedReader; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStreamReader; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.PriorityQueue; /** * Created by amitshekhar on 17/03/18. */ public abstract class Classifier { /** * An immutable result returned by a Classifier describing what was recognized. */ public class Recognition { /** * A unique identifier for what has been recognized. Specific to the class, not the instance of * the object. */ private final String id; /** * Display name for the recognition. */ private final String title; /** * A sortable score for how good the recognition is relative to others. Higher should be better. */ private final Float confidence; /** * Optional location within the source image for the location of the recognized object. */ private RectF location; public int detectedClass; public Recognition( final String id, final String title, final Float confidence, final RectF location, int detectedClass) { this.id = id; this.title = title; this.confidence = confidence; this.location = location; this.detectedClass = detectedClass; } public String getId() { return id; } public String getTitle() { return title; } public Float getConfidence() { return confidence; } public RectF getLocation() { return new RectF(location); } public void setLocation(RectF location) { this.location = location; } @Override public String toString() { String resultString = ""; if (id != null) { resultString += "[" + id + "] "; } if (title != null) { resultString += title + " "; } if (confidence != null) { resultString += String.format("(%.1f%%) ", confidence * 100.0f); } if (location != null) { resultString += location + " "; } return resultString.trim(); } } protected float mNmsThresh = 0.5f; protected List<String> mLabelList; protected static final int NUM_BOXES_PER_BLOCK = 3; protected static final int BATCH_SIZE = 1; protected static final int PIXEL_SIZE = 3; protected Interpreter mInterpreter; protected int mInputSize; protected int[][] mMasks; protected int[] mAnchors; protected int[] mOutWidth; public Classifier (AssetManager assetManager, String modelPath, String labelPath, int inputSize) throws IOException { mInterpreter = new Interpreter(loadModelFile(assetManager, modelPath)); mLabelList = loadLabelList(assetManager, labelPath); StringBuilder builder = new StringBuilder(); for (String label: mLabelList) { builder.append(label).append(" "); } Log.d("wangmin", "Labels are:\n" + builder.toString()); mInputSize = inputSize; } //non maximum suppression protected ArrayList<Recognition> nms(ArrayList<Recognition> list) { ArrayList<Recognition> nmsList = new ArrayList<Recognition>(); for (int k = 0; k < mLabelList.size(); k++) { //1.find max confidence per class PriorityQueue<Recognition> pq = new PriorityQueue<Recognition>( 10, new Comparator<Recognition>() { @Override public int compare(final Recognition lhs, final Recognition rhs) { // Intentionally reversed to put high confidence at the head of the queue. return Float.compare(rhs.getConfidence(), lhs.getConfidence()); } }); for (int i = 0; i < list.size(); ++i) { if (list.get(i).detectedClass == k) { pq.add(list.get(i)); } } Log.d("wangmin", "class[" + k + "] pq size: " + pq.size()); //2.do non maximum suppression while(pq.size() > 0) { //insert detection with max confidence Recognition[] a = new Recognition[pq.size()]; Recognition[] detections = pq.toArray(a); Recognition max = detections[0]; nmsList.add(max); Log.d("wangmin", "before nms pq size: " + pq.size()); //clear pq to do next nms pq.clear(); for (int j = 1; j < detections.length; j++) { Recognition detection = detections[j]; RectF b = detection.getLocation(); if (box_iou(max.getLocation(), b) < mNmsThresh){ pq.add(detection); } } Log.d("wangmin", "after nms pq size: " + pq.size()); } } return nmsList; } protected float box_iou(RectF a, RectF b) { return box_intersection(a, b)/box_union(a, b); } protected float box_intersection(RectF a, RectF b) { float w = overlap((a.left + a.right) / 2, a.right - a.left, (b.left + b.right) / 2, b.right - b.left); float h = overlap((a.top + a.bottom) / 2, a.bottom - a.top, (b.top + b.bottom) / 2, b.bottom - b.top); if(w < 0 || h < 0) return 0; float area = w*h; return area; } protected float box_union(RectF a, RectF b) { float i = box_intersection(a, b); float u = (a.right - a.left)*(a.bottom - a.top) + (b.right - b.left)*(b.bottom - b.top) - i; return u; } protected float overlap(float x1, float w1, float x2, float w2) { float l1 = x1 - w1/2; float l2 = x2 - w2/2; float left = l1 > l2 ? l1 : l2; float r1 = x1 + w1/2; float r2 = x2 + w2/2; float right = r1 < r2 ? r1 : r2; return right - left; } protected void close() { mInterpreter.close(); mInterpreter = null; } protected void softmax(final float[] vals) { float max = Float.NEGATIVE_INFINITY; for (final float val : vals) { max = Math.max(max, val); } float sum = 0.0f; for (int i = 0; i < vals.length; ++i) { vals[i] = (float) Math.exp(vals[i] - max); sum += vals[i]; } for (int i = 0; i < vals.length; ++i) { vals[i] = vals[i] / sum; } } protected float expit(final float x) { return (float) (1. / (1. + Math.exp(-x))); } protected MappedByteBuffer loadModelFile(AssetManager assetManager, String modelPath) throws IOException { AssetFileDescriptor fileDescriptor = assetManager.openFd(modelPath); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } /** Writes Image data into a {@code ByteBuffer}. */ protected ByteBuffer convertBitmapToByteBuffer(Bitmap bitmap) { ByteBuffer byteBuffer = ByteBuffer.allocateDirect(4 * BATCH_SIZE * mInputSize * mInputSize * PIXEL_SIZE); byteBuffer.order(ByteOrder.nativeOrder()); int[] intValues = new int[mInputSize * mInputSize]; bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); int pixel = 0; for (int i = 0; i < mInputSize; ++i) { for (int j = 0; j < mInputSize; ++j) { final int val = intValues[pixel++]; byteBuffer.putFloat(((val >> 16) & 0xFF) / 255.0f); byteBuffer.putFloat(((val >> 8) & 0xFF) / 255.0f); byteBuffer.putFloat((val & 0xFF) / 255.0f); } } return byteBuffer; } public ArrayList<Recognition> RecognizeImage(Bitmap bitmap) { ByteBuffer byteBuffer = convertBitmapToByteBuffer(bitmap); Map<Integer, Object> outputMap = new HashMap<>(); for (int i = 0; i < mOutWidth.length; i++) { float[][][][][] out = new float[1][mOutWidth[i]][mOutWidth[i]][3][5 + mLabelList.size()]; outputMap.put(i, out); } Log.d("wangmin", "mObjThresh: " + getObjThresh()); Object[] inputArray = {byteBuffer}; mInterpreter.runForMultipleInputsOutputs(inputArray, outputMap); ArrayList<Recognition> detections = new ArrayList<Recognition>(); for (int i = 0; i < mOutWidth.length; i++) { int gridWidth = mOutWidth[i]; float[][][][][] out = (float[][][][][])outputMap.get(i); Log.d("wangmin", "out[" + i + "] detect start"); for (int y = 0; y < gridWidth; ++y) { for (int x = 0; x < gridWidth; ++x) { for (int b = 0; b < NUM_BOXES_PER_BLOCK; ++b) { final int offset = (gridWidth * (NUM_BOXES_PER_BLOCK * (mLabelList.size() + 5))) * y + (NUM_BOXES_PER_BLOCK * (mLabelList.size() + 5)) * x + (mLabelList.size() + 5) * b; final float confidence = expit(out[0][y][x][b][4]); int detectedClass = -1; float maxClass = 0; final float[] classes = new float[mLabelList.size()]; for (int c = 0; c < mLabelList.size(); ++c) { classes[c] = out[0][y][x][b][5+c]; } softmax(classes); for (int c = 0; c < mLabelList.size(); ++c) { if (classes[c] > maxClass) { detectedClass = c; maxClass = classes[c]; } } final float confidenceInClass = maxClass * confidence; if (confidenceInClass > getObjThresh()) { final float xPos = (x + expit(out[0][y][x][b][0])) * (mInputSize / gridWidth); final float yPos = (y + expit(out[0][y][x][b][1])) * (mInputSize / gridWidth); final float w = (float) (Math.exp(out[0][y][x][b][2]) * mAnchors[2 * mMasks[i][b] + 0]); final float h = (float) (Math.exp(out[0][y][x][b][3]) * mAnchors[2 * mMasks[i][b] + 1]); Log.d("wangmin","box x:" + xPos + ", y:" + yPos + ", w:" + w + ", h:" + h); final RectF rect = new RectF( Math.max(0, xPos - w / 2), Math.max(0, yPos - h / 2), Math.min(bitmap.getWidth() - 1, xPos + w / 2), Math.min(bitmap.getHeight() - 1, yPos + h / 2)); Log.d("wangmin", "detect " + mLabelList.get(detectedClass) + ", confidence: " + confidenceInClass + ", box: " + rect.toString()); detections.add(new Recognition("" + offset, mLabelList.get(detectedClass), confidenceInClass, rect, detectedClass)); } } } } Log.d("wangmin", "out[" + i + "] detect end"); } final ArrayList<Recognition> recognitions = nms(detections); return recognitions; } protected List<String> loadLabelList(AssetManager assetManager, String labelPath) throws IOException { List<String> labelList = new ArrayList<>(); BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open(labelPath))); String line; while ((line = reader.readLine()) != null) { labelList.add(line); } reader.close(); return labelList; } public int getInputSize() { return mInputSize; } protected abstract float getObjThresh(); }