/*
 * Copyright (c) 2015-2019, Cloudera, Inc. All Rights Reserved.
 *
 * Cloudera, Inc. licenses this file to you under the Apache License,
 * Version 2.0 (the "License"). You may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
 * CONDITIONS OF ANY KIND, either express or implied. See the License for
 * the specific language governing permissions and limitations under the
 * License.
 */

package com.cloudera.labs.envelope.translate;

import com.cloudera.labs.envelope.component.ComponentFactory;
import com.cloudera.labs.envelope.component.InstantiatedComponent;
import com.cloudera.labs.envelope.component.InstantiatesComponents;
import com.cloudera.labs.envelope.schema.DeclaresExpectingSchema;
import com.cloudera.labs.envelope.schema.DeclaresProvidingSchema;
import com.cloudera.labs.envelope.schema.UsesProvidedSchema;
import com.cloudera.labs.envelope.utils.ConfigUtils;
import com.cloudera.labs.envelope.utils.RowUtils;
import com.cloudera.labs.envelope.utils.SchemaUtils;
import com.google.common.collect.Lists;
import com.typesafe.config.Config;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

@SuppressWarnings("serial")
public class TranslateFunction implements FlatMapFunction<Row, Row>, InstantiatesComponents,
    DeclaresExpectingSchema, DeclaresProvidingSchema, UsesProvidedSchema {

  public static final String APPEND_RAW_ENABLED_CONFIG = "append.raw.enabled";
  public static final boolean APPEND_RAW_ENABLED_DEFAULT = false;
  public static final String HAD_ERROR_FIELD_NAME = "_had_error";

  private Config config;
  private StructType providedSchema;

  private transient Translator translator;

  private static Logger LOG = LoggerFactory.getLogger(TranslateFunction.class);

  public TranslateFunction(Config config) {
    this.config = config;
  }

  @Override
  public Iterator<Row> call(Row message) throws Exception {
    validateMessageSchema(message);
    Iterable<Row> translationResults;

    try {
      translationResults = getTranslator().translate(message);
    }
    catch (Exception e) {
      Row error = appendHadErrorFlag(message, true);
      return Collections.singleton(error).iterator();
    }

    List<Row> translated = Lists.newArrayList();
    for (Row translationResult : translationResults) {
      validateTranslatedSchema(translationResult);

      if (doesAppendRaw()) {
        translationResult = appendRawFields(translationResult, message);
      }
      translationResult = appendHadErrorFlag(translationResult, false);

      translated.add(translationResult);
    }

    return translated.iterator();
  }

  @Override
  public StructType getExpectingSchema() {
    return getTranslator().getExpectingSchema();
  }

  @Override
  public StructType getProvidingSchema() {
    StructType translatedSchema = getTranslator().getProvidingSchema();

    if (doesAppendRaw()) {
      translatedSchema = SchemaUtils.appendFields(
          translatedSchema, Arrays.asList(addFieldNameUnderscores(providedSchema).fields()));
    }

    return translatedSchema;
  }

  @Override
  public void receiveProvidedSchema(StructType providedSchema) {
    this.providedSchema = providedSchema;

    if (getTranslator() instanceof UsesProvidedSchema) {
      ((UsesProvidedSchema)getTranslator()).receiveProvidedSchema(providedSchema);
    }
  }

  @Override
  public Set<InstantiatedComponent> getComponents(Config config, boolean configure) {
    return Collections.singleton(new InstantiatedComponent(
        getTranslator(configure), getTranslatorConfig(config), "Translator"
    ));
  }

  public Translator getTranslator() {
    return getTranslator(true);
  }

  public synchronized Translator getTranslator(boolean configure) {
    if (configure) {
      if (translator == null) {
        translator = ComponentFactory.create(Translator.class, getTranslatorConfig(config), true);
        LOG.debug("Translator created: " + translator.getClass().getName());
      }

      return translator;
    }
    else {
      return ComponentFactory.create(Translator.class, getTranslatorConfig(config), false);
    }
  }

  private Config getTranslatorConfig(Config config) {
    // Don't pass the append raw configuration to the translator itself as it doesn't use it
    return config.withoutPath(APPEND_RAW_ENABLED_CONFIG);
  }

  private void validateMessageSchema(Row message) {
    if (message.schema() == null) {
      throw new RuntimeException("Translator must be provided raw messages with an embedded schema");
    }

    if (!hasValueField(message)) {
      throw new RuntimeException("Translator must be provided raw messages with a '" +
          Translator.VALUE_FIELD_NAME + "' field");
    }
  }

  private void validateTranslatedSchema(Row translationResult) {
    if (translationResult.schema() == null) {
      throw new RuntimeException("Translator must translate to rows with an embedded schema");
    }
  }

  private boolean doesAppendRaw() {
    return ConfigUtils.getOrElse(config, APPEND_RAW_ENABLED_CONFIG, APPEND_RAW_ENABLED_DEFAULT);
  }

  private Row appendRawFields(Row translated, Row message) {
    for (StructField messageField : message.schema().fields()) {
      translated = RowUtils.append(
          translated,
          "_" + messageField.name(),
          messageField.dataType(),
          messageField.nullable(),
          message.getAs(messageField.name()));
    }

    return translated;
  }

  private boolean hasValueField(Row message) {
    for (String fieldName : message.schema().fieldNames()) {
      if (fieldName.equals(Translator.VALUE_FIELD_NAME)) {
        return true;
      }
    }

    return false;
  }

  private StructType addFieldNameUnderscores(StructType without) {
    List<StructField> withFields = Lists.newArrayList();

    for (StructField withoutField : without.fields()) {
      String withName = "_" + withoutField.name();
      if (Arrays.asList(without.fieldNames()).contains(withName)) {
        throw new RuntimeException("Can not append raw field '" + withName + "' because that " +
            "field already exists as a result of the translation");
      }

      StructField withField = DataTypes.createStructField(
          withName, withoutField.dataType(), withoutField.nullable(), withoutField.metadata());

      withFields.add(withField);
    }

    return DataTypes.createStructType(withFields);
  }

  private Row appendHadErrorFlag(Row row, boolean hadError) {
    return RowUtils.append(row, HAD_ERROR_FIELD_NAME, DataTypes.BooleanType, false, hadError);
  }

}