/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */
package com.chaoticity.dependensee;

/**
 *
 * @author Awais Athar
 */

import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.process.CoreLabelTokenFactory;
import edu.stanford.nlp.process.PTBTokenizer;
import edu.stanford.nlp.process.TokenizerFactory;
import edu.stanford.nlp.trees.*;

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.font.FontRenderContext;
import java.awt.font.TextLayout;
import java.awt.geom.Rectangle2D;
import java.awt.image.BufferedImage;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

public class Main {
    
    private static TreebankLanguagePack tlp = new PennTreebankLanguagePack();
    private static GrammaticalStructureFactory gsf = tlp.grammaticalStructureFactory();
    
    public static void main(String[] args) throws Exception {

        if (args.length == 2) {
            writeImage(args[0], args[1]);
        } else if (args.length == 3 && "-t".equalsIgnoreCase(args[0])) {
            writeFromTextFile(args[1], args[2]);
        } else if (args.length == 3 && "-c".equalsIgnoreCase(args[0])) {
            writeFromCONLLFile(args[1], args[2]);
        } else if (args.length == 4 && "-s".equalsIgnoreCase(args[0])) {
        	writeImage(args[2], args[3], Integer.parseInt(args[1])); 
        } else {
            printHelp();
        }
    }
    
    private static void printHelp() throws Exception {
        System.out.println("Usage: com.chaoticity.dependensee.Main <sentence> <image file>");
        System.out.println("Usage: com.chaoticity.dependensee.Main -t <input Stanford file> <image file>");
        System.out.println("Usage: com.chaoticity.dependensee.Main -c <input CoNLL file> <image file>");

    }
    
    private static Graph getGraph(Tree tree) throws Exception {
        ArrayList<TaggedWord> words = tree.taggedYield();
        GrammaticalStructure gs = gsf.newGrammaticalStructure(tree);
        Collection<TypedDependency> tdl = gs.typedDependencies();
        Graph g = new Graph(words);
        for (TypedDependency td : tdl) {
            g.addEdge(td.gov().index() - 1, td.dep().index() - 1, td.reln().toString());
        }
        try {
            g.setRoot(GrammaticalStructure.getRoots(tdl).iterator().next().gov().toString());
        } catch (Exception ex) {
            //System.err.println("Cannot find dependency graph root. Setting root to first");
            if (g.nodes.size() > 0) {
                g.setRoot(g.nodes.get(0).label);
            }
        }
        return g;
    }


    private static Graph getGraph(Collection<TypedDependency> tdl) {
        Graph g = new Graph();
        for (TypedDependency td: tdl) {
            g.addNode(td.dep().word()+"-"+td.dep().index(), td.dep().tag());
            g.addNode(td.gov().word()+"-"+td.gov().index(), td.gov().tag());
            g.addEdge(td.gov().index() - 1, td.dep().index() - 1, td.reln().toString());
        }
        return g;
    }


    public static Graph getGraph(String sentence) throws Exception {
        LexicalizedParser lp = LexicalizedParser.loadModel("edu/stanford/nlp/models/lexparser/englishPCFG.ser.gz");
        lp.setOptionFlags(new String[]{"-maxLength", "500", "-retainTmpSubcategories"});
        TokenizerFactory<CoreLabel> tokenizerFactory =
                PTBTokenizer.factory(new CoreLabelTokenFactory(), "");
        List<CoreLabel> wordList = tokenizerFactory.getTokenizer(new StringReader(sentence)).tokenize();
        Tree tree = lp.apply(wordList);
        GrammaticalStructure gs = gsf.newGrammaticalStructure(tree);
        Collection<TypedDependency> tdl = gs.typedDependencies();
        return getGraph(tree, tdl);
    }
    
    public static Graph getGraph(String sentence, LexicalizedParser lp) throws Exception {
        TokenizerFactory<CoreLabel> tokenizerFactory =
                PTBTokenizer.factory(new CoreLabelTokenFactory(), "");
        List<CoreLabel> wordList = tokenizerFactory.getTokenizer(new StringReader(sentence)).tokenize();
        Tree tree = lp.apply(wordList);
        GrammaticalStructure gs = gsf.newGrammaticalStructure(tree);
        Collection<TypedDependency> tdl = gs.typedDependencies();
        return getGraph(tree, tdl);
    }
    
    private static Graph getGraph(Tree tree, Collection<TypedDependency> tdl) throws Exception {
        ArrayList<TaggedWord> words = tree.taggedYield();
        Graph g = new Graph(words);
        for (TypedDependency td : tdl) {
            g.addEdge(td.gov().index() - 1, td.dep().index() - 1, td.reln().toString());
        }
        try {
            g.setRoot(GrammaticalStructure.getRoots(tdl).iterator().next().gov().toString());
        } catch (Exception ex) {
            //System.err.println("Cannot find dependency graph root. Setting root to first");
            if (g.nodes.size() > 0) {
                g.setRoot(g.nodes.get(0).label);
            }
        }
        return g;
    }

    private static int getNextHeight(Graph graph, Edge n) {
        int height = 3;
        boolean isFree = false;
        while (!isFree) {
            boolean overlapped = false;
            for (Edge e : graph.edges) {
                if (!e.visible || n == e) {
                    continue;
                }
                int eFirst = e.sourceIndex < e.targetIndex ? e.sourceIndex : e.targetIndex;
                int eSecond = e.sourceIndex < e.targetIndex ? e.targetIndex : e.sourceIndex;
                int nFirst = n.sourceIndex < n.targetIndex ? n.sourceIndex : n.targetIndex;
                int nSecond = n.sourceIndex < n.targetIndex ? n.targetIndex : n.sourceIndex;
                if (e.height == height
                        && ((nFirst > eFirst && nFirst < eSecond)
                        || (nSecond > eFirst && nSecond < eSecond)
                        || (eSecond > nFirst && eSecond < nSecond)
                        || (eSecond > nFirst && eSecond < nSecond)
                        || (n.targetIndex == eFirst)
                        || (n.targetIndex == eSecond))) {
                    overlapped = true;
                    //System.out.println("overlap = "+ n +" and " + e + " at height " + height);
                }
            }
            if (!overlapped) {
                isFree = true;
            } else {
                height++;
            }
            
        }
        return height;
    }
    
    public static void writeImage(String sentence, String outFile) throws Exception {
        writeImage(sentence, outFile, 1);
    }
    
    public static void writeImage(String sentence, String outFile, int scale) throws Exception {
        
        LexicalizedParser lp = null;
        try {
            lp = LexicalizedParser.loadModel("edu/stanford/nlp/models/lexparser/englishPCFG.ser.gz");
        } catch (Exception e) {
            System.err.println("Could not load file englishPCFG.ser.gz. Try placing this file in the same directory as Dependencee.jar");
            return;
        }
        
        lp.setOptionFlags(new String[]{"-maxLength", "500", "-retainTmpSubcategories"});
        TokenizerFactory<CoreLabel> tokenizerFactory =
                PTBTokenizer.factory(new CoreLabelTokenFactory(), "");
        List<CoreLabel> wordList = tokenizerFactory.getTokenizer(new StringReader(sentence)).tokenize();
        Tree tree = lp.apply(wordList);
        writeImage(tree, outFile, scale);
        
    }
    
    public static void writeImage(String sentence, String outFile, LexicalizedParser lp) throws Exception {
        
        Tree parse;
        try {
            TokenizerFactory<CoreLabel> tokenizerFactory =
                PTBTokenizer.factory(new CoreLabelTokenFactory(), "");
        List<CoreLabel> wordList = tokenizerFactory.getTokenizer(new StringReader(sentence)).tokenize();
            parse = lp.apply(wordList);            
        } catch (Exception e) {
            throw e;
        }
        writeImage(parse, outFile);
        
    }
    
    public static void writeImage(Tree tree, String outFile) throws Exception {
        writeImage(tree, outFile, 1);
    }
    
    public static void writeImage(Tree tree, Collection<TypedDependency> tdl, String outFile) throws Exception {
        Graph g = getGraph(tree, tdl);
        writeImage(g,outFile,1);
    }



    public static void writeImage(Tree tree, Collection<TypedDependency> tdl, String outFile, int scale) throws Exception {
        Graph g = getGraph(tree, tdl);
        writeImage(g,outFile,scale);
    }
    
    public static void writeImage(Tree tree, String outFile, int scale) throws Exception {
        Graph g = getGraph(tree);
        writeImage(g,outFile,scale);
    }

    public static void writeImage(Graph g, String outFile) throws Exception {
        writeImage(g,outFile,1);
    }

    public static void writeImage(Collection<TypedDependency> tdl, String outFile) throws Exception {
        writeImage(tdl,outFile,1);
    }

    public static void writeImage(Collection<TypedDependency> tdl, String outFile, int scale) throws Exception {
        Graph g = getGraph(tdl);
        writeImage(g,outFile,scale);
    }

    public static void writeImage(Graph g, String outFile, int scale) throws Exception {
        BufferedImage image = createTextImage(g, scale);
        ImageIO.write(image, "png", new File(outFile));
    }

    public static BufferedImage createTextImage(Graph graph, int scale) throws Exception {
        
        Font wordFont = new Font("Arial", Font.PLAIN, 12 * scale);
        FontRenderContext frc = new FontRenderContext(null, true, false);
        
        int spaceHeight = 20 * scale;
        int spaceWidth = 20 * scale;
        double totalWidth = spaceWidth;

        // calculate word positions
        for (Integer i : graph.nodes.keySet()) {
            Node node = graph.nodes.get(i);
            TextLayout layout = new TextLayout(node.toString(), wordFont, frc);
            Rectangle2D bounds = layout.getBounds();
            node.position.setRect(totalWidth, 0, bounds.getWidth(), bounds.getHeight());
            totalWidth += node.position.getWidth() + spaceWidth;
        }
        int imageWidth = (int) Math.ceil(totalWidth);
        int imageHeight = spaceHeight * (6 * scale + graph.nodes.size());
        int baseline = imageHeight - 30 * scale;

        // create image
        BufferedImage image = new BufferedImage(imageWidth, imageHeight, BufferedImage.TYPE_INT_RGB);
        Graphics2D g = image.createGraphics();
        g.setBackground(Color.white);
        g.clearRect(0, 0, imageWidth, imageHeight);
        g.setColor(Color.black);
        g.setFont(wordFont);
        g.setRenderingHint(RenderingHints.KEY_TEXT_ANTIALIASING, RenderingHints.VALUE_TEXT_ANTIALIAS_ON);
        g.setRenderingHint(RenderingHints.KEY_FRACTIONALMETRICS, RenderingHints.VALUE_FRACTIONALMETRICS_OFF);


        // draw words
        for (Integer i : graph.nodes.keySet()) {
            Node node = graph.nodes.get(i);
            node.position.setRect(node.position.getX(), baseline - spaceHeight, node.position.getWidth(), node.position.getHeight());
            g.drawString(node.toString(), (int) node.position.getX(), (int) node.position.getY());
        }
        
        Font posFont = new Font("Arial", Font.PLAIN, 8 * scale);
        g.setColor(Color.darkGray);
        g.setFont(posFont);
        for (Integer i : graph.nodes.keySet()) {
            Node node = graph.nodes.get(i);
            node.position.setRect(node.position.getX(), baseline - 10 * scale, node.position.getWidth(), node.position.getHeight());
            g.drawString(node.pos, (int) node.position.getX(), (int) node.position.getY());
        }
        g.setColor(Color.black);

        // draw lines
        int lineDistance = 5 * scale;
        int arrowBase = 2 * scale;
        int maxHeight = 0;
        for (Integer i : graph.nodes.keySet()) {
            Node node = graph.nodes.get(i);
            int spacer = (int) node.position.getWidth() / 2 - (node.outEdges.size() / 2 * lineDistance);
            for (Edge e : node.outEdges) {
                int height = getNextHeight(graph, e);
                if (height > maxHeight) {
                    maxHeight = height;
                }
                e.height = height;
                int targetSpacer = (int) e.target.position.getWidth() / 2 - ((e.target.outEdges.size() + 2) / 2 * lineDistance);
                // horizontal line
                g.drawLine(
                        (int) e.source.position.getX() + spacer,
                        baseline - (height * spaceHeight),
                        (int) e.target.position.getX() + targetSpacer,
                        baseline - (height * spaceHeight));

                // source vertical line
                g.drawLine(
                        (int) e.source.position.getX() + spacer,
                        baseline - (height * spaceHeight),
                        (int) e.source.position.getX() + spacer,
                        baseline - spaceHeight * 2);


                // target vertical line
                g.drawLine(
                        (int) e.target.position.getX() + targetSpacer,
                        baseline - (height * spaceHeight),
                        (int) e.target.position.getX() + targetSpacer,
                        baseline - spaceHeight * 2);

                // target arrowhead
                g.drawLine(
                        (int) e.target.position.getX() - arrowBase + targetSpacer,
                        baseline - spaceHeight * 2 - 4 * scale,
                        (int) e.target.position.getX() + targetSpacer,
                        baseline - spaceHeight * 2);
                g.drawLine(
                        (int) e.target.position.getX() + arrowBase + targetSpacer,
                        baseline - spaceHeight * 2 - 4 * scale,
                        (int) e.target.position.getX() + targetSpacer,
                        baseline - spaceHeight * 2);
                e.visible = true;
                spacer += lineDistance;
            }
            
        }

        //draw relation labels

        Font relFont = new Font("Arial", Font.PLAIN, 10 * scale);
        g.setColor(Color.blue);
        g.setFont(relFont);
        
        for (Integer i : graph.nodes.keySet()) {
            Node node = graph.nodes.get(i);
            int spacer = (int) node.position.getWidth() / 2 - (node.outEdges.size() / 2 * lineDistance);
            for (Edge e : node.outEdges) {
                int targetSpacer = (int) e.target.position.getWidth() / 2 - ((e.target.outEdges.size() + 2) / 2 * lineDistance);
                int x = (int) (e.source.position.getX() < e.target.position.getX()
                        ? e.source.position.getX() + spacer
                        : e.target.position.getX() + targetSpacer);
                TextLayout layout = new TextLayout(e.label, relFont, frc);
                Rectangle2D bounds = layout.getBounds();
                int clearWidth = (int) Math.ceil(bounds.getWidth());
                int clearHeight = (int) Math.ceil(bounds.getHeight()) + 2 * scale;
                g.clearRect(x, baseline - (e.height * spaceHeight) - clearHeight - 2 * scale,
                        clearWidth, clearHeight);
                g.drawString(e.label, x, baseline - (e.height * spaceHeight) - 3 * scale);
                
                spacer += lineDistance;
            }
        }
        
        g.dispose();
        int ystart = imageHeight - spaceHeight * (maxHeight + 3 * scale);
        return image.getSubimage(0, ystart, imageWidth, imageHeight - ystart);
    }
    
    public static void writeFromTextFile(String infile, String outfile) throws Exception {
        Graph g = new Graph();
        BufferedReader input = new BufferedReader(new FileReader(infile));
        String line = null;
        while ((line = input.readLine()) != null) {
            if ("".equals(line)) {
                continue;
            }
            int relEnd = line.indexOf("(");
            int secondWordStart = line.indexOf(", ", relEnd + 1);
            String rel = line.substring(0, relEnd);
            String gov = line.substring(relEnd + 1, secondWordStart);
            String dep = line.substring(secondWordStart + 2, line.length() - 1);
            Node govNode = g.addNode(gov, "");
            Node depNode = g.addNode(dep, "");
            g.addEdge(govNode, depNode, rel);
        }

        BufferedImage image = createTextImage(g, 1);
        ImageIO.write(image, "png", new File(outfile));
    }

    public static void writeFromCONLLFile(String infile, String outfile) throws Exception {
        Graph g = new Graph();
        BufferedReader input = new BufferedReader(new FileReader(infile));
        String line = null;
        List<Edge> tempEdges = new ArrayList<Edge>();
        while ((line = input.readLine()) != null) {
            if ("".equals(line)) break; // stop at sentence boundary
            if (line.startsWith("#")) continue; // skip comments

            String[] parts = line.split("\\s+");

            if (!parts[0].matches("^-?\\d+$")) continue; //skip ranges

            g.addNode(parts[1],Integer.parseInt(parts[0]),parts[2]);
            tempEdges.add( new Edge(
                    Integer.parseInt(parts[6])-1,
                    Integer.parseInt(parts[0])-1,
                    parts[7]));

        }
        for (Edge e: tempEdges ) {
            if (e.sourceIndex==-1 ) {
                g.setRoot(e.sourceIndex);
                continue;
            }
            g.addEdge(g.nodes.get(e.sourceIndex), g.nodes.get(e.targetIndex),e.label);
        }

        BufferedImage image = Main.createTextImage(g,1);
        ImageIO.write(image, "png", new File(outfile));
    }
}