package org.datavec.api.transform.transform.string;

import lombok.Data;
import org.datavec.api.transform.ColumnOp;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseTransform;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * Concatenate values of one or more String columns into
 * a new String column. Retains the constituent String
 * columns so user must remove those manually, if desired.
 *
 * TODO: use new String Reduce functionality in DataVec?
 *
 * @author [email protected]
 */
@JsonIgnoreProperties({"inputSchema"})
@Data
public class ConcatenateStringColumns extends BaseTransform implements ColumnOp {

    private final String newColumnName;
    private final String delimiter;
    private final List<String> columnsToConcatenate;
    private Schema inputSchema;

    /**
     * @param columnsToConcatenate A partial or complete order of the columns in the output
     */
    public ConcatenateStringColumns(String newColumnName, String delimiter, String... columnsToConcatenate) {
        this(newColumnName, delimiter, Arrays.asList(columnsToConcatenate));
    }

    /**
     * @param columnsToConcatenate A partial or complete order of the columns in the output
     */
    public ConcatenateStringColumns(@JsonProperty("newColumnName") String newColumnName,
                    @JsonProperty("delimiter") String delimiter,
                    @JsonProperty("columnsToConcatenate") List<String> columnsToConcatenate) {
        this.newColumnName = newColumnName;
        this.delimiter = delimiter;
        this.columnsToConcatenate = columnsToConcatenate;
    }

    @Override
    public Schema transform(Schema inputSchema) {
        for (String s : columnsToConcatenate) {
            if (!inputSchema.hasColumn(s)) {
                throw new IllegalStateException("Input schema does not contain column with name \"" + s + "\"");
            }
        }

        List<ColumnMetaData> outMeta = new ArrayList<>();
        outMeta.addAll(inputSchema.getColumnMetaData());

        ColumnMetaData newColMeta = ColumnType.String.newColumnMetaData(newColumnName);
        outMeta.add(newColMeta);
        return inputSchema.newSchema(outMeta);
    }

    @Override
    public void setInputSchema(Schema inputSchema) {
        for (String s : columnsToConcatenate) {
            if (!inputSchema.hasColumn(s)) {
                throw new IllegalStateException("Input schema does not contain column with name \"" + s + "\"");
            }
        }
        this.inputSchema = inputSchema;
    }

    @Override
    public Schema getInputSchema() {
        return inputSchema;
    }

    @Override
    public List<Writable> map(List<Writable> writables) {
        StringBuilder newColumnText = new StringBuilder();
        List<Writable> out = new ArrayList<>(writables);
        int i = 0;
        for (String columnName : columnsToConcatenate) {
            if (i++ > 0)
                newColumnText.append(delimiter);
            int columnIdx = inputSchema.getIndexOfColumn(columnName);
            newColumnText.append(writables.get(columnIdx));
        }
        out.add(new Text(newColumnText.toString()));
        return out;
    }

    @Override
    public List<List<Writable>> mapSequence(List<List<Writable>> sequence) {
        List<List<Writable>> out = new ArrayList<>();
        for (List<Writable> step : sequence) {
            out.add(map(step));
        }
        return out;
    }

    /**
     * Transform an object
     * in to another object
     *
     * @param input the record to transform
     * @return the transformed writable
     */
    @Override
    public Object map(Object input) {
        throw new UnsupportedOperationException(
                        "Unable to map. Please treat this as a special operation. This should be handled by your implementation.");

    }

    /**
     * Transform a sequence
     *
     * @param sequence
     */
    @Override
    public Object mapSequence(Object sequence) {
        throw new UnsupportedOperationException(
                        "Unable to map. Please treat this as a special operation. This should be handled by your implementation.");
    }

    @Override
    public boolean equals(Object o) {
        if (this == o)
            return true;
        if (o == null || getClass() != o.getClass())
            return false;

        ConcatenateStringColumns o2 = (ConcatenateStringColumns) o;
        return delimiter.equals(o2.delimiter) && columnsToConcatenate.equals(o2.columnsToConcatenate);
    }

    @Override
    public int hashCode() {
        int result = delimiter.hashCode();
        result = 31 * result + columnsToConcatenate.hashCode();
        return result;
    }

    @Override
    public String toString() {
        return "ConcatenateStringColumns(delimiters=" + delimiter + " columnsToConcatenate=" + columnsToConcatenate
                        + ")";

    }

    /**
     * The output column name
     * after the operation has been applied
     *
     * @return the output column name
     */
    @Override
    public String outputColumnName() {
        return newColumnName;
    }

    /**
     * The output column names
     * This will often be the same as the input
     *
     * @return the output column names
     */
    @Override
    public String[] outputColumnNames() {
        return new String[] {newColumnName};
    }

    /**
     * Returns column names
     * this op is meant to run on
     *
     * @return
     */
    @Override
    public String[] columnNames() {
        return columnsToConcatenate.toArray(new String[getInputSchema().getColumnNames().size()]);
    }

    /**
     * Returns a singular column name
     * this op is meant to run on
     *
     * @return
     */
    @Override
    public String columnName() {
        return columnsToConcatenate.get(0);
    }
}