/*
 * To change this license header, choose License Headers in Project Properties.
 * To change this template file, choose Tools | Templates
 * and open the template in the editor.
 */
package org.drools.devguide.phreakinspector.model;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.io.IOUtils;
import org.drools.core.base.ClassObjectType;
import org.drools.core.base.dataproviders.MVELDataProvider;
import org.drools.core.common.EmptyBetaConstraints;
import org.drools.core.impl.KnowledgeBaseImpl;
import org.drools.core.reteoo.AccumulateNode;
import org.drools.core.reteoo.AlphaNode;
import org.drools.core.reteoo.EntryPointNode;
import org.drools.core.reteoo.FromNode;
import org.drools.core.reteoo.JoinNode;
import org.drools.core.reteoo.LeftInputAdapterNode;
import org.drools.core.reteoo.LeftTupleSink;
import org.drools.core.reteoo.LeftTupleSinkPropagator;
import org.drools.core.reteoo.NotNode;
import org.drools.core.reteoo.ObjectSink;
import org.drools.core.reteoo.ObjectTypeNode;
import org.drools.core.reteoo.QueryElementNode;
import org.drools.core.reteoo.QueryTerminalNode;
import org.drools.core.reteoo.Rete;
import org.drools.core.reteoo.RightInputAdapterNode;
import org.drools.core.reteoo.RuleTerminalNode;
import org.drools.core.reteoo.Sink;
import org.drools.core.rule.Declaration;
import org.drools.core.rule.EntryPointId;
import org.drools.core.rule.constraint.MvelConstraint;
import org.drools.core.spi.BetaNodeFieldConstraint;
import org.drools.core.spi.ObjectType;
import org.kie.api.KieBase;
import org.kie.api.KieServices;
import org.kie.api.builder.Message;
import org.kie.api.builder.Results;
import org.kie.api.io.Resource;
import org.kie.api.io.ResourceType;
import org.kie.api.runtime.KieContainer;
import org.kie.internal.utils.KieHelper;
import org.stringtemplate.v4.ST;

/**
 *
 * @author esteban
 */
public class PhreakInspector {

    private final Map<Integer, Node> nodes = new HashMap<>();

    public InputStream fromClassPathKieContainer(String kieBaseName) throws IOException {
        return this.fromKieBase(this.createContainer().getKieBase(kieBaseName));
    }

    public InputStream fromResources(Map<Resource, ResourceType> resources) throws IOException {
        return this.fromKieBase(this.buildKieBase(resources));
    }

    public InputStream fromKieBase(KieBase kb) throws IOException {

        KnowledgeBaseImpl kbase = (KnowledgeBaseImpl) kb;

        Rete rete = kbase.getRete();

        Map<EntryPointId, EntryPointNode> entryPointNodes = rete.getEntryPointNodes();
        for (EntryPointNode value : entryPointNodes.values()) {

            Node epNode = new Node(value.getId(), value.getEntryPoint().getEntryPointId(), Node.TYPE.ENTRY_POINT);
            nodes.put(epNode.getId(), epNode);

            Map<ObjectType, ObjectTypeNode> objectTypeNodes = value.getObjectTypeNodes();
            for (ObjectTypeNode otn : objectTypeNodes.values()) {

                String nodeLabel = "";
                if (otn.getObjectType() instanceof ClassObjectType) {
                    nodeLabel = ((ClassObjectType) otn.getObjectType()).getClassName();
                } else {
                    nodeLabel = otn.getObjectType().toString();
                }

                Node otNode = new Node(otn.getId(), nodeLabel, Node.TYPE.OBJECT_TYPE);
                nodes.put(otNode.getId(), otNode);
                epNode.addTargetNode(otNode.getId());

                ObjectSink[] sinks = otn.getSinkPropagator().getSinks();
                for (ObjectSink sink : sinks) {
                    this.visitObjectSink(sink, otNode);
                }

            }
        }

//        //Segments
//        InternalWorkingMemory wm = ((InternalWorkingMemory)kbase.newStatefulKnowledgeSession());
//        for (EntryPointNode value : entryPointNodes.values()) {
//            Map<ObjectType, ObjectTypeNode> objectTypeNodes = value.getObjectTypeNodes();
//            for (ObjectTypeNode otn : objectTypeNodes.values()) {
//                if (otn.getSinkPropagator().getSinks().length == 0){
//                    continue;
//                }
//                ObjectSink tmp = otn.getSinkPropagator().getSinks()[0];
//                wm.getNodeMemory(tmp);
//                
//                LeftInputAdapterNode liaNode = (LeftInputAdapterNode) otn.getSinkPropagator().getSinks()[0];
//
//                LeftInputAdapterNode.LiaNodeMemory liaMem = ( LeftInputAdapterNode.LiaNodeMemory ) wm.getNodeMemory( liaNode ); 
//
//                System.out.println("Found a memory");
//            }
//        }
        return this.generateGraphViz(nodes);
    }

    private void visitObjectSink(ObjectSink oSink, Node parentNode) {
        this.visitSink(oSink, parentNode);
    }

    private void visitLeftTupleSink(LeftTupleSink ltSink, Node parentNode) {
        this.visitSink(ltSink, parentNode);
    }

    private void visitSink(Sink sink, Node parentNode) {
        if (sink instanceof LeftInputAdapterNode) {
            LeftInputAdapterNode lian = (LeftInputAdapterNode) sink;
            this.visitLeftInputAdapterNode(lian, parentNode);
        } else if (sink instanceof RightInputAdapterNode) {
            RightInputAdapterNode rian = (RightInputAdapterNode) sink;
            this.visitRightInputAdapterNode(rian, parentNode);
        } else if (sink instanceof AlphaNode) {
            AlphaNode alpha = (AlphaNode) sink;
            this.visitAlphaNode(alpha, parentNode);
        } else if (sink instanceof JoinNode) {
            JoinNode join = (JoinNode) sink;
            this.visitBetaNode(join, parentNode);
        } else if (sink instanceof NotNode) {
            NotNode not = (NotNode) sink;
            this.visitNotNode(not, parentNode);
        } else if (sink instanceof QueryElementNode) {
            QueryElementNode qen = (QueryElementNode) sink;
            this.visitQueryElementNode(qen, parentNode);
        } else if (sink instanceof AccumulateNode) {
            AccumulateNode acc = (AccumulateNode) sink;
            this.visitAccumulateNode(acc, parentNode);
        } else if (sink instanceof RuleTerminalNode) {
            RuleTerminalNode rt = (RuleTerminalNode) sink;
            this.visitRuleTerminalNode(rt, parentNode);
        } else if (sink instanceof QueryTerminalNode) {
            QueryTerminalNode qt = (QueryTerminalNode) sink;
            this.visitQueryTerminalNode(qt, parentNode);
        } else if (sink instanceof FromNode) {
            FromNode from = (FromNode) sink;
            this.visitFromNode(from, parentNode);
        } else {
            throw new UnsupportedOperationException(sink.toString());
        }
    }

    private void visitLeftInputAdapterNode(LeftInputAdapterNode lian, Node parentNode) {
        LeftTupleSink[] ltSinks = lian.getSinkPropagator().getSinks();
        for (LeftTupleSink ltSink : ltSinks) {
            visitLeftTupleSink(ltSink, parentNode);
        }
    }

    private void visitRightInputAdapterNode(RightInputAdapterNode rian, Node parentNode) {
        ObjectSink[] oSinks = rian.getSinkPropagator().getSinks();
        for (ObjectSink oSink : oSinks) {
            this.visitObjectSink(oSink, parentNode);
        }
    }

    private void visitAlphaNode(AlphaNode alpha, Node parentNode) {
        Node alphaNode = new Node(alpha.getId(), alpha.getConstraint().toString(), Node.TYPE.ALPHA);
        nodes.put(alphaNode.getId(), alphaNode);
        parentNode.addTargetNode(alphaNode.getId());

        ObjectSink[] oSinks = alpha.getSinkPropagator().getSinks();
        for (ObjectSink oSink2 : oSinks) {
            this.visitObjectSink(oSink2, alphaNode);
        }
    }

    private void visitBetaNode(JoinNode join, Node parentNode) {
        Node betaNode = new Node(join.getId(), this.createConstraintsString(join), Node.TYPE.BETA);
        nodes.put(betaNode.getId(), betaNode);
        parentNode.addTargetNode(betaNode.getId());

        LeftTupleSinkPropagator ltsp = join.getSinkPropagator();
        LeftTupleSink[] sinks = ltsp.getSinks();
        for (LeftTupleSink ltSink : sinks) {
            visitLeftTupleSink(ltSink, betaNode);
        }
    }

    private void visitFromNode(FromNode from, Node parentNode) {
        Node fromNode = new Node(from.getId(), this.createConstraintsString(from), Node.TYPE.FROM);
        nodes.put(fromNode.getId(), fromNode);
        parentNode.addTargetNode(fromNode.getId());

        LeftTupleSinkPropagator ltsp = from.getSinkPropagator();
        LeftTupleSink[] sinks = ltsp.getSinks();
        for (LeftTupleSink ltSink : sinks) {
            visitLeftTupleSink(ltSink, fromNode);
        }
    }

    private void visitNotNode(NotNode not, Node parentNode) {
        Node notNode = new Node(not.getId(), not.toString(), Node.TYPE.NOT);
        nodes.put(notNode.getId(), notNode);
        parentNode.addTargetNode(notNode.getId());

        LeftTupleSinkPropagator ltsp = not.getSinkPropagator();
        LeftTupleSink[] sinks = ltsp.getSinks();
        for (LeftTupleSink ltSink : sinks) {
            visitLeftTupleSink(ltSink, notNode);
        }
    }

    private void visitQueryElementNode(QueryElementNode qen, Node parentNode) {
        Node queryNode = new Node(qen.getId(), qen.toString(), Node.TYPE.QUERY_ELEMENT);
        nodes.put(queryNode.getId(), queryNode);
        parentNode.addTargetNode(queryNode.getId());

        LeftTupleSinkPropagator ltsp = qen.getSinkPropagator();
        LeftTupleSink[] sinks = ltsp.getSinks();
        for (LeftTupleSink ltSink : sinks) {
            visitLeftTupleSink(ltSink, queryNode);
        }
    }

    private void visitAccumulateNode(AccumulateNode an, Node parentNode) {
        Node accNode = new Node(an.getId(), an.toString(), Node.TYPE.ACCUMULATE);
        nodes.put(accNode.getId(), accNode);
        parentNode.addTargetNode(accNode.getId());

        LeftTupleSinkPropagator ltsp = an.getSinkPropagator();
        LeftTupleSink[] sinks = ltsp.getSinks();
        for (LeftTupleSink ltSink : sinks) {
            visitLeftTupleSink(ltSink, accNode);
        }
    }

    private void visitRuleTerminalNode(RuleTerminalNode rtn, Node parentNode) {
        Node rtNode = new Node(rtn.getId(), rtn.getRule().getName(), Node.TYPE.RULE_TERMINAL);
        nodes.put(rtNode.getId(), rtNode);
        parentNode.addTargetNode(rtNode.getId());
    }

    private void visitQueryTerminalNode(QueryTerminalNode qtn, Node parentNode) {
        Node qtNode = new Node(qtn.getId(), qtn.getRule().getName(), Node.TYPE.QUERY_TERMINAL);
        nodes.put(qtNode.getId(), qtNode);
        parentNode.addTargetNode(qtNode.getId());
    }

    private String createConstraintsString(JoinNode joinNode) {
        String result = "";
        BetaNodeFieldConstraint[] constraints = joinNode.getConstraints();
        if (constraints == null) {
            return result;
        }

        for (BetaNodeFieldConstraint constraint : constraints) {
            if (constraint instanceof EmptyBetaConstraints) {
                //do nothing
            } else if (constraint instanceof MvelConstraint) {
                result = ((MvelConstraint) constraint).getExpression() + ",";
            }
        }

        return result;
    }

    private String createConstraintsString(FromNode fromNode) {
        String result = "<from expression here>";
        MVELDataProvider provider = (MVELDataProvider) fromNode.getDataProvider();
        if (provider == null) {
            return result;
        }
        
       

        return result;
    }

    private InputStream generateGraphViz(Map<Integer, Node> nodes) throws IOException {

        String template = IOUtils.toString(PhreakInspector.class
                .getResourceAsStream("/templates/viz.template"));
        ST st = new ST(template, '$', '$');
        st.add("items", nodes.values());

        Map<String, List<Node>> itemsByGroup = nodes.values().stream().collect(Collectors.groupingBy(n -> n.getType().getGroup()));
        st.add("itemsByGroup", itemsByGroup);

        return new ByteArrayInputStream(st.render().getBytes());

    }

    private KieContainer createContainer() {

        KieServices ks = KieServices.Factory.get();
        KieContainer kContainer = ks.getKieClasspathContainer();

        this.assertBuildResults(kContainer.verify());

        return kContainer;
    }

    private KieBase buildKieBase(Map<Resource, ResourceType> resources) {
        KieHelper kieHelper = new KieHelper();

        for (Map.Entry<Resource, ResourceType> entrySet : resources.entrySet()) {
            kieHelper.addResource(entrySet.getKey(), entrySet.getValue());
        }

        this.assertBuildResults(kieHelper.verify());

        return kieHelper.build();
    }

    private void assertBuildResults(Results results) {
        if (results.hasMessages(Message.Level.WARNING, Message.Level.ERROR)) {
            List<Message> messages = results.getMessages(Message.Level.WARNING, Message.Level.ERROR);
            for (Message message : messages) {
                System.out.printf("[%s] - %s[%s,%s]: %s", message.getLevel(), message.getPath(), message.getLine(), message.getColumn(), message.getText());
            }

            throw new IllegalStateException("Compilation errors were found. Check the logs.");
        }
    }
}