package site.acsi.baidu.dog.service.impl;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.primitives.Doubles;
import com.google.common.primitives.Ints;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import org.springframework.stereotype.Service;
import site.acsi.baidu.dog.service.IVerCodeParseService;
import site.acsi.baidu.dog.util.ImageUtils;

import javax.annotation.Resource;
import java.awt.image.BufferedImage;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.net.URL;
import java.util.*;

/**
 * @author Acsi
 * @date 2018/2/9
 */
@Service("local")
public class LocalVerCodeParseServiceImpl implements IVerCodeParseService {

    @Resource
    private ImageUtils imageUtils;

    private Map<Integer, String> labels = Maps.newHashMapWithExpectedSize(36);

    private svm_model model;

    private static final int CHAR_NUM = 36;

    public LocalVerCodeParseServiceImpl() throws IOException {
        String labelName = "1234567890abcdefghijklmnopqrstuvwxyz";
        for (int i = 0; i < CHAR_NUM; i++) {
            labels.put(i + 1, String.valueOf(labelName.charAt(i)));
        }
//        URL url = getClass().getClassLoader().getResource("svm.model");
//        Preconditions.checkNotNull(url, "无法找到svm模型");
//        model = svm.svm_load_model(new BufferedReader(new FileReader(url.getFile())));
    }

    @Override
    public String predict(String imgData) throws IOException {
        // 转化成bufferedImage
        BufferedImage image = imageUtils.convertBase64DataToBufferedImage(imgData);
        // 图像预处理
        image = imageUtils.preProcess(image);
        // 切割
        List<BufferedImage> subImgs = imageUtils.cfs(image);
        // 过滤
        List<BufferedImage> filterImgs = filter(subImgs);
        // 转成svm格式数据
        List<String> svmTest = formatSvm(filterImgs);
        // 预测
        StringBuilder result = doPredict(svmTest);

        return result.toString();
    }

    private StringBuilder doPredict(List<String> svmTest) {
        StringBuilder result = new StringBuilder();
        for(String line : svmTest) {
            StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:");
            Double target = Doubles.tryParse(st.nextToken());
            Preconditions.checkNotNull(target);
            int m = st.countTokens() / 2;
            svm_node[] x = new svm_node[m];
            for (int j = 0; j < m; j++) {
                x[j] = new svm_node();
                Integer index = Ints.tryParse(st.nextToken());
                Double value = Doubles.tryParse(st.nextToken());
                Preconditions.checkNotNull(index);
                Preconditions.checkNotNull(value);
                x[j].index = index;
                x[j].value = value;
            }

            double v = svm.svm_predict(model, x);
            result.append(labels.get((int)v));

        }
        return result;
    }

    private List<String> formatSvm(List<BufferedImage> filterImgs) {
        List<String> svmTest = Lists.newArrayList();
        for (BufferedImage img : filterImgs) {
            int width = img.getWidth();
            int height = img.getHeight();
            int index = 1;
            // 默认无标号,则为-1
            StringBuilder tmpLine = new StringBuilder("-1 ");
            for (int y = 0; y < height; y++) {
                for (int x = 0; x < width; x++) {
                    // 黑色点标记为1
                    int value = imageUtils.isBlack(img.getRGB(x, y)) ? 1 : 0;
                    tmpLine.append(index).append(":").append(value).append(" ");
                    index ++;
                }
            }
            svmTest.add(tmpLine + "\r\n");
        }
        return svmTest;
    }

    private List<BufferedImage> filter(List<BufferedImage> imgs) {
        List<BufferedImage> filterSortedList = new ArrayList<>();
        filterSortedList.addAll(imgs);
        filterSortedList.sort(Comparator.comparingInt(img -> -(img.getWidth() * img.getHeight())));

        filterSortedList = filterSortedList.subList(0, 4);
        List<BufferedImage> imageList = new ArrayList<>();
        List<BufferedImage> finalFilterSortedList = filterSortedList;
        imgs.forEach(img -> {
            if (finalFilterSortedList.contains(img)) {
                imageList.add(img);
            }
        });
        return imageList;
    }
}