package org.apdplat.data.generator.mysql2es;

import com.alibaba.fastjson.JSON;
import org.apdplat.data.generator.mysql.MySQLUtils;
import org.apdplat.data.generator.utils.Config;
import org.apache.http.HttpHost;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestHighLevelClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStreamWriter;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicLong;

/**
 * Created by ysc on 20/04/2018.
 */
public abstract class Output {
    private static final Logger LOGGER = LoggerFactory.getLogger(Output.class);

    private static final String HOST = (Config.getStringValue("es.host") == null ? "192.168.252.193" : Config.getStringValue("es.host"));
    private static final String PORT = (Config.getStringValue("es.port") == null ? "9200" : Config.getStringValue("es.port"));
    private static final String MODE = (Config.getStringValue("es.mode") == null ? "es" : Config.getStringValue("es.mode"));
    protected static final int BATCH_SIZE = Config.getIntValue("es.batchSize") == -1 ? 5 : Config.getIntValue("es.batchSize");
    private static final int MYSQL_PAGE_SIZE = Config.getIntValue("mysql.pageSize") == -1 ? 1000 : Config.getIntValue("mysql.pageSize");
    protected static final int START_PAGE = Config.getIntValue("output.start.page") == -1 ? 0 : Config.getIntValue("output.start.page");
    protected static final String ASYNC_OUTPUT = Config.getStringValue("output.async") == null ? "true" : Config.getStringValue("output.async");
    private static final int THREAD_COUNT = Config.getIntValue("output.async.thread.count") == -1 ? 1 : Config.getIntValue("output.async.thread.count");
    private static final ExecutorService EXECUTOR_SERVICE = Executors.newCachedThreadPool();
    private static final BlockingQueue<Map<String, Object>> BLOCKING_QUEUE = new LinkedBlockingQueue<>(THREAD_COUNT);
    private static volatile boolean running = true;

    private static final AtomicLong COUNT = new AtomicLong();

    private static final RestHighLevelClient CLIENT = new RestHighLevelClient(
            RestClient.builder(
                    new HttpHost(HOST, Integer.parseInt(PORT), "http")));

    static {
        if (!"file".equals(MODE) && "true".equalsIgnoreCase(ASYNC_OUTPUT)) {
            for (int i = 0; i < THREAD_COUNT; i++) {
                EXECUTOR_SERVICE.submit(() -> {
                    while (running) {
                        try {
                            Map<String, Object> map = BLOCKING_QUEUE.take();
                            if (map.get("data") == null) {
                                running = false;
                                break;
                            }
                            List<Map<String, Object>> data = (List<Map<String, Object>>) map.get("data");
                            String index = (String) map.get("index");
                            String type = (String) map.get("type");
                            output(index, type, null, data);
                        } catch (Exception e) {
                            LOGGER.error("获取数据异常", e);
                        }
                    }
                });
            }
        }
    }

    protected void writeBatch(String index, String type, BufferedWriter bufferedWriter, List<Map<String, Object>> list) throws Exception{
        if("true".equalsIgnoreCase(ASYNC_OUTPUT)){
            List<Map<String, Object>> newList = new ArrayList<>(list.size());
            newList.addAll(list);
            Map<String, Object> map = new HashMap<>();
            map.put("data", newList);
            map.put("index", index);
            map.put("type", type);
            BLOCKING_QUEUE.put(map);
            list.clear();
        }else{
            output(index, type, bufferedWriter, list);
        }
    }

    private static void output(String index, String type, BufferedWriter bufferedWriter, List<Map<String, Object>> list) {
        try {
            if ("file".equals(MODE)) {
                writeBatchToFile(index, type, bufferedWriter, list);
            } else {
                writeBatchToES(index, type, list);
            }
            list.forEach(item -> {
                if (item.get("geo_location") != null) {
                    ((Map<String, Float>) item.get("geo_location")).clear();
                }
                item.clear();
            });
            list.clear();
        }catch (Exception e){
            LOGGER.error("数据输出异常", e);
            throw new RuntimeException(e);
        }
    }

    private static void writeBatchToES(String index, String type, List<Map<String, Object>> list) throws Exception{
        if(list.isEmpty()){
            return;
        }
        BulkRequest request = new BulkRequest();
        for(Map<String, Object> data : list) {
            String id = data.get("id").toString();
            request.add(
                    new IndexRequest(index, type, id)
                            .source(data));

        }
        BulkResponse bulkResponse = CLIENT.bulk(request);
        if (bulkResponse.hasFailures()) {
            for (BulkItemResponse bulkItemResponse : bulkResponse) {
                if (bulkItemResponse.isFailed()) {
                    BulkItemResponse.Failure failure = bulkItemResponse.getFailure();
                    LOGGER.error("ES索引失败: {}", failure.getMessage());
                }
            }
        }
    }

    private static void writeBatchToFile(String index, String type, BufferedWriter bufferedWriter, List<Map<String, Object>> list) throws Exception{
        StringBuilder batchJson = new StringBuilder();
        for(Map<String, Object> data : list) {
            String json = JSON.toJSONString(data);
            String id = data.get("id").toString();
            batchJson.append("\n{ \"index\":{ \"_id\": \""+id+"\"} }\n").append(json).append("\n");
        }

        String command = "curl -H \"Content-Type: application/json\" -XPUT 'http://"+HOST+":"+PORT+"/"+index+"/"+type+"/_bulk' -d '"+batchJson.toString()+"';";

        bufferedWriter.write(command+"\n");
        bufferedWriter.flush();
    }

    protected void generateCommand(String table, String sql, String index, String type, String shellFileName) {
        Connection con = MySQLUtils.getConnection();
        if(con == null){
            return ;
        }
        PreparedStatement pst = null;
        ResultSet rs = null;
        List<Map<String, Object>> list = new ArrayList<>();
        try(BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(new File(shellFileName)), "utf-8"))) {
            int count = MySQLUtils.getCount(table);
            int page = count / MYSQL_PAGE_SIZE;
            if(count % MYSQL_PAGE_SIZE > 0){
                page++;
            }
            LOGGER.info("表纪录数量: {}, 页面大小: {}, 总页数: {}", count, MYSQL_PAGE_SIZE, page);
            for(int i=START_PAGE; i<page; i++) {
                String join = "";
                if(sql.contains("where")){
                    join = " and ";
                }else{
                    join = " where ";
                }
                String sqlWithPage = sql + join + table + ".id > " + i*MYSQL_PAGE_SIZE + " and " + table +".id <= " + (i*MYSQL_PAGE_SIZE + MYSQL_PAGE_SIZE) +  ";";
                processPage(index, type, sqlWithPage, bufferedWriter, con, pst, rs, list);
            }
            writeBatch(index, type, bufferedWriter, list);
            bufferedWriter.flush();
            if (!"file".equals(MODE)) {
                for (int i = 0; i < THREAD_COUNT; i++) {
                    Map<String, Object> map = new HashMap<>();
                    map.put("data", null);
                    map.put("index", null);
                    map.put("type", null);
                    BLOCKING_QUEUE.put(map);
                }
            }
        } catch (Exception e) {
            LOGGER.error("查询失败", e);
        } finally {
            MySQLUtils.close(con, pst, rs);
        }
    }

    private void processPage(String index, String type, String sql, BufferedWriter bufferedWriter, Connection con, PreparedStatement pst, ResultSet rs, List<Map<String, Object>> list){
        try {
            LOGGER.info("开始查询, SQL: {}", sql);
            pst = con.prepareStatement(sql);
            rs = pst.executeQuery();
            LOGGER.info("查询结束, 开始处理数据");
            while (rs.next()) {
                Map<String, Object> row = getRow(rs);
                list.add(row);
                COUNT.incrementAndGet();
                if(COUNT.get() % 1000 == 0) {
                    LOGGER.info("已写: {}", COUNT.get());
                }
                if(list.size() % BATCH_SIZE == 0) {
                    writeBatch(index, type, bufferedWriter, list);
                }
            }
        }catch (Exception e){
            LOGGER.error("处理页面异常", e);
        }
    }

    protected abstract Map<String, Object> getRow(ResultSet rs);

    public abstract void run();
}