package net.jgp.labs.spark.x.datasource;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import org.apache.commons.lang3.StringUtils;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.sources.BaseRelation;
import org.apache.spark.sql.sources.TableScan;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SubStringCounterRelation extends BaseRelation
    implements Serializable, TableScan {

  private static final long serialVersionUID = -3441600156161255871L;
  private static transient Logger log = LoggerFactory.getLogger(
      SubStringCounterRelation.class);

  private SQLContext sqlContext;
  private String filename;
  private boolean updateSchema = true;

  /**
   * Schema, never call directly, always rely on the schema() method.
   */
  private StructType schema = null;
  private List<String> criteria = new ArrayList<>();

  /**
   * Constructs the schema. Internally cache it to avoid rebuilding it every
   * time.
   */
  @Override
  public StructType schema() {
    log.debug("-> schema()");
    if (updateSchema || schema == null) {
      List<StructField> sfl = new ArrayList<>();
      sfl.add(DataTypes.createStructField("line", DataTypes.IntegerType, true));
      for (String crit : this.criteria) {
        sfl.add(DataTypes.createStructField(crit, DataTypes.IntegerType,
            false));
      }
      schema = DataTypes.createStructType(sfl);
      updateSchema = false;
    }
    return schema;
  }

  @Override
  public SQLContext sqlContext() {
    log.debug("-> sqlContext()");
    return this.sqlContext;
  }

  public void setSqlContext(SQLContext arg0) {
    this.sqlContext = arg0;
  }

  @Override
  public RDD<Row> buildScan() {
    log.debug("-> buildScan()");

    // I have isolated the work to a method to keep the plumbing code as simple
    // as
    // possible.
    List<List<Integer>> table = collectData();

    @SuppressWarnings("resource") // cannot be closed here, done elsewhere
    JavaSparkContext sparkContext = new JavaSparkContext(sqlContext
        .sparkContext());
    JavaRDD<Row> rowRDD = sparkContext.parallelize(table)
        .map(row -> RowFactory.create(row.toArray()));

    return rowRDD.rdd();
  }

  /**
   * Builds the data table that will then be turned into a RDD. The way it is
   * built is slightly out of the scope of this example, it simply does line by
   * line and count the substring we want to analyze in the source file.
   * 
   * @return The data as a List of List of Integer. I know, it could be more
   *         elegant.
   */
  private List<List<Integer>> collectData() {
    List<List<Integer>> table = new ArrayList<>();

    FileReader fileReader;
    try {
      fileReader = new FileReader(filename);
    } catch (FileNotFoundException e) {
      log.error("File [{}] not found, got {}", filename, e.getMessage(), e);
      return table;
    }

    BufferedReader br = new BufferedReader(fileReader);
    String line;
    int lineCount = 0;
    int criteriaCount = this.criteria.size();
    List<Integer> row0;
    do {
      row0 = new ArrayList<>();
      try {
        line = br.readLine();
      } catch (IOException e) {
        log.error("Error while reading [{}], got {}", filename, e.getMessage(),
            e);
        break;
      }
      row0.add(lineCount);
      for (int i = 0; i < criteriaCount; i++) {
        int v = StringUtils.countMatches(line, this.criteria.get(i));
        row0.add(v);
      }
      // line.
      table.add(row0);
      lineCount++;
    } while (line != null);

    try {
      br.close();
    } catch (IOException e) {
      log.error("Error while closing [{}], got {}", filename, e.getMessage(),
          e);
    }

    return table;
  }

  public void setFilename(String filename) {
    this.filename = filename;
  }

  public void addCriteria(String crit) {
    this.updateSchema = true;
    this.criteria.add(crit);
  }

}