package edu.ml.tensorflow.util;

import org.tensorflow.DataType;
import org.tensorflow.Graph;
import org.tensorflow.Output;
import org.tensorflow.Tensor;
import org.tensorflow.types.UInt8;

/**
 * Defines the necessary operations for image processing and builds the graph.
 */
public class GraphBuilder {
    private Graph graph;

    public GraphBuilder(Graph graph) {
        this.graph = graph;
    }

    public Output<Float> div(Output<Float> x, Output<Float> y) {
        return binaryOp("Div", x, y);
    }

    public <T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) {
        return binaryOp3("ResizeBilinear", images, size);
    }

    public <T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
        return binaryOp3("ExpandDims", input, dim);
    }

    public <T, U> Output<U> cast(Output<T> value, Class<U> type) {
        DataType dtype = DataType.fromClass(type);
        return graph.opBuilder("Cast", "Cast")
                .addInput(value)
                .setAttr("DstT", dtype)
                .build()
                .<U>output(0);
    }

    public Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
        return graph.opBuilder("DecodeJpeg", "DecodeJpeg")
                .addInput(contents)
                .setAttr("channels", channels)
                .build()
                .<UInt8>output(0);
    }

    public <T> Output<T> constant(String name, Object value, Class<T> type) {
        try (Tensor<T> t = Tensor.<T>create(value, type)) {
            return graph.opBuilder("Const", name)
                    .setAttr("dtype", DataType.fromClass(type))
                    .setAttr("value", t)
                    .build()
                    .<T>output(0);
        }
    }

    public Output<String> constant(String name, byte[] value) {
        return this.constant(name, value, String.class);
    }

    public Output<Integer> constant(String name, int value) {
        return this.constant(name, value, Integer.class);
    }

    public Output<Integer> constant(String name, int[] value) {
        return this.constant(name, value, Integer.class);
    }

    public Output<Float> constant(String name, float value) {
        return this.constant(name, value, Float.class);
    }

    private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) {
        return graph.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
    }

    private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) {
        return graph.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
    }
}