/*
 * Copyright 2018 Tinkoff Bank
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package ru.tinkoff.eclair.printer.processor;

import org.springframework.beans.BeanUtils;
import org.springframework.core.io.Resource;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternResolver;
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
import org.springframework.core.type.classreading.MetadataReaderFactory;
import org.springframework.oxm.jaxb.Jaxb2Marshaller;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.SystemPropertyUtils;

import javax.xml.bind.JAXBElement;
import javax.xml.bind.annotation.XmlElementDecl;
import javax.xml.bind.annotation.XmlRegistry;
import javax.xml.bind.annotation.XmlRootElement;
import java.io.IOException;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import static java.util.Objects.isNull;
import static java.util.Objects.nonNull;
import static org.springframework.util.StringUtils.hasText;

/**
 * @author Vyacheslav Klapatnyuk
 */
public class JaxbElementWrapper implements PrinterPreProcessor {

    /**
     * Incorrect method for empty cache stub.
     */
    private static final Method EMPTY_METHOD = BeanUtils.findMethod(JaxbElementWrapper.class, "process", Object.class);

    private static final ResourcePatternResolver resourcePatternResolver = new PathMatchingResourcePatternResolver();
    private static final MetadataReaderFactory metadataReaderFactory = new CachingMetadataReaderFactory(resourcePatternResolver);

    private final Map<Class<?>, Method> wrapperMethodCache = new ConcurrentHashMap<>();
    private final Map<Class<?>, Object> wrapperCache = new ConcurrentHashMap<>();

    private final Jaxb2Marshaller jaxb2Marshaller;

    public JaxbElementWrapper(Jaxb2Marshaller jaxb2Marshaller) {
        this.jaxb2Marshaller = jaxb2Marshaller;
    }

    @Override
    public Object process(Object input) {
        if (input instanceof JAXBElement) {
            return input;
        }
        if (nonNull(input.getClass().getAnnotation(XmlRootElement.class))) {
            return input;
        }
        return wrap(jaxb2Marshaller, input);
    }

    Map<Class<?>, Object> getWrapperCache() {
        return wrapperCache;
    }

    private Object wrap(Jaxb2Marshaller jaxb2Marshaller, Object input) {
        Class<?> clazz = input.getClass();

        Object cached = wrapperMethodCache.get(clazz);
        if (nonNull(cached)) {
            return cached == EMPTY_METHOD ? input : wrap(input, (Method) cached);
        }

        Class<?>[] wrapperClasses = findWrapperClasses(jaxb2Marshaller);
        if (isNull(wrapperClasses)) {
            wrapperMethodCache.put(clazz, EMPTY_METHOD);
            return input;
        }

        Method method = findMethod(wrapperClasses, clazz);
        if (isNull(method)) {
            wrapperMethodCache.put(clazz, EMPTY_METHOD);
            return input;
        }

        wrapperMethodCache.put(clazz, method);
        return wrap(input, method);
    }

    private Class<?>[] findWrapperClasses(Jaxb2Marshaller jaxb2Marshaller) {
        String contextPath = jaxb2Marshaller.getContextPath();
        if (!hasText(contextPath)) {
            return jaxb2Marshaller.getClassesToBeBound();
        }
        List<Class<?>> classes = new ArrayList<>();
        for (String path : contextPath.split(":")) {
            for (Resource resource : pathToResources(path)) {
                if (resource.isReadable()) {
                    classes.add(forName(resource));
                }
            }
        }
        return classes.toArray(new Class[classes.size()]);
    }

    private Resource[] pathToResources(String path) {
        String resourcePath = ClassUtils.convertClassNameToResourcePath(SystemPropertyUtils.resolvePlaceholders(path));
        String packageSearchPath = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX + resourcePath + "/*.class";
        try {
            return resourcePatternResolver.getResources(packageSearchPath);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private Class<?> forName(Resource resource) {
        try {
            String className = metadataReaderFactory.getMetadataReader(resource).getClassMetadata().getClassName();
            return Class.forName(className);
        } catch (IOException | ClassNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    private Method findMethod(Class<?>[] classes, Class<?> parameterClass) {
        for (Class<?> clazz : classes) {
            if (nonNull(clazz.getAnnotation(XmlRegistry.class))) {
                for (Method method : ReflectionUtils.getAllDeclaredMethods(clazz)) {
                    if (byParameterType(method, parameterClass) && byReturnType(method, parameterClass) && byAnnotation(method)) {
                        return method;
                    }
                }
            }
        }
        return null;
    }

    private boolean byParameterType(Method method, Class<?> parameterClass) {
        return method.getParameterCount() == 1 && method.getParameterTypes()[0].equals(parameterClass);
    }

    private boolean byReturnType(Method method, Class<?> parameterClass) {
        Type genericReturnType = method.getGenericReturnType();
        if (genericReturnType instanceof ParameterizedType) {
            ParameterizedType parameterizedType = (ParameterizedType) genericReturnType;
            if (parameterizedType.getRawType().equals(JAXBElement.class)) {
                Type[] actualTypeArguments = parameterizedType.getActualTypeArguments();
                if (actualTypeArguments.length == 1 && actualTypeArguments[0].equals(parameterClass)) {
                    return true;
                }
            }
        }
        return false;
    }

    private boolean byAnnotation(Method method) {
        return nonNull(method.getAnnotation(XmlElementDecl.class));
    }

    private Object wrap(Object input, Method method) {
        Class<?> wrapperClass = method.getDeclaringClass();
        Object wrapper = wrapperCache.computeIfAbsent(wrapperClass, BeanUtils::instantiate);
        return ReflectionUtils.invokeMethod(method, wrapper, input);
    }
}