package gan;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import javax.imageio.ImageIO;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;

/**
 * @author hezf
 * @date 19/1/17
 */
public class ImageUtils {

    public static void save(String path, INDArray data) {
        File file = new File(path);
        File dir = file.getParentFile();
        if (!dir.exists()) {
            dir.mkdirs();
        }
        BufferedImage bufferedImage = toBufferedImage(data);//loader.asBufferedImage(picCopy.getRow(0));
        try {
            ImageIO.write(bufferedImage, "jpg", file);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static INDArray load(String path) {
        INDArray indArray0 = null;
        try {
            indArray0 = toINDArrayBGR(ImageIO.read(new File(path)));
        } catch (IOException e) {
            e.printStackTrace();
        }

        return indArray0;
    }

    public static INDArray toINDArrayBGR(BufferedImage image) {
        int height = image.getHeight();
        int width = image.getWidth();
        int bands = image.getRaster().getNumBands();

        int[] pixels = new int[width * height];
        pixels = getRGB(image, 0, 0, width, height, pixels);
        int[] shape = new int[] {bands, height, width};

        INDArray ret2 = Nd4j.create(1, width * height * bands);

        for (int y = 0; y < height; y++) {
            for (int x = 0; x < width; x++) {
                int idx = y * width + x;
                int color = pixels[idx];
                int[] argb = trimRGBColor(color);
                ret2.putScalar(idx, (argb[1]) & 0xFF);
                if (bands > 1) {
                    ret2.putScalar(idx + pixels.length, (argb[2]) & 0xFF);
                }
                if (bands > 2) {
                    ret2.putScalar(idx + pixels.length * 2, (argb[3]) & 0xFF);
                }
            }
        }
        return Nd4j.expandDims(ret2.reshape(shape), 0);
    }

    public static int[] trimRGBColor(int color) {
        int[] rgb = new int[4];
        rgb[0] = (color >>24) & 0xff;
        rgb[1] = (color >> 16) & 0xff;
        rgb[2] = (color >> 8) & 0xff;
        rgb[3] = color & 0x000000ff;

        return rgb;
    }

    public static BufferedImage toBufferedImage(INDArray data) {
        long[] shape = data.shape();
        int width = (int) shape[3];
        int heith = (int) shape[2];
        int chl = (int) shape[1];

        BufferedImage image = new BufferedImage(width, heith, BufferedImage.TYPE_3BYTE_BGR);

        int[] dataPixels = data.permute(0, 3, 2, 1).getRow(0).data().asInt();

        int[] pixels = new int[width * heith];

        for (int y = 0; y < heith; y++) {
            for (int x = 0; x < width; x++) {
                int[] rgb = new int[4];
                rgb[0] = 0xff;
                rgb[1] = dataPixels[y * width + x];
                if (chl > 1) {
                    rgb[2] = dataPixels[y * width + x + pixels.length];
                }
                if (chl > 2) {
                    rgb[3] = dataPixels[y * width + x + pixels.length * 2];
                }
                pixels[y * width + x] = rgb[0] << 24 | rgb[1] << 16 | rgb[2] << 8 | rgb[3];
            }
        }
        setRGB(image, 0, 0, width, heith, pixels);
        return image;
    }

    public static int[] getRGB( BufferedImage image, int x, int y, int width, int height, int[] pixels ) {
        int type = image.getType();
        if ( type == BufferedImage.TYPE_INT_ARGB || type == BufferedImage.TYPE_INT_RGB )
            return (int [])image.getRaster().getDataElements( x, y, width, height, pixels );
        return image.getRGB( x, y, width, height, pixels, 0, width );
    }

    public static void setRGB( BufferedImage image, int x, int y, int width, int height, int[] pixels ) {
        int type = image.getType();
        if ( type == BufferedImage.TYPE_INT_ARGB || type == BufferedImage.TYPE_INT_RGB )
            image.getRaster().setDataElements( x, y, width, height, pixels );
        else
            image.setRGB( x, y, width, height, pixels, 0, width );
    }
}