/*
 * Copyright (C) 2018 Google Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package com.google.cloud.teleport.spanner;

import com.google.cloud.ByteArray;
import com.google.cloud.Date;
import com.google.cloud.Timestamp;
import com.google.cloud.spanner.Mutation;
import com.google.cloud.teleport.spanner.ddl.Column;
import com.google.cloud.teleport.spanner.ddl.Table;
import com.google.common.annotations.VisibleForTesting;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.avro.LogicalType;
import org.apache.avro.LogicalTypes;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.util.Utf8;
import org.apache.beam.sdk.transforms.SerializableFunction;

/** Converts {@link GenericRecord} to {@link Mutation}. */
public class AvroRecordConverter implements SerializableFunction<GenericRecord, Mutation> {

  private final Table table;

  public AvroRecordConverter(Table table) {
    this.table = table;
  }

  public Mutation apply(GenericRecord record) {
    Schema schema = record.getSchema();
    List<Schema.Field> fields = schema.getFields();
    Mutation.WriteBuilder builder = Mutation.newInsertOrUpdateBuilder(table.name());

    for (Schema.Field field : fields) {
      String fieldName = field.name();

      Column column = table.column(fieldName);
      if (column == null) {
        throw new IllegalArgumentException(
            String.format(
                "Cannot find corresponding column for field %s in table %s schema %s",
                fieldName, table.prettyPrint(), schema.toString(true)));
      }

      Schema avroFieldSchema = field.schema();
      if (avroFieldSchema.getType() == Schema.Type.UNION) {
        Schema unpacked = AvroUtil.unpackNullable(avroFieldSchema);
        if (unpacked != null) {
          avroFieldSchema = unpacked;
        }
      }
      LogicalType logicalType = LogicalTypes.fromSchema(avroFieldSchema);
      Schema.Type avroType = avroFieldSchema.getType();

      switch (column.type().getCode()) {
        case BOOL:
          builder.set(column.name()).to(readBool(record, avroType, fieldName).orElse(null));
          break;
        case INT64:
          builder.set(column.name()).to(readInt64(record, avroType, fieldName).orElse(null));
          break;
        case FLOAT64:
          builder.set(column.name()).to(readFloat64(record, avroType, fieldName).orElse(null));
          break;
        case STRING:
          builder.set(column.name()).to(readString(record, avroType, fieldName).orElse(null));
          break;
        case BYTES:
          builder.set(column.name()).to(readBytes(record, avroType, fieldName).orElse(null));
          break;
        case TIMESTAMP:
          builder
              .set(column.name())
              .to(readTimestamp(record, avroType, logicalType, fieldName).orElse(null));
          break;
        case DATE:
          builder
              .set(column.name())
              .to(readDate(record, avroType, logicalType, fieldName).orElse(null));
          break;
        case ARRAY:
          {
            Schema arraySchema = avroFieldSchema.getElementType();
            if (arraySchema.getType() == Schema.Type.UNION) {
              Schema unpacked = AvroUtil.unpackNullable(arraySchema);
              if (unpacked != null) {
                arraySchema = unpacked;
              }
            }
            LogicalType arrayLogicalType = LogicalTypes.fromSchema(arraySchema);
            Schema.Type arrayType = arraySchema.getType();
            switch (column.type().getArrayElementType().getCode()) {
              case BOOL:
                builder
                    .set(column.name())
                    .toBoolArray(readBoolArray(record, arrayType, fieldName).orElse(null));
                break;
              case INT64:
                builder
                    .set(column.name())
                    .toInt64Array(readInt64Array(record, arrayType, fieldName).orElse(null));
                break;
              case FLOAT64:
                builder
                    .set(column.name())
                    .toFloat64Array(readFloat64Array(record, arrayType, fieldName).orElse(null));
                break;
              case STRING:
                builder
                    .set(column.name())
                    .toStringArray(readStringArray(record, arrayType, fieldName).orElse(null));
                break;
              case BYTES:
                builder
                    .set(column.name())
                    .toBytesArray(readBytesArray(record, arrayType, fieldName).orElse(null));
                break;
              case TIMESTAMP:
                builder
                    .set(column.name())
                    .toTimestampArray(
                        readTimestampArray(record, arrayType, arrayLogicalType, fieldName)
                            .orElse(null));
                break;
              case DATE:
                builder
                    .set(column.name())
                    .toDateArray(readDateArray(record, arrayType, fieldName).orElse(null));
                break;
              default:
                throw new IllegalArgumentException(
                    String.format(
                        "Cannot convert field %s in schema %s table %s",
                        fieldName, schema.toString(true), table.prettyPrint()));
            }
            break;
          }
        default:
          throw new IllegalArgumentException(
              String.format(
                  "Cannot convert field %s in schema %s table %s",
                  fieldName, schema.toString(true), table.prettyPrint()));
      }
    }

    return builder.build();
  }

  @SuppressWarnings("unchecked")
  private Optional<List<ByteArray>> readBytesArray(
      GenericRecord record, Schema.Type avroType, String fieldName) {
    switch (avroType) {
      case STRING:
        {
          List<Utf8> value = (List<Utf8>) record.get(fieldName);
          if (value == null) {
            return Optional.empty();
          }
          return Optional.of(
              value
                  .stream()
                  .map(x -> x == null ? null : ByteArray.copyFrom(x.getBytes()))
                  .collect(Collectors.toList()));
        }
      case BYTES:
        {
          List<ByteBuffer> value = (List<ByteBuffer>) record.get(fieldName);
          if (value == null) {
            return Optional.empty();
          }
          return Optional.of(
              value
                  .stream()
                  .map(x -> x == null ? null : ByteArray.copyFrom(x))
                  .collect(Collectors.toList()));
        }
      default:
        throw new IllegalArgumentException("Cannot interpret " + avroType + " as BYTES");
    }
  }

  @SuppressWarnings("unchecked")
  private Optional<List<Date>> readDateArray(
      GenericRecord record, Schema.Type avroType, String fieldName) {
    switch (avroType) {
      case STRING:
        {
          List<Utf8> value = (List<Utf8>) record.get(fieldName);
          if (value == null) {
            return Optional.empty();
          }
          return Optional.of(
              value
                  .stream()
                  .map(x -> x == null ? null : Date.parseDate(x.toString()))
                  .collect(Collectors.toList()));
        }
      default:
        throw new IllegalArgumentException("Cannot interpret " + avroType + " as DATE");
    }
  }

  @VisibleForTesting
  @SuppressWarnings("unchecked")
  static Optional<List<Timestamp>> readTimestampArray(
      GenericRecord record, Schema.Type avroType, LogicalType logicalType, String fieldName) {
    Object fieldValue = record.get(fieldName);
    if (fieldValue == null) {
      return Optional.empty();
    }
    switch (avroType) {
      case LONG:
        {
          List<Long> value = (List<Long>) fieldValue;
          // Default to microseconds
          if (logicalType == null || LogicalTypes.timestampMicros().equals(logicalType)) {
            return Optional.of(
                value
                    .stream()
                    .map(x -> x == null ? null : Timestamp.ofTimeMicroseconds(x))
                    .collect(Collectors.toList()));
          } else if (LogicalTypes.timestampMillis().equals(logicalType)) {
            return Optional.of(
                value
                    .stream()
                    .map(x -> x == null ? null : Timestamp.ofTimeMicroseconds(1000L * x))
                    .collect(Collectors.toList()));
          } else {
            throw new IllegalArgumentException(
                String.format(
                    "Cannot interpret Avrotype LONG LogicalType %s as TIMESTAMP", logicalType));
          }
        }
      case STRING:
        {
          List<Utf8> value = (List<Utf8>) fieldValue;
          return Optional.of(
              value
                  .stream()
                  .map(x -> x == null ? null : Timestamp.parseTimestamp(x.toString()))
                  .collect(Collectors.toList()));
        }
      default:
        throw new IllegalArgumentException("Cannot interpret " + avroType + " as TIMESTAMP");
    }
  }

  @VisibleForTesting
  @SuppressWarnings("unchecked")
  static Optional<List<String>> readStringArray(
      GenericRecord record, Schema.Type avroType, String fieldName) {
    List<Object> fieldValue = (List<Object>) record.get(fieldName);
    if (fieldValue == null) {
      return Optional.empty();
    }
    switch (avroType) {
      case BOOLEAN:
      case FLOAT:
      case DOUBLE:
      case LONG:
      case INT:
      case STRING:
        // This relies on the .toString() method present in all classes.
        // It is not necessary to know the exact type of x for that.
        return Optional.of(
            fieldValue.stream()
                .map(x -> x == null ? null : String.valueOf(x))
                .collect(Collectors.toList()));
      default:
        throw new IllegalArgumentException("Cannot interpret " + avroType + " as STRING");
    }
  }

  @VisibleForTesting
  @SuppressWarnings("unchecked")
  static Optional<List<Double>> readFloat64Array(
      GenericRecord record, Schema.Type avroType, String fieldName) {
    Object fieldValue = record.get(fieldName);
    if (fieldValue == null) {
      return Optional.empty();
    }
    switch (avroType) {
      // For type check at compile time, the type of x has to be specified (as cast) so that
      // convertability to double can be verified.
      case DOUBLE:
        return Optional.of((List<Double>) fieldValue);
      case FLOAT:
        {
          List<Float> value = (List<Float>) fieldValue;
          return Optional.of(
              value.stream()
                  .map(x -> x == null ? null : Double.valueOf(x))
                  .collect(Collectors.toList()));
        }
      case INT:
        {
          List<Integer> value = (List<Integer>) fieldValue;
          return Optional.of(
              value.stream()
                  .map(x -> x == null ? null : Double.valueOf(x))
                  .collect(Collectors.toList()));
        }
      case LONG: {
          List<Long> value = (List<Long>) record.get(fieldName);
          return Optional.of(
              value.stream()
                  .map(x -> x == null ? null : Double.valueOf(x))
                  .collect(Collectors.toList()));
      }
      case STRING: {
          List<Utf8> value = (List<Utf8>) record.get(fieldName);
          return Optional.of(
              value
                  .stream()
                  .map(x -> x == null ? null : Double.valueOf(x.toString()))
                  .collect(Collectors.toList()));
      }
      default:
        throw new IllegalArgumentException("Cannot interpret " + avroType + " as FLOAT64");
    }
  }

  @VisibleForTesting
  @SuppressWarnings("unchecked")
  static Optional<List<Long>> readInt64Array(
      GenericRecord record, Schema.Type avroType, String fieldName) {
    Object fieldValue = record.get(fieldName);
    if (fieldValue == null) {
      return Optional.empty();
    }
    switch (avroType) {
      // For type check at compile time, the type of x has to be specified (as cast) so that
      // convertability to long can be verified.
      case LONG:
        return Optional.of((List<Long>) fieldValue);
      case INT: {
          List<Integer> value = (List<Integer>) fieldValue;
          return Optional.of(
              value.stream()
                  .map(x -> x == null ? null : Long.valueOf(x))
                  .collect(Collectors.toList()));
      }
      case STRING: {
          List<Utf8> value = (List<Utf8>) fieldValue;
          return Optional.of(
              value.stream()
                  .map(x -> x == null ? null : Long.valueOf(x.toString()))
                  .collect(Collectors.toList()));
      }
      default:
        throw new IllegalArgumentException("Cannot interpret " + avroType + " as INT64");
    }
  }

  @VisibleForTesting
  @SuppressWarnings("unchecked")
  static Optional<List<Boolean>> readBoolArray(
      GenericRecord record, Schema.Type avroType, String fieldName) {
    switch (avroType) {
      case BOOLEAN:
        return Optional.ofNullable((List<Boolean>) record.get(fieldName));
      case STRING:
        {
          List<Utf8> value = (List<Utf8>) record.get(fieldName);
          if (value == null) {
            return Optional.empty();
          }
          List<Boolean> result =
              value
                  .stream()
                  .map(x -> x == null ? null : Boolean.valueOf(x.toString()))
                  .collect(Collectors.toList());
          return Optional.of(result);
        }
      default:
        throw new IllegalArgumentException("Cannot interpret " + avroType + " as BOOL");
    }
  }

  private Optional<Date> readDate(
      GenericRecord record, Schema.Type avroType, LogicalType logicalType, String fieldName) {
    switch (avroType) {
      case INT:
        if (logicalType == null || !LogicalTypes.date().equals(logicalType)) {
          throw new IllegalArgumentException(
              "Cannot interpret Avrotype INT Logicaltype " + logicalType + " as DATE");
        }
        // Avro Date is number of days since Jan 1, 1970.
        // Have to convert to Java Date first before creating google.cloud.core.Date
        return Optional.ofNullable((Integer) record.get(fieldName))
            .map(x -> new java.util.Date((long) x * 24L * 3600L * 1000L))
            .map(Date::fromJavaUtilDate);
      case STRING:
        return Optional.ofNullable((Utf8) record.get(fieldName))
            .map(Utf8::toString)
            .map(Date::parseDate);
      default:
        throw new IllegalArgumentException("Cannot interpret " + avroType + " as DATE");
    }
  }

  private Optional<Timestamp> readTimestamp(
      GenericRecord record, Schema.Type avroType, LogicalType logicalType, String fieldName) {
    switch (avroType) {
      case LONG:
        if (LogicalTypes.timestampMillis().equals(logicalType)) {
          return Optional.ofNullable((Long) record.get(fieldName))
              .map(x -> Timestamp.ofTimeMicroseconds(1000L * x));
        }
        if (LogicalTypes.timestampMicros().equals(logicalType)) {
          return Optional.ofNullable((Long) record.get(fieldName))
              .map(Timestamp::ofTimeMicroseconds);
        }
        // Default to micro-seconds.
        return Optional.ofNullable((Long) record.get(fieldName)).map(Timestamp::ofTimeMicroseconds);
      case STRING:
        return Optional.ofNullable((Utf8) record.get(fieldName))
            .map(Utf8::toString)
            .map(Timestamp::parseTimestamp);
      default:
        throw new IllegalArgumentException("Cannot interpret " + avroType + " as TIMESTAMP");
    }
  }

  private static Optional<ByteArray> readBytes(
      GenericRecord record, Schema.Type avroType, String fieldName) {
    switch (avroType) {
      case BYTES:
        return Optional.ofNullable((ByteBuffer) record.get(fieldName)).map(ByteArray::copyFrom);
      case STRING:
        return Optional.ofNullable((Utf8) record.get(fieldName))
            .map(Utf8::toString)
            .map(ByteArray::copyFrom);
      default:
        throw new IllegalArgumentException("Cannot interpret " + avroType + " as BYTES");
    }
  }

  private static Optional<Boolean> readBool(
      GenericRecord record, Schema.Type avroType, String fieldName) {
    switch (avroType) {
      case BOOLEAN:
        return Optional.ofNullable((Boolean) record.get(fieldName));
      case STRING:
        return Optional.ofNullable((Utf8) record.get(fieldName))
            .map(Utf8::toString)
            .map(Boolean::parseBoolean);
      default:
        throw new IllegalArgumentException("Cannot interpret " + avroType + " as BOOL");
    }
  }

  private static Optional<Double> readFloat64(
      GenericRecord record, Schema.Type avroType, String fieldName) {
    switch (avroType) {
      case INT:
        return Optional.ofNullable((Integer) record.get(fieldName)).map(x -> (double) x);
      case LONG:
        return Optional.ofNullable((Long) record.get(fieldName)).map(x -> (double) x);
      case FLOAT:
        return Optional.ofNullable((Float) record.get(fieldName)).map(x -> (double) x);
      case DOUBLE:
        return Optional.ofNullable((Double) record.get(fieldName));
      case STRING:
        return Optional.ofNullable((Utf8) record.get(fieldName))
            .map(Utf8::toString)
            .map(Double::valueOf);
      default:
        throw new IllegalArgumentException("Cannot interpret " + avroType + " as FLOAT64");
    }
  }

  private static Optional<Long> readInt64(
      GenericRecord record, Schema.Type avroType, String fieldName) {
    switch (avroType) {
      case INT:
        return Optional.ofNullable((Integer) record.get(fieldName)).map(x -> (long) x);
      case LONG:
        return Optional.ofNullable((Long) record.get(fieldName));
      case STRING:
        return Optional.ofNullable((Utf8) record.get(fieldName))
            .map(Utf8::toString)
            .map(Long::valueOf);
      default:
        throw new IllegalArgumentException("Cannot interpret " + avroType + " as INT64");
    }
  }

  private static Optional<String> readString(
      GenericRecord record, Schema.Type avroType, String fieldName) {
    switch (avroType) {
      case INT:
        return Optional.ofNullable((Integer) record.get(fieldName)).map(String::valueOf);
      case LONG:
        return Optional.ofNullable((Long) record.get(fieldName)).map(String::valueOf);
      case FLOAT:
        return Optional.ofNullable((Float) record.get(fieldName)).map(String::valueOf);
      case DOUBLE:
        return Optional.ofNullable((Double) record.get(fieldName)).map(String::valueOf);
      case STRING:
        return Optional.ofNullable((Utf8) record.get(fieldName)).map(Utf8::toString);
      default:
        throw new IllegalArgumentException("Cannot interpret " + avroType + " as STRING");
    }
  }
}