package com.distelli.graphql;

import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import graphql.execution.batched.Batched;
import graphql.execution.batched.BatchedDataFetcher;
import graphql.schema.DataFetchingEnvironmentImpl;

import java.util.List;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.lang.reflect.Method;

public class ResolverDataFetcher implements DataFetcher {
    private DataFetcher fetcher;
    private Resolver resolver;
    private boolean isBatched;
    private int listDepth;
    public ResolverDataFetcher(DataFetcher fetcher, Resolver resolver, int listDepth) {
        this.fetcher = fetcher;
        this.resolver = resolver;
        this.listDepth = listDepth;
        if ( fetcher instanceof BatchedDataFetcher ) {
            this.isBatched = true;
        } else {
            try {
                Method getMethod = fetcher.getClass()
                    .getMethod("get", DataFetchingEnvironment.class);
                this.isBatched = null != getMethod.getAnnotation(Batched.class);
            } catch (NoSuchMethodException e) {
                throw new IllegalArgumentException(e);
            }
        }
    }

    @Batched
    @Override
    public Object get(DataFetchingEnvironment env) {
        List<Object> unresolved = new ArrayList<>();
        Object result;
        int depth = listDepth;
        if ( env.getSource() instanceof List ) { // batched.
            result = getBatched(env);
            if ( null != resolver ) addUnresolved(unresolved, result, ++depth);
        } else {
            result = getUnbatched(env);
            if ( null != resolver ) addUnresolved(unresolved, result, depth);
        }
        if ( null == resolver ) return result;
        return replaceResolved(result, resolver.resolve(unresolved).iterator(), depth);
    }

    public Object replaceResolved(Object result, Iterator<Object> resolved, int depth) {
        if ( depth <= 0 ) {
            return resolved.next();
        }
        List<Object> resolvedResults = new ArrayList<>();
        if ( null == result ) return null;
        for ( Object elm : (List)result ) {
            resolvedResults.add(replaceResolved(elm, resolved, depth-1));
        }
        return resolvedResults;
    }

    public void addUnresolved(List<Object> unresolved, Object result, int depth) {
        if ( depth <= 0 ) {
            unresolved.add(result);
            return;
        }
        if ( ! (result instanceof List) ) {
            if ( null == result ) return;
            throw new IllegalStateException("Fetcher "+fetcher+" expected to return a List for each result, got="+result);
        }
        for ( Object elm : (List)result ) {
            addUnresolved(unresolved, elm, depth-1);
        }
    }

    public Object getUnbatched(DataFetchingEnvironment env) {
        if ( ! isBatched ) {
            try {
                return fetcher.get(env);
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        }
        DataFetchingEnvironmentImpl.Builder builder = DataFetchingEnvironmentImpl.newDataFetchingEnvironment(env);

        DataFetchingEnvironment envCopy = builder.build();

        try {
            Object result = fetcher.get(envCopy);
            if ( !(result instanceof List) || ((List)result).size() != 1 ) {
                throw new IllegalStateException("Batched fetcher "+fetcher+" expected to return list of 1");
            }
            return ((List)result).get(0);
        } catch (Exception e) {
            throw new IllegalStateException(e);
        }
    }

    public List<Object> getBatched(DataFetchingEnvironment env) {
        List sources = env.getSource();
        if ( isBatched ) {
            try {
                Object result = fetcher.get(env);
                if ( !(result instanceof List) || ((List)result).size() != sources.size() ) {
                    throw new IllegalStateException("Batched fetcher "+fetcher+" expected to return list of "+sources.size());
                }
                return (List<Object>)result;
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        }
        List<Object> result = new ArrayList<>();
        for ( Object source : sources ) {
            DataFetchingEnvironmentImpl.Builder builder = DataFetchingEnvironmentImpl.newDataFetchingEnvironment(env);
            builder.source(source);
            DataFetchingEnvironment envCopy = builder.build();

            try {
                result.add(fetcher.get(envCopy));
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        }
        return result;
    }

    @Override
    public String toString() {
        return "ResolverDataFetcher{"+
            "resolver="+resolver+
            ", fetcher="+fetcher+
            ", isBatched="+isBatched+
            ", listDepth="+listDepth+
            "}";
    }
}