package name.falgout.jeffrey.testing.junit.mockito;

import static java.util.stream.Collectors.toList;

import com.google.common.collect.Iterables;
import java.lang.reflect.Parameter;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.api.extension.ParameterContext;
import org.junit.jupiter.api.extension.ParameterResolutionException;
import org.junit.jupiter.api.extension.ParameterResolver;
import org.junit.jupiter.api.extension.TestInstancePostProcessor;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;

public final class MockitoExtension implements TestInstancePostProcessor, AfterEachCallback,
    ParameterResolver {
  private final Set<ParameterFactory> parameterFactories;

  public MockitoExtension() {
    parameterFactories = new LinkedHashSet<>();
    parameterFactories.add(new MockParameterFactory());
    parameterFactories.add(new CaptorParameterFactory());
  }

  @Override
  public void postProcessTestInstance(Object testInstance, ExtensionContext context)
      throws Exception {
    MockitoAnnotations.initMocks(testInstance);
  }

  @Override
  public void afterEach(ExtensionContext context) throws Exception {
    Mockito.validateMockitoUsage();
  }

  @Override
  public boolean supportsParameter(ParameterContext parameterContext,
      ExtensionContext extensionContext)
      throws ParameterResolutionException {
    return getSupportedFactories(parameterContext.getParameter()).findAny().isPresent();
  }

  @Override
  public Object resolveParameter(ParameterContext parameterContext,
      ExtensionContext extensionContext)
      throws ParameterResolutionException {
    List<ParameterFactory> validFactories =
        getSupportedFactories(parameterContext.getParameter()).collect(toList());

    if (validFactories.size() > 1) {
      throw new ParameterResolutionException(
          String.format("Too many factories: %s for parameter: %s",
              validFactories,
              parameterContext.getParameter()));
    }

    return Iterables.getOnlyElement(validFactories)
        .getParameterValue(parameterContext.getParameter());
  }

  private Stream<ParameterFactory> getSupportedFactories(Parameter parameter) {
    return parameterFactories.stream().filter(factory -> factory.supports(parameter));
  }
}