package searcher.api.expand;

import org.neo4j.driver.v1.Record;
import org.neo4j.driver.v1.Session;
import org.neo4j.driver.v1.StatementResult;
import searcher.api.ApiLocatorContext;
import searcher.api.SubGraph;
import searcher.index.LuceneSearchResult;
import utils.VectorUtils;

import java.util.List;
import java.util.Set;

public class ExpandFlossNodes {

    public static SubGraph run(String query, SubGraph searchResult1, SubGraph linkedSearchResult1,  ApiLocatorContext context){
        SubGraph r = new SubGraph();
        r.getNodes().addAll(searchResult1.getNodes());
        r.cost = searchResult1.cost;
        List<LuceneSearchResult> luceneSearchResults=context.getLuceneSearcher().query(query);
        luceneSearchResults.sort((LuceneSearchResult a, LuceneSearchResult b)->{
            Double aDist=new Double(dist(a.nodeSet,searchResult1.getNodes(),context));
            Double bDist=new Double(dist(b.nodeSet,searchResult1.getNodes(),context));
            return aDist.compareTo(bDist);
        });
        for (int i=0;i<3&&i<luceneSearchResults.size();i++) {
            r.getNodes().add(luceneSearchResults.get(i).id);
            for (long node:linkedSearchResult1.getNodes()){
                Session session=context.connection.session();
                StatementResult rs=session.run("match (a)-[r]-(b) where id(a)="+node+" and id(b)="+luceneSearchResults.get(i).id+" return id(r)");
                while (rs.hasNext()){
                    Record item=rs.next();
                    r.getEdges().add(item.get("id(r)").asLong());
                }
                session.close();
            }
        }
        return r;
    }

    public static double dist(Set<Long> nodeSet1, Set<Long> nodeSet2, ApiLocatorContext context){
        double r=0;
        double c=0;
        for (long id1:nodeSet1){
            if (!context.id2Vec.containsKey(id1))
                continue;
            c++;
            double minDist=Double.MAX_VALUE;
            for (long id2:nodeSet2){
                if (!context.id2Vec.containsKey(id2))
                    continue;
                double dist=dist(id1, id2,context);
                if (dist<minDist)
                    minDist=dist;
            }
            if (minDist!=Double.MAX_VALUE)
                r+=minDist;
            else
                return Double.MAX_VALUE;
        }
        if (c==0)
            return Double.MAX_VALUE;
        return r/c;
    }

    private static double dist(long node1, long node2, ApiLocatorContext context){
        return VectorUtils.dist(node1,node2,context.id2Vec);
    }

}