package org.ngseq.metagenomics;

import htsjdk.samtools.*;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.RecordWriter;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.lib.input.FileSplit;
import org.apache.hadoop.mapreduce.lib.input.NLineInputFormat;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.hive.HiveContext;
import org.seqdoop.hadoop_bam.*;
import org.seqdoop.hadoop_bam.util.SAMHeaderReader;
import scala.Tuple2;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;

/**
 * Created by Altti Ilari Maarala on 7/20/16.
 */

public class HDFSWriter {

    public static String alignmentTable = "alignments";
    public static String samheader =  "@HD\\tVN:0.7.15-r1140";

    public JavaPairRDD<LongWritable, SAMRecordWritable> getRecords() {
        return records;
    }

    private final JavaPairRDD<LongWritable, SAMRecordWritable> records = null;

    public HDFSWriter(JavaSparkContext sc) throws IOException {

    }

    public HDFSWriter(JavaSparkContext sc, String inputpath, boolean broadcastHeader) throws IOException {

        if(broadcastHeader){
            SAMFileHeader header = SAMHeaderReader.readSAMHeaderFrom(new Path(inputpath), sc.hadoopConfiguration());
            final Broadcast<SAMFileHeader> headerBc = sc.broadcast(header);
        }



    }

    public HDFSWriter(JavaRDD<String> alignmentRDD, String outpuDir, String format, JavaSparkContext sc) throws IOException {

        SAMFileHeader header = new SAMFileHeader();
        header.setTextHeader(samheader);

        if(format.equals("bam")||format.equals("sam")){

            if(format.equals("bam")){
                final Broadcast<SAMFileHeader> headerBc = sc.broadcast(header);
                BAMHeaderOutputFormat.setHeader(headerBc.getValue());

                JavaRDD<SAMRecord>  samRDD = alignmentsToSAM(alignmentRDD, header);
                //Map and save

                JavaRDD<SAMRecord> samHRDD = setPartitionHeaders(samRDD, headerBc);
                JavaPairRDD<SAMRecord, SAMRecordWritable> samWritableRDD = readsToWritable(samHRDD, headerBc);
                samWritableRDD.saveAsNewAPIHadoopFile(outpuDir, BAMRecord.class, SAMRecordWritable.class, BAMHeaderOutputFormat.class, sc.hadoopConfiguration());
            }
            if(format.equals("sam")){
                JavaRDD<SAMRecord>  samRDD = alignmentsToSAM(alignmentRDD, header);
                sc.hadoopConfiguration().setBoolean(KeyIgnoringAnySAMOutputFormat.WRITE_HEADER_PROPERTY, false);
                sc.hadoopConfiguration().set(AnySAMOutputFormat.OUTPUT_SAM_FORMAT_PROPERTY, "sam");
                final Broadcast<SAMFileHeader> headerBc = sc.broadcast(header);

                BAMHeaderOutputFormat.setHeader(headerBc.getValue());
                JavaRDD<SAMRecord> samHRDD = setPartitionHeaders(samRDD, headerBc);
                JavaPairRDD<SAMRecord, SAMRecordWritable> samWritableRDD = readsToWritable(samHRDD, headerBc);
                samWritableRDD.saveAsNewAPIHadoopFile(outpuDir, SAMRecord.class, SAMRecordWritable.class, BAMHeaderOutputFormat.class, sc.hadoopConfiguration());
            }
        }

        if(format.equals("fastq")){
            JavaPairRDD<Text, SequencedFragment> newfastqRDD = alignmentsToFastq(alignmentRDD, header);
            newfastqRDD.saveAsNewAPIHadoopFile(outpuDir, Text.class, SequencedFragment.class, FastqOutputFormat.class, sc.hadoopConfiguration());
        }
        if(format.equals("parquet")){
            JavaRDD<SAMRecord>  samRDD = alignmentsToSAM(alignmentRDD, header);
            JavaRDD<MyAlignment> rdd = samRDD.map(bam -> new MyAlignment(bam.getReadName(), bam.getStart(), bam.getReferenceName(), bam.getReadLength(), new String(bam.getReadBases(), StandardCharsets.UTF_8), bam.getCigarString(), bam.getReadUnmappedFlag(), bam.getDuplicateReadFlag()));

            SQLContext sqlContext = new HiveContext(sc.sc());
            Dataset samDF = sqlContext.createDataFrame(rdd, MyAlignment.class);
            samDF.registerTempTable(alignmentTable);
            samDF.write().parquet(outpuDir);
        }

    }

    private static JavaRDD<SAMRecord> alignmentsToSAM(JavaRDD<String> alignmentRDD, SAMFileHeader header) {
        return alignmentRDD.mapPartitions(alns -> {

            List<SAMRecord> records = new ArrayList<SAMRecord>();

            final SAMLineParser samLP = new SAMLineParser(new DefaultSAMRecordFactory(), ValidationStringency.SILENT, header, null, null);
            while (alns.hasNext()) {

                String aln = alns.next().replace("\r\n", "").replace("\n", "").replace(System.lineSeparator(), "");
                SAMRecord record = null;
                try{
                    record = samLP.parseLine(aln);
                    records.add(record);
                }catch(SAMFormatException e){
                    System.out.println(e.getMessage().toString());
                }
            }
            return records.iterator();
        });
    }

    private static JavaPairRDD<Text, SequencedFragment> alignmentsToFastq(JavaRDD<String> alignmentRDD, SAMFileHeader header) {
        return alignmentRDD.mapPartitionsToPair(alns -> {

            List<Tuple2<Text, SequencedFragment>> records = new ArrayList<Tuple2<Text, SequencedFragment>>();
            final SAMLineParser samLP = new SAMLineParser(new DefaultSAMRecordFactory(), ValidationStringency.SILENT, header, null, null);
            while (alns.hasNext()) {
                String aln = alns.next().replace("\r\n", "").replace("\n", "").replace(System.lineSeparator(), "");
                try{
                    SAMRecord sam = samLP.parseLine(aln);
                    String[] fields = aln.split("\\t");
                    String name = fields[0];
                    if(sam.getReadPairedFlag()){
                        if(sam.getFirstOfPairFlag())
                            name = name+"/1";
                        if(sam.getSecondOfPairFlag())
                            name = name+"/2";
                    }

                    String bases = fields[9];
                    String quality = fields[10];

                    Text t = new Text(name);
                    SequencedFragment sf = new SequencedFragment();
                    sf.setSequence(new Text(bases));
                    sf.setQuality(new Text(quality));
                    records.add(new Tuple2<Text, SequencedFragment>(t, sf));
                }catch(SAMFormatException e){
                    System.out.println(e.getMessage().toString());
                }
            }
            return records.iterator();
        });
    }

    private static JavaPairRDD<Text, SequencedFragment> interleaveReads(String fastq, String fastq2, int splitlen, JavaSparkContext sc) throws IOException {

        FileSystem fs = FileSystem.get(new Configuration());

        FileStatus fst = fs.getFileStatus(new Path(fastq));
        FileStatus fst2 = fs.getFileStatus(new Path(fastq2));

        List<FileSplit> nlif = NLineInputFormat.getSplitsForFile(fst, sc.hadoopConfiguration(), splitlen);
        List<FileSplit> nlif2 = NLineInputFormat.getSplitsForFile(fst2, sc.hadoopConfiguration(), splitlen);

        JavaRDD<FileSplit> splitRDD = sc.parallelize(nlif);
        JavaRDD<FileSplit> splitRDD2 = sc.parallelize(nlif2);
        JavaPairRDD<FileSplit, FileSplit> zips = splitRDD.zip(splitRDD2);

        return zips.flatMapToPair( splits ->  {

            FastqInputFormat.FastqRecordReader fqreader = new FastqInputFormat.FastqRecordReader(new Configuration(), splits._1);
            FastqInputFormat.FastqRecordReader fqreader2 = new FastqInputFormat.FastqRecordReader(new Configuration(), splits._2);

            ArrayList<Tuple2<Text, SequencedFragment>> reads = new ArrayList<Tuple2<Text, SequencedFragment>>();
            while (fqreader.nextKeyValue()) {
                String key = fqreader.getCurrentKey().toString();
                String[] keysplit = key.split(" ");
                key = keysplit[0];

                SequencedFragment sf = new SequencedFragment();
                sf.setQuality(new Text(fqreader.getCurrentValue().getQuality().toString()));
                sf.setSequence(new Text(fqreader.getCurrentValue().getSequence().toString()));

                if (fqreader2.nextKeyValue()) {

                    String key2 = fqreader2.getCurrentKey().toString();
                    String[] keysplit2 = key2.split(" ");
                    key2 = keysplit2[0];
                    //key2 = key2.replace(" 2:N:0:1","/2");

                    SequencedFragment sf2 = new SequencedFragment();
                    sf2.setQuality(new Text(fqreader2.getCurrentValue().getQuality().toString()));
                    sf2.setSequence(new Text(fqreader2.getCurrentValue().getSequence().toString()));
                    reads.add(new Tuple2<Text, SequencedFragment>(new Text(key), sf));
                    reads.add(new Tuple2<Text, SequencedFragment>(new Text(key2), sf2));
                }
            }

            return reads.iterator();

        });
    }

    private static JavaPairRDD<Text, SequencedFragment> mapSAMRecordsToFastq(JavaRDD<SAMRecord> bamRDD) {

        //Map SAMRecords to MyReads
        JavaPairRDD<Text, SequencedFragment> fastqRDD = bamRDD.mapToPair(read -> {

            String name = read.getReadName();
            if(read.getReadPairedFlag()){
                if(read.getFirstOfPairFlag())
                    name = name+"/1";
                if(read.getSecondOfPairFlag())
                    name = name+"/2";
            }

            //TODO: check values
            Text t = new Text(name);
            SequencedFragment sf = new SequencedFragment();
            sf.setSequence(new Text(read.getReadString()));
            sf.setQuality(new Text(read.getBaseQualityString()));

            return new Tuple2<Text, SequencedFragment>(t, sf);
        });
        return fastqRDD;
    }

    public void writeRecords(JavaRDD<SAMRecord> records, Broadcast<SAMFileHeader> header, String outpath, SparkContext sc) {

        JavaPairRDD<SAMRecord, SAMRecordWritable> bamWritableRDD = readsToWritable(records, header);

        //Distribute records to HDFS as BAM

        bamWritableRDD.saveAsNewAPIHadoopFile(outpath, SAMRecord.class, SAMRecordWritable.class, BAMHeaderOutputFormat.class, sc.hadoopConfiguration());
    }


    public static JavaPairRDD<SAMRecord, SAMRecordWritable> readsToWritable(JavaRDD<SAMRecord> records, Broadcast<SAMFileHeader> header) {
        return records.mapToPair(read -> {

            //SEQUENCE DICTIONARY must be set here for the alignment because it's not given as header file
            //Set in alignment to sam map phase
            if(header.getValue().getSequenceDictionary()==null) header.getValue().setSequenceDictionary(new SAMSequenceDictionary());
            if(header.getValue().getSequenceDictionary().getSequence(read.getReferenceName())==null)
                header.getValue().getSequenceDictionary().addSequence(new SAMSequenceRecord(read.getReferenceName()));

            //read.setHeader(read.getHeader());
            read.setHeaderStrict(header.getValue());
            final SAMRecordWritable samRecordWritable = new SAMRecordWritable();
            samRecordWritable.set(read);
            return new Tuple2<>(read, samRecordWritable);
        });
    }

    public static JavaPairRDD<SAMRecord, SAMRecordWritable> readsToWritableNoRef(JavaRDD<SAMRecord> records) {
        return records.mapToPair(read -> {
            //read.setHeaderStrict(read.getHeader());
            read.setHeader(read.getHeader());
            final SAMRecordWritable samRecordWritable = new SAMRecordWritable();
            samRecordWritable.set(read);
            return new Tuple2<>(read, samRecordWritable);
        });
    }

    public static JavaPairRDD<SAMRecord, SAMRecordWritable> readsToWritableNoHeader(JavaRDD<SAMRecord> records) {
        return records.mapToPair(read -> {
            final SAMRecordWritable samRecordWritable = new SAMRecordWritable();
            samRecordWritable.set(read);
            return new Tuple2<>(read, samRecordWritable);
        });
    }

    public static JavaRDD<SAMRecord> setPartitionHeaders(final JavaRDD<SAMRecord> reads, final Broadcast<SAMFileHeader> header) {

        return reads.mapPartitions(records -> {
            //header.getValue().setTextHeader(header.getValue().getTextHeader()+"\\n@SQ\\tSN:"+records..getReferenceName());
            //record.setHeader(header);

            BAMHeaderOutputFormat.setHeader(header.getValue());
            return records;
        });
    }

    public static class BAMHeaderOutputFormat extends KeyIgnoringBAMOutputFormat<NullWritable>{
        public static SAMFileHeader samheader;
        public static boolean writeHeader = true;

        public static void setHeader(final SAMFileHeader header) {
            samheader = header;
        }

        public void setWriteHeader(final boolean s) {
            writeHeader = s;
        }

        @Override
        public RecordWriter<NullWritable, SAMRecordWritable> getRecordWriter(TaskAttemptContext ctx,
                                                                             Path outputPath) throws IOException {
            // the writers require a header in order to create a codec, even if
            // the header isn't being written out
            setSAMHeader(samheader);
            setWriteHeader(writeHeader);

            return super.getRecordWriter(ctx, outputPath);
        }
    }

}