/*-
 * #%L
 * Amazon Athena Query Federation SDK
 * %%
 * Copyright (C) 2019 - 2020 Amazon Web Services
 * %%
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * #L%
 */
package com.amazonaws.athena.connector.lambda.serde.v2;

import com.amazonaws.athena.connector.lambda.data.BlockAllocator;
import com.amazonaws.athena.connector.lambda.request.FederationRequest;
import com.amazonaws.athena.connector.lambda.request.FederationResponse;
import com.amazonaws.athena.connector.lambda.serde.FederatedIdentitySerDe;
import com.amazonaws.athena.connector.lambda.serde.PingRequestSerDe;
import com.amazonaws.athena.connector.lambda.serde.PingResponseSerDe;
import com.amazonaws.services.lambda.invoke.LambdaFunctionException;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.databind.BeanDescription;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.JsonSerializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializerProvider;
import com.fasterxml.jackson.databind.cfg.DeserializerFactoryConfig;
import com.fasterxml.jackson.databind.cfg.SerializerFactoryConfig;
import com.fasterxml.jackson.databind.deser.BeanDeserializerFactory;
import com.fasterxml.jackson.databind.deser.DefaultDeserializationContext;
import com.fasterxml.jackson.databind.deser.DeserializerFactory;
import com.fasterxml.jackson.databind.deser.Deserializers;
import com.fasterxml.jackson.databind.module.SimpleDeserializers;
import com.fasterxml.jackson.databind.module.SimpleSerializers;
import com.fasterxml.jackson.databind.ser.BeanSerializerFactory;
import com.fasterxml.jackson.databind.ser.SerializerFactory;
import com.fasterxml.jackson.databind.ser.Serializers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

public class ObjectMapperFactoryV2
{
    private static final JsonFactory JSON_FACTORY = new JsonFactory();
    private static final String LAMDA_EXCEPTION_CLASS_NAME = LambdaFunctionException.class.getName();

    private static final SerializerFactory SERIALIZER_FACTORY;

    static {
        // Serializers can be static since they don't need a BlockAllocator
        ImmutableList<JsonSerializer<?>> sers = ImmutableList.of(createRequestSerializer(), createResponseSerializer());
        SimpleSerializers serializers = new SimpleSerializers(sers);
        SerializerFactoryConfig config = new SerializerFactoryConfig().withAdditionalSerializers(serializers);
        SERIALIZER_FACTORY = new StrictSerializerFactory(config);
    }

    private ObjectMapperFactoryV2(){}

    /**
     * Custom SerializerFactory that *only* uses the custom serializers that we inject into the {@link ObjectMapper}.
     */
    private static class StrictSerializerFactory extends BeanSerializerFactory
    {
        private StrictSerializerFactory(SerializerFactoryConfig config)
        {
            super(config);
        }

        @Override
        public StrictSerializerFactory withConfig(SerializerFactoryConfig config)
        {
            if (_factoryConfig == config) {
                return this;
            }
            return new StrictSerializerFactory(config);
        }

        @Override
        @SuppressWarnings("unchecked")
        public JsonSerializer<Object> createSerializer(SerializerProvider prov, JavaType origType)
                throws JsonMappingException
        {
            for (Serializers serializers : customSerializers()) {
                JsonSerializer<?> ser = serializers.findSerializer(prov.getConfig(), origType, null);
                if (ser != null) {
                    return (JsonSerializer<Object>) ser;
                }
            }
            throw new IllegalArgumentException("No explicitly configured serializer for " + origType);
        }
    }

    /**
     * Custom DeserializerFactory that *only* uses the custom deserializers that we inject into the {@link ObjectMapper}.
     */
    private static class StrictDeserializerFactory extends BeanDeserializerFactory
    {
        private StrictDeserializerFactory(DeserializerFactoryConfig config)
        {
            super(config);
        }

        @Override
        public DeserializerFactory withConfig(DeserializerFactoryConfig config)
        {
            if (_factoryConfig == config) {
                return this;
            }
            return new StrictDeserializerFactory(config);
        }

        @Override
        @SuppressWarnings("unchecked")
        public JsonDeserializer<Object> createBeanDeserializer(DeserializationContext ctxt, JavaType type, BeanDescription beanDesc)
                throws JsonMappingException
        {
            for (Deserializers d  : _factoryConfig.deserializers()) {
                JsonDeserializer<?> deser = d.findBeanDeserializer(type, ctxt.getConfig(), beanDesc);
                if (deser != null) {
                    return (JsonDeserializer<Object>) deser;
                }
            }
            throw new IllegalArgumentException("No explicitly configured deserializer for " + type);
        }
    }

    /**
     * Locked down ObjectMapper that only uses the serializers/deserializers provided and does not fall back to annotation or reflection
     * based serialization.
     */
    private static class StrictObjectMapper extends ObjectMapper
    {
        private StrictObjectMapper(BlockAllocator allocator)
        {
            super(JSON_FACTORY);
            _serializerFactory = SERIALIZER_FACTORY;

            ImmutableMap<Class<?>, JsonDeserializer<?>> desers = ImmutableMap.of(
                    FederationRequest.class, createRequestDeserializer(allocator),
                    FederationResponse.class, createResponseDeserializer(allocator),
                    LambdaFunctionException.class, new LambdaFunctionExceptionSerDe.Deserializer());
            SimpleDeserializers deserializers = new SimpleDeserializers(desers);
            DeserializerFactoryConfig dConfig = new DeserializerFactoryConfig().withAdditionalDeserializers(deserializers);
            _deserializationContext = new DefaultDeserializationContext.Impl(new StrictDeserializerFactory(dConfig));
            // required by LambdaInvokerFactory
            disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES);
        }
    }

    public static ObjectMapper create(BlockAllocator allocator)
    {
        return new StrictObjectMapper(allocator);
    }

    private static FederationRequestSerDe.Serializer createRequestSerializer()
    {
        FederatedIdentitySerDe.Serializer identity = new FederatedIdentitySerDe.Serializer();
        TableNameSerDe.Serializer tableName = new TableNameSerDe.Serializer();
        SchemaSerDe.Serializer schema = new SchemaSerDe.Serializer();
        BlockSerDe.Serializer block = new BlockSerDe.Serializer(schema);
        ArrowTypeSerDe.Serializer arrowType = new ArrowTypeSerDe.Serializer();
        MarkerSerDe.Serializer marker = new MarkerSerDe.Serializer(block);
        RangeSerDe.Serializer range = new RangeSerDe.Serializer(marker);
        EquatableValueSetSerDe.Serializer equatableValueSet = new EquatableValueSetSerDe.Serializer(block);
        SortedRangeSetSerDe.Serializer sortedRangeSet = new SortedRangeSetSerDe.Serializer(arrowType, range);
        AllOrNoneValueSetSerDe.Serializer allOrNoneValueSet = new AllOrNoneValueSetSerDe.Serializer(arrowType);
        ValueSetSerDe.Serializer valueSet = new ValueSetSerDe.Serializer(equatableValueSet, sortedRangeSet, allOrNoneValueSet);
        ConstraintsSerDe.Serializer constraints = new ConstraintsSerDe.Serializer(valueSet);
        S3SpillLocationSerDe.Serializer s3SpillLocation = new S3SpillLocationSerDe.Serializer();
        SpillLocationSerDe.Serializer spillLocation = new SpillLocationSerDe.Serializer(s3SpillLocation);
        EncryptionKeySerDe.Serializer encryptionKey = new EncryptionKeySerDe.Serializer();
        SplitSerDe.Serializer split = new SplitSerDe.Serializer(spillLocation, encryptionKey);
        PingRequestSerDe.Serializer ping = new PingRequestSerDe.Serializer(identity);
        ListSchemasRequestSerDe.Serializer listSchemas = new ListSchemasRequestSerDe.Serializer(identity);
        ListTablesRequestSerDe.Serializer listTables = new ListTablesRequestSerDe.Serializer(identity);
        GetTableRequestSerDe.Serializer getTable = new GetTableRequestSerDe.Serializer(identity, tableName);
        GetTableLayoutRequestSerDe.Serializer getTableLayout = new GetTableLayoutRequestSerDe.Serializer(identity, tableName, constraints, schema);
        GetSplitsRequestSerDe.Serializer getSplits = new GetSplitsRequestSerDe.Serializer(identity, tableName, block, constraints);
        ReadRecordsRequestSerDe.Serializer readRecords = new ReadRecordsRequestSerDe.Serializer(identity, tableName, constraints, schema, split);
        UserDefinedFunctionRequestSerDe.Serializer userDefinedFunction = new UserDefinedFunctionRequestSerDe.Serializer(identity, block, schema);
        return new FederationRequestSerDe.Serializer(
                ping,
                listSchemas,
                listTables,
                getTable,
                getTableLayout,
                getSplits,
                readRecords,
                userDefinedFunction);
    }

    private static FederationRequestSerDe.Deserializer createRequestDeserializer(BlockAllocator allocator)
    {
        FederatedIdentitySerDe.Deserializer identity = new FederatedIdentitySerDe.Deserializer();
        TableNameSerDe.Deserializer tableName = new TableNameSerDe.Deserializer();
        SchemaSerDe.Deserializer schema = new SchemaSerDe.Deserializer();
        BlockSerDe.Deserializer block = new BlockSerDe.Deserializer(allocator, schema);
        ArrowTypeSerDe.Deserializer arrowType = new ArrowTypeSerDe.Deserializer();
        MarkerSerDe.Deserializer marker = new MarkerSerDe.Deserializer(block);
        RangeSerDe.Deserializer range = new RangeSerDe.Deserializer(marker);
        EquatableValueSetSerDe.Deserializer equatableValueSet = new EquatableValueSetSerDe.Deserializer(block);
        SortedRangeSetSerDe.Deserializer sortedRangeSet = new SortedRangeSetSerDe.Deserializer(arrowType, range);
        AllOrNoneValueSetSerDe.Deserializer allOrNoneValueSet = new AllOrNoneValueSetSerDe.Deserializer(arrowType);
        ValueSetSerDe.Deserializer valueSet = new ValueSetSerDe.Deserializer(equatableValueSet, sortedRangeSet, allOrNoneValueSet);
        ConstraintsSerDe.Deserializer constraints = new ConstraintsSerDe.Deserializer(valueSet);
        S3SpillLocationSerDe.Deserializer s3SpillLocation = new S3SpillLocationSerDe.Deserializer();
        SpillLocationSerDe.Deserializer spillLocation = new SpillLocationSerDe.Deserializer(s3SpillLocation);
        EncryptionKeySerDe.Deserializer encryptionKey = new EncryptionKeySerDe.Deserializer();
        SplitSerDe.Deserializer split = new SplitSerDe.Deserializer(spillLocation, encryptionKey);

        PingRequestSerDe.Deserializer ping = new PingRequestSerDe.Deserializer(identity);
        ListSchemasRequestSerDe.Deserializer listSchemas = new ListSchemasRequestSerDe.Deserializer(identity);
        ListTablesRequestSerDe.Deserializer listTables = new ListTablesRequestSerDe.Deserializer(identity);
        GetTableRequestSerDe.Deserializer getTable = new GetTableRequestSerDe.Deserializer(identity, tableName);
        GetTableLayoutRequestSerDe.Deserializer getTableLayout = new GetTableLayoutRequestSerDe.Deserializer(identity, tableName, constraints, schema);
        GetSplitsRequestSerDe.Deserializer getSplits = new GetSplitsRequestSerDe.Deserializer(identity, tableName, block, constraints);
        ReadRecordsRequestSerDe.Deserializer readRecords = new ReadRecordsRequestSerDe.Deserializer(identity, tableName, constraints, schema, split);
        UserDefinedFunctionRequestSerDe.Deserializer userDefinedFunction = new UserDefinedFunctionRequestSerDe.Deserializer(identity, block, schema);

        return new FederationRequestSerDe.Deserializer(
                ping,
                listSchemas,
                listTables,
                getTable,
                getTableLayout,
                getSplits,
                readRecords,
                userDefinedFunction);
    }

    private static FederationResponseSerDe.Serializer createResponseSerializer()
    {
        TableNameSerDe.Serializer tableName = new TableNameSerDe.Serializer();
        SchemaSerDe.Serializer schema = new SchemaSerDe.Serializer();
        BlockSerDe.Serializer block = new BlockSerDe.Serializer(schema);
        S3SpillLocationSerDe.Serializer s3SpillLocation = new S3SpillLocationSerDe.Serializer();
        SpillLocationSerDe.Serializer spillLocation = new SpillLocationSerDe.Serializer(s3SpillLocation);
        EncryptionKeySerDe.Serializer encryptionKey = new EncryptionKeySerDe.Serializer();
        SplitSerDe.Serializer split = new SplitSerDe.Serializer(spillLocation, encryptionKey);

        PingResponseSerDe.Serializer ping = new PingResponseSerDe.Serializer();
        ListSchemasResponseSerDe.Serializer listSchemas = new ListSchemasResponseSerDe.Serializer();
        ListTablesResponseSerDe.Serializer listTables = new ListTablesResponseSerDe.Serializer(tableName);
        GetTableResponseSerDe.Serializer getTable = new GetTableResponseSerDe.Serializer(tableName, schema);
        GetTableLayoutResponseSerDe.Serializer getTableLayout = new GetTableLayoutResponseSerDe.Serializer(tableName, block);
        GetSplitsResponseSerDe.Serializer getSplits = new GetSplitsResponseSerDe.Serializer(split);
        ReadRecordsResponseSerDe.Serializer readRecords = new ReadRecordsResponseSerDe.Serializer(block);
        RemoteReadRecordsResponseSerDe.Serializer remoteReadRecords = new RemoteReadRecordsResponseSerDe.Serializer(schema, spillLocation, encryptionKey);
        UserDefinedFunctionResponseSerDe.Serializer userDefinedFunction = new UserDefinedFunctionResponseSerDe.Serializer(block);

        return new FederationResponseSerDe.Serializer(
                ping,
                listSchemas,
                listTables,
                getTable,
                getTableLayout,
                getSplits,
                readRecords,
                remoteReadRecords,
                userDefinedFunction);
    }

    private static FederationResponseSerDe.Deserializer createResponseDeserializer(BlockAllocator allocator)
    {
        TableNameSerDe.Deserializer tableName = new TableNameSerDe.Deserializer();
        SchemaSerDe.Deserializer schema = new SchemaSerDe.Deserializer();
        BlockSerDe.Deserializer block = new BlockSerDe.Deserializer(allocator, schema);
        S3SpillLocationSerDe.Deserializer s3SpillLocation = new S3SpillLocationSerDe.Deserializer();
        SpillLocationSerDe.Deserializer spillLocation = new SpillLocationSerDe.Deserializer(s3SpillLocation);
        EncryptionKeySerDe.Deserializer encryptionKey = new EncryptionKeySerDe.Deserializer();
        SplitSerDe.Deserializer split = new SplitSerDe.Deserializer(spillLocation, encryptionKey);

        PingResponseSerDe.Deserializer ping = new PingResponseSerDe.Deserializer();
        ListSchemasResponseSerDe.Deserializer listSchemas = new ListSchemasResponseSerDe.Deserializer();
        ListTablesResponseSerDe.Deserializer listTables = new ListTablesResponseSerDe.Deserializer(tableName);
        GetTableResponseSerDe.Deserializer getTable = new GetTableResponseSerDe.Deserializer(tableName, schema);
        GetTableLayoutResponseSerDe.Deserializer getTableLayout = new GetTableLayoutResponseSerDe.Deserializer(tableName, block);
        GetSplitsResponseSerDe.Deserializer getSplits = new GetSplitsResponseSerDe.Deserializer(split);
        ReadRecordsResponseSerDe.Deserializer readRecords = new ReadRecordsResponseSerDe.Deserializer(block);
        RemoteReadRecordsResponseSerDe.Deserializer remoteReadRecords = new RemoteReadRecordsResponseSerDe.Deserializer(schema, spillLocation, encryptionKey);
        UserDefinedFunctionResponseSerDe.Deserializer userDefinedFunction = new UserDefinedFunctionResponseSerDe.Deserializer(block);

        return new FederationResponseSerDe.Deserializer(
                ping,
                listSchemas,
                listTables,
                getTable,
                getTableLayout,
                getSplits,
                readRecords,
                remoteReadRecords,
                userDefinedFunction);
    }
}