/*
 * Copyright (C) 2012 Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.airlift.drift.client.guice;

import com.google.common.collect.ImmutableList;
import com.google.common.reflect.TypeParameter;
import com.google.common.reflect.TypeToken;
import com.google.inject.Binder;
import com.google.inject.Injector;
import com.google.inject.Key;
import com.google.inject.Module;
import com.google.inject.Provides;
import com.google.inject.Scopes;
import com.google.inject.TypeLiteral;
import io.airlift.configuration.ConfigDefaults;
import io.airlift.drift.client.DriftClient;
import io.airlift.drift.client.DriftClientFactory;
import io.airlift.drift.client.DriftClientFactoryManager;
import io.airlift.drift.client.ExceptionClassifier;
import io.airlift.drift.client.MethodInvocationFilter;
import io.airlift.drift.client.address.AddressSelector;
import io.airlift.drift.client.stats.JmxMethodInvocationStatsFactory;
import io.airlift.drift.client.stats.MethodInvocationStatsFactory;
import io.airlift.drift.client.stats.NullMethodInvocationStatsFactory;
import io.airlift.drift.codec.ThriftCodecManager;
import io.airlift.drift.codec.guice.ThriftCodecModule;
import io.airlift.drift.transport.client.DriftClientConfig;
import io.airlift.drift.transport.client.MethodInvokerFactory;
import org.weakref.jmx.MBeanExporter;

import javax.inject.Inject;
import javax.inject.Provider;
import javax.inject.Singleton;

import java.lang.annotation.Annotation;
import java.lang.reflect.Type;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.google.inject.multibindings.Multibinder.newSetBinder;
import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder;
import static io.airlift.configuration.ConfigBinder.configBinder;
import static io.airlift.drift.client.ExceptionClassifier.mergeExceptionClassifiers;
import static io.airlift.drift.client.guice.DriftClientAnnotationFactory.extractDriftClientBindingAnnotation;
import static io.airlift.drift.client.guice.DriftClientAnnotationFactory.getDriftClientAnnotation;
import static io.airlift.drift.codec.metadata.ThriftServiceMetadata.getThriftServiceAnnotation;
import static java.util.Objects.requireNonNull;

public class DriftClientBinder
{
    public static DriftClientBinder driftClientBinder(Binder binder)
    {
        return new DriftClientBinder(binder);
    }

    private final Binder binder;

    private DriftClientBinder(Binder binder)
    {
        this.binder = requireNonNull(binder, "binder is null").skipSources(this.getClass());
        binder.install(new ThriftCodecModule());
        binder.install(new DriftClientBinderModule());
    }

    public <T> DriftClientBindingBuilder bindDriftClient(Class<T> clientInterface)
    {
        String configPrefix = getServiceName(clientInterface);
        return bindDriftClient(clientInterface, configPrefix, DefaultClient.class);
    }

    public <T> DriftClientBindingBuilder bindDriftClient(Class<T> clientInterface, Class<? extends Annotation> annotationType)
    {
        String configPrefix = getServiceName(clientInterface);
        if (annotationType != DefaultClient.class) {
            configPrefix += "." + annotationType.getSimpleName();
        }
        return bindDriftClient(clientInterface, configPrefix, annotationType);
    }

    private <T> DriftClientBindingBuilder bindDriftClient(Class<T> clientInterface, String configPrefix, Class<? extends Annotation> annotation)
    {
        Annotation clientAnnotation = getDriftClientAnnotation(clientInterface, annotation);

        configBinder(binder).bindConfig(DriftClientConfig.class, clientAnnotation, configPrefix);

        TypeLiteral<DriftClient<T>> typeLiteral = driftClientTypeLiteral(clientInterface);

        Provider<T> instanceProvider = new DriftClientInstanceProvider<>(clientAnnotation, Key.get(typeLiteral, annotation));
        Provider<DriftClient<T>> factoryProvider = new DriftClientProvider<>(clientInterface, clientAnnotation);

        binder.bind(Key.get(clientInterface, annotation)).toProvider(instanceProvider).in(Scopes.SINGLETON);
        binder.bind(Key.get(typeLiteral, annotation)).toProvider(factoryProvider).in(Scopes.SINGLETON);

        if (annotation == DefaultClient.class) {
            binder.bind(Key.get(clientInterface)).toProvider(instanceProvider).in(Scopes.SINGLETON);
            binder.bind(Key.get(typeLiteral)).toProvider(factoryProvider).in(Scopes.SINGLETON);
        }

        return new DriftClientBindingBuilder(binder, clientAnnotation, configPrefix);
    }

    public <T> void bindClientConfigDefaults(Class<T> clientInterface, ConfigDefaults<DriftClientConfig> configDefaults)
    {
        bindClientConfigDefaults(clientInterface, DefaultClient.class, configDefaults);
    }

    public <T> void bindClientConfigDefaults(Class<T> clientInterface, Class<? extends Annotation> annotationType, ConfigDefaults<DriftClientConfig> configDefaults)
    {
        bindConfigDefaults(clientInterface, annotationType, DriftClientConfig.class, configDefaults);
    }

    public <T, C> void bindConfigDefaults(Class<T> clientInterface, Class<C> configClass, ConfigDefaults<C> configDefaults)
    {
        bindConfigDefaults(configClass, DefaultClient.class, configClass, configDefaults);
    }

    public <T, C> void bindConfigDefaults(Class<T> clientInterface, Class<? extends Annotation> annotationType, Class<C> configClass, ConfigDefaults<C> configDefaults)
    {
        configBinder(binder).bindConfigDefaults(configClass, getDriftClientAnnotation(clientInterface, annotationType), configDefaults);
    }

    private static String getServiceName(Class<?> clientInterface)
    {
        requireNonNull(clientInterface, "clientInterface is null");
        String serviceName = getThriftServiceAnnotation(clientInterface).value();
        if (!serviceName.isEmpty()) {
            return serviceName;
        }
        return clientInterface.getSimpleName();
    }

    @SuppressWarnings("unchecked")
    private static <T> TypeLiteral<DriftClient<T>> driftClientTypeLiteral(Class<T> clientInterface)
    {
        Type javaType = new TypeToken<DriftClient<T>>() {}
                .where(new TypeParameter<T>() {}, TypeToken.of(clientInterface))
                .getType();
        return (TypeLiteral<DriftClient<T>>) TypeLiteral.get(javaType);
    }

    private static class DriftClientInstanceProvider<T>
            extends AbstractAnnotatedProvider<T>
    {
        private final Key<DriftClient<T>> key;

        public DriftClientInstanceProvider(Annotation annotation, Key<DriftClient<T>> key)
        {
            super(annotation);
            this.key = requireNonNull(key, "key is null");
        }

        @Override
        protected T get(Injector injector, Annotation annotation)
        {
            return injector.getInstance(key).get();
        }
    }

    private static class DriftClientProvider<T>
            extends AbstractAnnotatedProvider<DriftClient<T>>
    {
        private static final TypeLiteral<DriftClientFactoryManager<Annotation>> DRIFT_CLIENT_FACTORY_MANAGER_TYPE = new TypeLiteral<DriftClientFactoryManager<Annotation>>() {};
        private static final TypeLiteral<Set<MethodInvocationFilter>> SET_METHOD_INVOCATION_FILTERS_TYPE = new TypeLiteral<Set<MethodInvocationFilter>>() {};
        private static final TypeLiteral<Set<ExceptionClassifier>> SET_EXCEPTION_CLASSIFIER_TYPE = new TypeLiteral<Set<ExceptionClassifier>>() {};

        private final Class<T> clientInterface;

        public DriftClientProvider(Class<T> clientInterface, Annotation annotation)
        {
            super(annotation);
            this.clientInterface = requireNonNull(clientInterface, "clientInterface is null");
        }

        @Override
        protected DriftClient<T> get(Injector injector, Annotation clientAnnotation)
        {
            DriftClientConfig config = injector.getInstance(Key.get(DriftClientConfig.class, clientAnnotation));
            DriftClientFactoryManager<Annotation> driftClientFactoryManager = injector.getInstance(Key.get(DRIFT_CLIENT_FACTORY_MANAGER_TYPE));

            AddressSelector<?> addressSelector = injector.getInstance(Key.get(AddressSelector.class, clientAnnotation));

            ExceptionClassifier exceptionClassifier = mergeExceptionClassifiers(ImmutableList.<ExceptionClassifier>builder()
                    .addAll(injector.getInstance(Key.get(SET_EXCEPTION_CLASSIFIER_TYPE, clientAnnotation))) // per-client
                    .addAll(injector.getInstance(Key.get(SET_EXCEPTION_CLASSIFIER_TYPE))) // global
                    .build());

            List<MethodInvocationFilter> filters = ImmutableList.copyOf(injector.getInstance(Key.get(SET_METHOD_INVOCATION_FILTERS_TYPE, clientAnnotation)));

            DriftClientFactory driftClientFactory = driftClientFactoryManager.createDriftClientFactory(clientAnnotation, addressSelector, exceptionClassifier);
            return driftClientFactory.createDriftClient(clientInterface, extractDriftClientBindingAnnotation(clientAnnotation), filters, config);
        }
    }

    private static class DefaultMethodInvocationStatsFactoryProvider
            implements Provider<MethodInvocationStatsFactory>
    {
        private final Optional<MBeanExporter> mbeanExporter;

        @Inject
        public DefaultMethodInvocationStatsFactoryProvider(Optional<MBeanExporter> mbeanExporter)
        {
            this.mbeanExporter = mbeanExporter;
        }

        @Override
        public MethodInvocationStatsFactory get()
        {
            return mbeanExporter
                    .map(JmxMethodInvocationStatsFactory::new)
                    .map(MethodInvocationStatsFactory.class::cast)
                    .orElseGet(NullMethodInvocationStatsFactory::new);
        }
    }

    private static class DriftClientBinderModule
            implements Module
    {
        @Override
        public void configure(Binder binder)
        {
            newSetBinder(binder, ExceptionClassifier.class);
            newOptionalBinder(binder, MBeanExporter.class);
            newOptionalBinder(binder, MethodInvocationStatsFactory.class)
                    .setDefault()
                    .toProvider(DefaultMethodInvocationStatsFactoryProvider.class)
                    .in(Scopes.SINGLETON);
        }

        @Provides
        @Singleton
        private static DriftClientFactoryManager<Annotation> getDriftClientFactory(
                ThriftCodecManager codecManager,
                MethodInvokerFactory<Annotation> methodInvokerFactory,
                MethodInvocationStatsFactory methodInvocationStatsFactory)
        {
            return new DriftClientFactoryManager<>(codecManager, methodInvokerFactory, methodInvocationStatsFactory);
        }

        @Override
        public boolean equals(Object o)
        {
            if (this == o) {
                return true;
            }
            return o != null && getClass() == o.getClass();
        }

        @Override
        public int hashCode()
        {
            return getClass().hashCode();
        }
    }
}