package com.google.firebase.samples.apps.mlkit.java.automl; import android.content.Context; import android.graphics.Bitmap; import androidx.annotation.NonNull; import androidx.annotation.Nullable; import android.util.Log; import android.widget.Toast; import com.google.android.gms.tasks.Continuation; import com.google.android.gms.tasks.Task; import com.google.android.gms.tasks.Tasks; import com.google.firebase.ml.common.FirebaseMLException; import com.google.firebase.ml.common.modeldownload.FirebaseModelDownloadConditions; import com.google.firebase.ml.common.modeldownload.FirebaseModelManager; import com.google.firebase.ml.vision.FirebaseVision; import com.google.firebase.ml.vision.automl.FirebaseAutoMLLocalModel; import com.google.firebase.ml.vision.automl.FirebaseAutoMLRemoteModel; import com.google.firebase.ml.vision.common.FirebaseVisionImage; import com.google.firebase.ml.vision.label.FirebaseVisionImageLabel; import com.google.firebase.ml.vision.label.FirebaseVisionImageLabeler; import com.google.firebase.ml.vision.label.FirebaseVisionOnDeviceAutoMLImageLabelerOptions; import com.google.firebase.samples.apps.mlkit.R; import com.google.firebase.samples.apps.mlkit.common.CameraImageGraphic; import com.google.firebase.samples.apps.mlkit.common.FrameMetadata; import com.google.firebase.samples.apps.mlkit.common.GraphicOverlay; import com.google.firebase.samples.apps.mlkit.java.VisionProcessorBase; import com.google.firebase.samples.apps.mlkit.java.labeldetector.LabelGraphic; import com.google.firebase.samples.apps.mlkit.common.preference.PreferenceUtils; import java.io.IOException; import java.util.Collections; import java.util.List; /** * AutoML image labeler Demo. */ public class AutoMLImageLabelerProcessor extends VisionProcessorBase<List<FirebaseVisionImageLabel>> { private static final String TAG = "ODAutoMLILProcessor"; private final Context context; private final FirebaseVisionImageLabeler detector; private final Task<Void> modelDownloadingTask; private final Mode mode; /** * The detection mode of the processor. Different modes will have different behavior on whether or * not waiting for the model download complete. */ public enum Mode { STILL_IMAGE, LIVE_PREVIEW } public AutoMLImageLabelerProcessor(Context context, Mode mode) throws FirebaseMLException { this.context = context; this.mode = mode; String modelChoice = PreferenceUtils.getAutoMLRemoteModelChoice(context); if (modelChoice.equals(context.getString(R.string.pref_entries_automl_models_local))) { Log.d(TAG, "Local model used."); FirebaseAutoMLLocalModel localModel = new FirebaseAutoMLLocalModel.Builder().setAssetFilePath("automl/manifest.json").build(); detector = FirebaseVision.getInstance() .getOnDeviceAutoMLImageLabeler( new FirebaseVisionOnDeviceAutoMLImageLabelerOptions.Builder(localModel) .setConfidenceThreshold(0) .build()); modelDownloadingTask = null; } else { Log.d(TAG, "Remote model used."); String remoteModelName = PreferenceUtils.getAutoMLRemoteModelName(context); FirebaseAutoMLRemoteModel remoteModel = new FirebaseAutoMLRemoteModel.Builder(remoteModelName).build(); FirebaseModelDownloadConditions downloadConditions = new FirebaseModelDownloadConditions.Builder().requireWifi().build(); modelDownloadingTask = FirebaseModelManager.getInstance().download(remoteModel, downloadConditions); detector = FirebaseVision.getInstance() .getOnDeviceAutoMLImageLabeler( new FirebaseVisionOnDeviceAutoMLImageLabelerOptions.Builder(remoteModel) .setConfidenceThreshold(0) .build()); } } @Override public void stop() { try { detector.close(); } catch (IOException e) { Log.e(TAG, "Exception thrown while trying to close the image labeler", e); } } @Override protected Task<List<FirebaseVisionImageLabel>> detectInImage(final FirebaseVisionImage image) { if (modelDownloadingTask == null) { // No download task means only the locally bundled model is used. Model can be used directly. return detector.processImage(image); } else if (!modelDownloadingTask.isComplete()) { if (mode == Mode.LIVE_PREVIEW) { Log.i(TAG, "Model download is in progress. Skip detecting image."); return Tasks.forResult(Collections.<FirebaseVisionImageLabel>emptyList()); } else { Log.i(TAG, "Model download is in progress. Waiting..."); return modelDownloadingTask.continueWithTask(new Continuation<Void, Task<List<FirebaseVisionImageLabel>>>() { @Override public Task<List<FirebaseVisionImageLabel>> then(@NonNull Task<Void> task) { return processImageOnDownloadComplete(image); } }); } } else { return processImageOnDownloadComplete(image); } } @Override protected void onSuccess( @Nullable Bitmap originalCameraImage, @NonNull List<FirebaseVisionImageLabel> labels, @NonNull FrameMetadata frameMetadata, @NonNull GraphicOverlay graphicOverlay) { graphicOverlay.clear(); if (originalCameraImage != null) { CameraImageGraphic imageGraphic = new CameraImageGraphic(graphicOverlay, originalCameraImage); graphicOverlay.add(imageGraphic); } LabelGraphic labelGraphic = new LabelGraphic(graphicOverlay, labels); graphicOverlay.add(labelGraphic); graphicOverlay.postInvalidate(); } @Override protected void onFailure(@NonNull Exception e) { Log.w(TAG, "Label detection failed.", e); } private Task<List<FirebaseVisionImageLabel>> processImageOnDownloadComplete( FirebaseVisionImage image) { if (modelDownloadingTask.isSuccessful()) { return detector.processImage(image); } else { String downloadingError = "Error downloading remote model."; Log.e(TAG, downloadingError, modelDownloadingTask.getException()); Toast.makeText(context, downloadingError, Toast.LENGTH_SHORT).show(); return Tasks.forException( new Exception("Failed to download remote model.", modelDownloadingTask.getException())); } } }