package com.google.api.graphql.rejoiner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.inject.Provider;
import com.google.protobuf.Descriptors;
import com.google.protobuf.Message;
import graphql.schema.DataFetcher;
import graphql.schema.DataFetchingEnvironment;
import graphql.schema.GraphQLFieldDefinition;
import graphql.schema.GraphQLOutputType;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.util.stream.Stream;

/**
 * SchemaModule that generates queries and mutations for GAX gRPC clients, such as the Google Cloud
 * Platform APIs.
 */
public abstract class GaxSchemaModule extends SchemaModule {

  protected ImmutableList<GraphQLFieldDefinition> serviceToFields(
      Class<?> client, ImmutableList<String> methodWhitelist) {
    return getMethods(client, methodWhitelist)
        .map(
            methodWrapper -> {
              try {
                methodWrapper.setAccessible(true);
                /* com.google.api.gax.rpc.UnaryCallable<Req, Resp> */
                ParameterizedType callable =
                    (ParameterizedType) methodWrapper.getGenericReturnType();
                GraphQLOutputType responseType = getReturnType(callable);
                Class<? extends Message> requestMessageClass =
                    (Class<? extends Message>) callable.getActualTypeArguments()[0];
                Descriptors.Descriptor requestDescriptor =
                    (Descriptors.Descriptor)
                        requestMessageClass.getMethod("getDescriptor").invoke(null);
                Message requestMessage =
                    ((Message.Builder) requestMessageClass.getMethod("newBuilder").invoke(null))
                        .buildPartial();
                Provider<?> service = getProvider(client);

                GqlInputConverter inputConverter =
                    GqlInputConverter.newBuilder().add(requestDescriptor.getFile()).build();

                DataFetcher dataFetcher =
                    (DataFetchingEnvironment env) -> {
                      Message input =
                          inputConverter.createProtoBuf(
                              requestDescriptor,
                              requestMessage.toBuilder(),
                              env.getArgument("input"));
                      try {
                        Object callableInstance = methodWrapper.invoke(service.get());
                        Method method =
                            callableInstance.getClass().getMethod("futureCall", Object.class);
                        method.setAccessible(true);
                        Object[] methodParameterValues = new Object[] {input};
                        return method.invoke(callableInstance, methodParameterValues);
                      } catch (Exception e) {
                        throw new RuntimeException(e);
                      }
                    };

                return GraphQLFieldDefinition.newFieldDefinition()
                    .name(transformName(methodWrapper.getName()))
                    .argument(GqlInputConverter.createArgument(requestDescriptor, "input"))
                    .type(responseType)
                    .dataFetcher(dataFetcher)
                    .build();
              } catch (Exception e) {
                throw new RuntimeException(e);
              }
            })
        .collect(ImmutableList.toImmutableList());
  }

  private Stream<Method> getMethods(Class<?> clientClass, ImmutableList<String> methodWhitelist) {
    ImmutableSet<String> asyncNameWhitelist =
        methodWhitelist
            .stream()
            .map(name -> name + "Callable")
            .collect(ImmutableSet.toImmutableSet());

    return ImmutableList.copyOf(clientClass.getMethods())
        .stream()
        .filter(method -> asyncNameWhitelist.contains(method.getName()));
  }

  private GraphQLOutputType getReturnType(ParameterizedType parameterizedType)
      throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
    Class<? extends Message> responseClass =
        (Class<? extends Message>) parameterizedType.getActualTypeArguments()[1];
    Descriptors.Descriptor responseDescriptor =
        (Descriptors.Descriptor) responseClass.getMethod("getDescriptor").invoke(null);
    addExtraType(responseDescriptor);
    return ProtoToGql.getReference(responseDescriptor);
  }

  private static final int LENGTH_OF_CALLABLE = 8;

  private static String transformName(String name) {
    return name.substring(0, name.length() - LENGTH_OF_CALLABLE);
  }
}