package org.broadinstitute.hellbender.tools;

import com.google.common.annotations.VisibleForTesting;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.reference.ReferenceSequence;
import htsjdk.samtools.reference.ReferenceSequenceFile;
import htsjdk.samtools.util.OverlapDetector;
import htsjdk.variant.variantcontext.VariantContext;
import htsjdk.variant.variantcontext.writer.VariantContextWriter;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.barclay.argparser.*;
import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import org.broadinstitute.hellbender.cmdline.*;
import org.broadinstitute.hellbender.cmdline.programgroups.SparkProgramGroup;
import org.broadinstitute.hellbender.engine.*;
import org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.spark.GATKSparkTool;
import org.broadinstitute.hellbender.engine.spark.SparkReadShard;
import org.broadinstitute.hellbender.exceptions.GATKException;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCaller;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCallerArgumentCollection;
import org.broadinstitute.hellbender.tools.walkers.haplotypecaller.HaplotypeCallerEngine;
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
import org.broadinstitute.hellbender.utils.reference.ReferenceBases;
import scala.Tuple2;

import java.io.IOException;
import java.io.Serializable;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

/**
 * Call germline SNPs and indels via local re-assembly of haplotypes
 *
 * This is an implementation of {@link HaplotypeCaller} using spark to distribute the computation.
 * It is still in an early stage of development and does not yet support all the options that the non-spark version does.
 *
 * Specifically it does not support the --dbsnp, --comp, and --bamOutput options.
 *
 */
@CommandLineProgramProperties(summary = "HaplotypeCaller on Spark", oneLineSummary = "HaplotypeCaller on Spark", programGroup = SparkProgramGroup.class)
public final class HaplotypeCallerSpark extends GATKSparkTool {
    private static final long serialVersionUID = 1L;

    @Argument(fullName= StandardArgumentDefinitions.OUTPUT_LONG_NAME, shortName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME, doc = "Single file to which variants should be written")
    public String output;

    @ArgumentCollection
    public final ShardingArgumentCollection shardingArgs = new ShardingArgumentCollection();

    public static class ShardingArgumentCollection implements Serializable {
        private static final long serialVersionUID = 1L;

        @Argument(fullName="readShardSize", shortName="readShardSize", doc = "Maximum size of each read shard, in bases. For good performance, this should be much larger than the maximum assembly region size.", optional = true)
        public int readShardSize = HaplotypeCaller.DEFAULT_READSHARD_SIZE;

        @Argument(fullName="readShardPadding", shortName="readShardPadding", doc = "Each read shard has this many bases of extra context on each side. Read shards must have as much or more padding than assembly regions.", optional = true)
        public int readShardPadding = HaplotypeCaller.DEFAULT_READSHARD_PADDING;

        @Argument(fullName = "minAssemblyRegionSize", shortName = "minAssemblyRegionSize", doc = "Minimum size of an assembly region", optional = true)
        public int minAssemblyRegionSize = HaplotypeCaller.DEFAULT_MIN_ASSEMBLY_REGION_SIZE;

        @Argument(fullName = "maxAssemblyRegionSize", shortName = "maxAssemblyRegionSize", doc = "Maximum size of an assembly region", optional = true)
        public int maxAssemblyRegionSize = HaplotypeCaller.DEFAULT_MAX_ASSEMBLY_REGION_SIZE;

        @Argument(fullName = "assemblyRegionPadding", shortName = "assemblyRegionPadding", doc = "Number of additional bases of context to include around each assembly region", optional = true)
        public int  assemblyRegionPadding = HaplotypeCaller.DEFAULT_ASSEMBLY_REGION_PADDING;

        @Advanced
        @Argument(fullName = "activeProbabilityThreshold", shortName = "activeProbabilityThreshold", doc="Minimum probability for a locus to be considered active.", optional = true)
        public double activeProbThreshold = HaplotypeCaller.DEFAULT_ACTIVE_PROB_THRESHOLD;

        @Advanced
        @Argument(fullName = "maxProbPropagationDistance", shortName = "maxProbPropagationDistance", doc="Upper limit on how many bases away probability mass can be moved around when calculating the boundaries between active and inactive assembly regions", optional = true)
        public int maxProbPropagationDistance = HaplotypeCaller.DEFAULT_MAX_PROB_PROPAGATION_DISTANCE;

    }

    @ArgumentCollection
    public HaplotypeCallerArgumentCollection hcArgs = new HaplotypeCallerArgumentCollection();

    @Override
    public boolean requiresReads(){
        return true;
    }

    @Override
    public boolean requiresReference(){
        return true;
    }


    @Override
    protected void runTool(final JavaSparkContext ctx) {
        final List<SimpleInterval> intervals = hasIntervals() ? getIntervals() : IntervalUtils.getAllIntervalsForReference(getHeaderForReads().getSequenceDictionary());
        final JavaRDD<VariantContext> variants = callVariantsWithHaplotypeCaller(getAuthHolder(), ctx, getReads(), getHeaderForReads(), getReference(), intervals, hcArgs, shardingArgs);
        writeVariants(variants);
    }

    @Override
    public List<ReadFilter> getDefaultReadFilters() {
        return HaplotypeCallerEngine.makeStandardHCReadFilters();
    }

    /**
     * Call Variants using HaplotypeCaller on Spark and return an RDD of  {@link VariantContext}
     *
     * This may be called from any spark pipeline in order to call variants from an RDD of GATKRead
     *
     * @param authHolder authorization needed for the reading the reference
     * @param ctx the spark context
     * @param reads the reads variants should be called from
     * @param header the header that goes with the reads
     * @param reference the reference to use when calling
     * @param intervals the intervals to restrict calling to
     * @param hcArgs haplotype caller arguments
     * @param shardingArgs arguments to control how the assembly regions are sharded
     * @return an RDD of Variants
     */
    public static JavaRDD<VariantContext> callVariantsWithHaplotypeCaller(
            final AuthHolder authHolder,
            final JavaSparkContext ctx,
            final JavaRDD<GATKRead> reads,
            final SAMFileHeader header,
            final ReferenceMultiSource reference,
            final List<SimpleInterval> intervals,
            final HaplotypeCallerArgumentCollection hcArgs,
            final ShardingArgumentCollection shardingArgs) {
        Utils.validateArg(hcArgs.dbsnp.dbsnp == null, "HaplotypeCallerSpark does not yet support -D or --dbsnp arguments" );
        Utils.validateArg(hcArgs.comps.isEmpty(), "HaplotypeCallerSpark does not yet support -comp or --comp arguments" );
        Utils.validateArg(hcArgs.bamOutputPath == null, "HaplotypeCallerSpark does not yet support -bamout or --bamOutput");
        if ( !reference.isCompatibleWithSparkBroadcast()){
            throw new UserException.Require2BitReferenceForBroadcast();
        }

        final Broadcast<ReferenceMultiSource> referenceBroadcast = ctx.broadcast(reference);
        final Broadcast<HaplotypeCallerArgumentCollection> hcArgsBroadcast = ctx.broadcast(hcArgs);
        final OverlapDetector<ShardBoundary> overlaps = getShardBoundaryOverlapDetector(header, intervals, shardingArgs.readShardSize, shardingArgs.readShardPadding);
        final Broadcast<OverlapDetector<ShardBoundary>> shardBoundariesBroadcast = ctx.broadcast(overlaps);

        final JavaRDD<Shard<GATKRead>> readShards = createReadShards(shardBoundariesBroadcast, reads);

        final JavaRDD<Tuple2<AssemblyRegion, SimpleInterval>> assemblyRegions = readShards
                .mapPartitions(shardsToAssemblyRegions(authHolder, referenceBroadcast, hcArgsBroadcast, shardingArgs, header));

        return assemblyRegions.mapPartitions(callVariantsFromAssemblyRegions(authHolder, header, referenceBroadcast, hcArgsBroadcast));
    }

    /**
     * Call variants from Tuples of AssemblyRegion and Simple Interval
     * The interval should be the non-padded shard boundary for the shard that the corresponding AssemblyRegion was
     * created in, it's used to eliminate redundant variant calls at the edge of shard boundaries.
     */
    private static FlatMapFunction<Iterator<Tuple2<AssemblyRegion, SimpleInterval>>, VariantContext> callVariantsFromAssemblyRegions(
            final AuthHolder authHolder,
            final SAMFileHeader header,
            final Broadcast<ReferenceMultiSource> referenceBroadcast,
            final Broadcast<HaplotypeCallerArgumentCollection> hcArgsBroadcast) {
        return regionAndIntervals -> {
            //HaplotypeCallerEngine isn't serializable but is expensive to instantiate, so construct and reuse one for every partition
            final ReferenceMultiSourceAdapter referenceReader = new ReferenceMultiSourceAdapter(referenceBroadcast.getValue(), authHolder);
            final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgsBroadcast.value(), header, referenceReader);
            return iteratorToStream(regionAndIntervals).flatMap(regionToVariants(hcEngine)).iterator();
        };
    }

    private static <T> Stream<T> iteratorToStream(Iterator<T> iterator) {
        Iterable<T> regionsIterable = () -> iterator;
        return StreamSupport.stream(regionsIterable.spliterator(), false);
    }

    private static Function<Tuple2<AssemblyRegion, SimpleInterval>, Stream<? extends VariantContext>> regionToVariants(HaplotypeCallerEngine hcEngine) {
        return regionAndInterval -> {
            final List<VariantContext> variantContexts = hcEngine.callRegion(regionAndInterval._1(), new FeatureContext());
            final SimpleInterval shardBoundary = regionAndInterval._2();
            return variantContexts.stream()
                .filter(vc -> shardBoundary.contains(new SimpleInterval(vc.getContig(), vc.getStart(), vc.getStart())));
        };
    }

    /**
     * WriteVariants, this is currently going to be horribly slow and explosive on a full size file since it performs a collect.
     *
     * This will be replaced by a parallel writer similar to what's done with {@link org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink}
     */
    private void writeVariants(JavaRDD<VariantContext> variants) {
        final List<VariantContext> collectedVariants = variants.collect();
        final SAMSequenceDictionary referenceDictionary = getReferenceSequenceDictionary();

        final List<VariantContext> sortedVariants = collectedVariants.stream()
                .sorted((o1, o2) -> IntervalUtils.compareLocatables(o1, o2, referenceDictionary))
                .collect(Collectors.toList());

        final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgs, getHeaderForReads(), new ReferenceMultiSourceAdapter(getReference(), getAuthHolder()));
        try(final VariantContextWriter writer = hcEngine.makeVCFWriter(output, getBestAvailableSequenceDictionary())) {
            hcEngine.writeHeader(writer, getHeaderForReads().getSequenceDictionary(), Collections.emptySet());
            sortedVariants.forEach(writer::add);
        }
    }

    /**
     * Create an RDD of {@link Shard} from an RDD of {@link GATKRead}
     * @param shardBoundariesBroadcast  broadcast of an {@link OverlapDetector} loaded with the intervals that should be used for creating ReadShards
     * @param reads Rdd of {@link GATKRead}
     * @return a Rdd of reads grouped into potentially overlapping shards
     */
    private static JavaRDD<Shard<GATKRead>> createReadShards(final Broadcast<OverlapDetector<ShardBoundary>> shardBoundariesBroadcast, final JavaRDD<GATKRead> reads) {
        final JavaPairRDD<ShardBoundary, GATKRead> paired = reads.flatMapToPair(read -> {
            final Collection<ShardBoundary> overlappingShards = shardBoundariesBroadcast.value().getOverlaps(read);
            return overlappingShards.stream().map(key -> new Tuple2<>(key, read)).iterator();
        });
        final JavaPairRDD<ShardBoundary, Iterable<GATKRead>> shardsWithReads = paired.groupByKey();
        return shardsWithReads.map(shard -> new SparkReadShard(shard._1(), shard._2()));
    }

    /**
     * @return an {@link OverlapDetector} loaded with {@link ShardBoundary}
     * based on the -L intervals
     */
    private static OverlapDetector<ShardBoundary> getShardBoundaryOverlapDetector(final SAMFileHeader header, final List<SimpleInterval> intervals, final int readShardSize, final int readShardPadding) {
        final OverlapDetector<ShardBoundary> shardBoundaryOverlapDetector = new OverlapDetector<>(0, 0);
        intervals.stream()
                .flatMap(interval -> Shard.divideIntervalIntoShards(interval, readShardSize, readShardPadding, header.getSequenceDictionary()).stream())
                .forEach(boundary -> shardBoundaryOverlapDetector.addLhs(boundary, boundary.getPaddedInterval()));
        return shardBoundaryOverlapDetector;
    }

    /**
     * @return and RDD of {@link Tuple2<AssemblyRegion, SimpleInterval>} which pairs each AssemblyRegion with the
     * interval it was generated in
     */
    private static FlatMapFunction<Iterator<Shard<GATKRead>>, Tuple2<AssemblyRegion, SimpleInterval>> shardsToAssemblyRegions(
            final AuthHolder authHolder,
            final Broadcast<ReferenceMultiSource> reference,
            final Broadcast<HaplotypeCallerArgumentCollection> hcArgsBroadcast,
            final ShardingArgumentCollection assemblyArgs,
            final SAMFileHeader header) {
        return shards -> {
            final ReferenceMultiSource referenceMultiSource = reference.value();
            final ReferenceMultiSourceAdapter referenceSource = new ReferenceMultiSourceAdapter(referenceMultiSource, authHolder);
            final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgsBroadcast.value(), header, referenceSource);

            return iteratorToStream(shards).flatMap(shardToRegion(assemblyArgs, header, referenceSource, hcEngine)).iterator();
        };
    }

    private static Function<Shard<GATKRead>, Stream<? extends Tuple2<AssemblyRegion, SimpleInterval>>> shardToRegion(
            ShardingArgumentCollection assemblyArgs,
            SAMFileHeader header,
            ReferenceMultiSourceAdapter referenceSource,
            HaplotypeCallerEngine evaluator) {
        return shard -> {
            final ReferenceContext refContext = new ReferenceContext(referenceSource, shard.getPaddedInterval());

            //TODO load features as a side input
            final FeatureContext features = new FeatureContext();

            final Iterable<AssemblyRegion> assemblyRegions = AssemblyRegion.createFromReadShard(
                    shard, header, refContext, features, evaluator,
                    assemblyArgs.minAssemblyRegionSize, assemblyArgs.maxAssemblyRegionSize,
                    assemblyArgs.assemblyRegionPadding, assemblyArgs.activeProbThreshold,
                    assemblyArgs.maxProbPropagationDistance);

            return StreamSupport.stream(assemblyRegions.spliterator(), false)
                    .map(a -> new Tuple2<>(a, shard.getInterval()));
        };
    }

    /**
     * Adapter to allow a 2bit reference to be used in HaplotypeCallerEngine.
     * This is not intended as a general purpose adapter, it only enables the operations needed in {@link HaplotypeCallerEngine}
     * This should not be used outside of this class except for testing purposes.
     */
    @VisibleForTesting
    public static final class ReferenceMultiSourceAdapter implements ReferenceSequenceFile, ReferenceDataSource, Serializable{
        private static final long serialVersionUID = 1L;

        private final ReferenceMultiSource source;
        private final AuthHolder auth;
        private final SAMSequenceDictionary sequenceDictionary;

        public ReferenceMultiSourceAdapter(final ReferenceMultiSource source, final AuthHolder auth) {
            this.source = source;
            this.auth = auth;
            sequenceDictionary = source.getReferenceSequenceDictionary(null);
        }

        @Override
        public ReferenceSequence queryAndPrefetch(final String contig, final long start, final long stop) {
           return getSubsequenceAt(contig, start, stop);
        }

        @Override
        public SAMSequenceDictionary getSequenceDictionary() {
            return source.getReferenceSequenceDictionary(null);
        }

        @Override
        public ReferenceSequence nextSequence() {
            throw new UnsupportedOperationException("nextSequence is not implemented");
        }

        @Override
        public void reset() {
            throw new UnsupportedOperationException("reset is not implemented");
        }

        @Override
        public boolean isIndexed() {
            return true;
        }

        @Override
        public ReferenceSequence getSequence(final String contig) {
            throw new UnsupportedOperationException("getSequence is not supported");
        }

        @Override
        public ReferenceSequence getSubsequenceAt(final String contig, final long start, final long stop) {
            try {
                final ReferenceBases bases = source.getReferenceBases(auth.asPipelineOptionsDeprecated(), new SimpleInterval(contig, (int) start, (int) stop));
                return new ReferenceSequence(contig, sequenceDictionary.getSequenceIndex(contig), bases.getBases());
            } catch (final IOException e) {
                throw new GATKException(String.format("Failed to load reference bases for %s:%d-%d", contig, start, stop));
            }
        }

        @Override
        public void close() {
            // doesn't do anything because you can't close a two-bit file
        }

        @Override
        public Iterator<Byte> iterator() {
            throw new UnsupportedOperationException("iterator is not supported");
        }
    }

}