/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.spark.util;

import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.io.IOUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.LocatedFileStatus;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.RemoteIterator;
import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.serializer.SerializerInstance;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.data.BatchDataSetsFunction;
import org.deeplearning4j.spark.data.shuffle.SplitDataSetExamplesPairFlatMapFunction;
import org.deeplearning4j.spark.impl.common.CountPartitionsFunction;
import org.deeplearning4j.spark.impl.common.SplitPartitionsFunction;
import org.deeplearning4j.spark.impl.common.SplitPartitionsFunction2;
import org.deeplearning4j.spark.impl.common.repartition.BalancedPartitioner;
import org.deeplearning4j.spark.impl.common.repartition.HashingBalancedPartitioner;
import org.deeplearning4j.spark.impl.common.repartition.MapTupleToPairFlatMap;
import org.deeplearning4j.spark.impl.repartitioner.EqualRepartitioner;
import org.deeplearning4j.core.util.UIDProvider;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import scala.Tuple2;

import java.io.*;
import java.lang.reflect.Array;
import java.net.URI;
import java.nio.ByteBuffer;
import java.util.*;

/**
 * Various utilities for Spark
 *
 * @author Alex Black
 */
@Slf4j
public class SparkUtils {

    private static final String KRYO_EXCEPTION_MSG = "Kryo serialization detected without an appropriate registrator "
                    + "for ND4J INDArrays.\nWhen using Kryo, An appropriate Kryo registrator must be used to avoid"
                    + " serialization issues (NullPointerException) with off-heap data in INDArrays.\n"
                    + "Use nd4j-kryo_2.10 or _2.11 artifact, with sparkConf.set(\"spark.kryo.registrator\", \"org.nd4j.kryo.Nd4jRegistrator\");\n"
                    + "See https://deeplearning4j.konduit.ai/distributed-deep-learning/howto#how-to-use-kryo-serialization-with-dl-4-j-and-nd-4-j for more details";

    private static String sparkExecutorId;

    private SparkUtils() {}

    /**
     * Check the spark configuration for incorrect Kryo configuration, logging a warning message if necessary
     *
     * @param javaSparkContext Spark context
     * @param log              Logger to log messages to
     * @return True if ok (no kryo, or correct kryo setup)
     */
    public static boolean checkKryoConfiguration(JavaSparkContext javaSparkContext, Logger log) {
        //Check if kryo configuration is correct:
        String serializer = javaSparkContext.getConf().get("spark.serializer", null);
        if (serializer != null && serializer.equals("org.apache.spark.serializer.KryoSerializer")) {
            String kryoRegistrator = javaSparkContext.getConf().get("spark.kryo.registrator", null);
            if (kryoRegistrator == null || !kryoRegistrator.equals("org.nd4j.kryo.Nd4jRegistrator")) {

                //It's probably going to fail later due to Kryo failing on the INDArray deserialization (off-heap data)
                //But: the user might be using a custom Kryo registrator that can handle ND4J INDArrays, even if they
                // aren't using the official ND4J-provided one
                //Either way: Let's test serialization now of INDArrays now, and fail early if necessary
                SerializerInstance si;
                ByteBuffer bb;
                try {
                    si = javaSparkContext.env().serializer().newInstance();
                    bb = si.serialize(Nd4j.linspace(1, 5, 5), null);
                } catch (Exception e) {
                    //Failed for some unknown reason during serialization - should never happen
                    throw new RuntimeException(KRYO_EXCEPTION_MSG, e);
                }

                if (bb == null) {
                    //Should probably never happen
                    throw new RuntimeException(
                                    KRYO_EXCEPTION_MSG + "\n(Got: null ByteBuffer from Spark SerializerInstance)");
                } else {
                    //Could serialize successfully, but still may not be able to deserialize if kryo config is wrong
                    boolean equals;
                    INDArray deserialized;
                    try {
                        deserialized = (INDArray) si.deserialize(bb, null);
                        //Equals method may fail on malformed INDArrays, hence should be within the try-catch
                        equals = Nd4j.linspace(1, 5, 5).equals(deserialized);
                    } catch (Exception e) {
                        throw new RuntimeException(KRYO_EXCEPTION_MSG, e);
                    }
                    if (!equals) {
                        throw new RuntimeException(KRYO_EXCEPTION_MSG + "\n(Error during deserialization: test array"
                                        + " was not deserialized successfully)");
                    }

                    //Otherwise: serialization/deserialization was successful using Kryo
                    return true;
                }
            }
        }
        return true;
    }

    /**
     * Write a String to a file (on HDFS or local) in UTF-8 format
     *
     * @param path    Path to write to
     * @param toWrite String to write
     * @param sc      Spark context
     */
    public static void writeStringToFile(String path, String toWrite, JavaSparkContext sc) throws IOException {
        writeStringToFile(path, toWrite, sc.sc());
    }

    /**
     * Write a String to a file (on HDFS or local) in UTF-8 format
     *
     * @param path    Path to write to
     * @param toWrite String to write
     * @param sc      Spark context
     */
    public static void writeStringToFile(String path, String toWrite, SparkContext sc) throws IOException {
        FileSystem fileSystem = FileSystem.get(sc.hadoopConfiguration());
        try (BufferedOutputStream bos = new BufferedOutputStream(fileSystem.create(new Path(path)))) {
            bos.write(toWrite.getBytes("UTF-8"));
        }
    }

    /**
     * Read a UTF-8 format String from HDFS (or local)
     *
     * @param path Path to write the string
     * @param sc   Spark context
     */
    public static String readStringFromFile(String path, JavaSparkContext sc) throws IOException {
        return readStringFromFile(path, sc.sc());
    }

    /**
     * Read a UTF-8 format String from HDFS (or local)
     *
     * @param path Path to write the string
     * @param sc   Spark context
     */
    public static String readStringFromFile(String path, SparkContext sc) throws IOException {
        FileSystem fileSystem = FileSystem.get(sc.hadoopConfiguration());
        try (BufferedInputStream bis = new BufferedInputStream(fileSystem.open(new Path(path)))) {
            byte[] asBytes = IOUtils.toByteArray(bis);
            return new String(asBytes, "UTF-8");
        }
    }

    /**
     * Write an object to HDFS (or local) using default Java object serialization
     *
     * @param path    Path to write the object to
     * @param toWrite Object to write
     * @param sc      Spark context
     */
    public static void writeObjectToFile(String path, Object toWrite, JavaSparkContext sc) throws IOException {
        writeObjectToFile(path, toWrite, sc.sc());
    }

    /**
     * Write an object to HDFS (or local) using default Java object serialization
     *
     * @param path    Path to write the object to
     * @param toWrite Object to write
     * @param sc      Spark context
     */
    public static void writeObjectToFile(String path, Object toWrite, SparkContext sc) throws IOException {
        FileSystem fileSystem = FileSystem.get(sc.hadoopConfiguration());
        try (BufferedOutputStream bos = new BufferedOutputStream(fileSystem.create(new Path(path)))) {
            ObjectOutputStream oos = new ObjectOutputStream(bos);
            oos.writeObject(toWrite);
        }
    }

    /**
     * Read an object from HDFS (or local) using default Java object serialization
     *
     * @param path File to read
     * @param type Class of the object to read
     * @param sc   Spark context
     * @param <T>  Type of the object to read
     */
    public static <T> T readObjectFromFile(String path, Class<T> type, JavaSparkContext sc) throws IOException {
        return readObjectFromFile(path, type, sc.sc());
    }

    /**
     * Read an object from HDFS (or local) using default Java object serialization
     *
     * @param path File to read
     * @param type Class of the object to read
     * @param sc   Spark context
     * @param <T>  Type of the object to read
     */
    public static <T> T readObjectFromFile(String path, Class<T> type, SparkContext sc) throws IOException {
        FileSystem fileSystem = FileSystem.get(sc.hadoopConfiguration());
        try (ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(fileSystem.open(new Path(path))))) {
            Object o;
            try {
                o = ois.readObject();
            } catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }

            return (T) o;
        }
    }

    /**
     * Repartition the specified RDD (or not) using the given {@link Repartition} and {@link RepartitionStrategy} settings
     *
     * @param rdd                 RDD to repartition
     * @param repartition         Setting for when repartiting is to be conducted
     * @param repartitionStrategy Setting for how repartitioning is to be conducted
     * @param objectsPerPartition Desired number of objects per partition
     * @param numPartitions       Total number of partitions
     * @param <T>                 Type of the RDD
     * @return Repartitioned RDD, or original RDD if no repartitioning was conducted
     */
    public static <T> JavaRDD<T> repartition(JavaRDD<T> rdd, Repartition repartition,
                    RepartitionStrategy repartitionStrategy, int objectsPerPartition, int numPartitions) {
        if (repartition == Repartition.Never)
            return rdd;

        switch (repartitionStrategy) {
            case SparkDefault:
                if (repartition == Repartition.NumPartitionsWorkersDiffers && rdd.partitions().size() == numPartitions)
                    return rdd;

                //Either repartition always, or workers/num partitions differs
                return rdd.repartition(numPartitions);
            case Balanced:
                return repartitionBalanceIfRequired(rdd, repartition, objectsPerPartition, numPartitions);
            case ApproximateBalanced:
                return repartitionApproximateBalance(rdd, repartition, numPartitions);
            default:
                throw new RuntimeException("Unknown repartition strategy: " + repartitionStrategy);
        }
    }

    public static <T> JavaRDD<T> repartitionApproximateBalance(JavaRDD<T> rdd, Repartition repartition,
                    int numPartitions) {
        int origNumPartitions = rdd.partitions().size();
        switch (repartition) {
            case Never:
                return rdd;
            case NumPartitionsWorkersDiffers:
                if (origNumPartitions == numPartitions)
                    return rdd;
            case Always:
                // Count each partition...
                List<Integer> partitionCounts =
                                rdd.mapPartitionsWithIndex(new Function2<Integer, Iterator<T>, Iterator<Integer>>() {
                                    @Override
                                    public Iterator<Integer> call(Integer integer, Iterator<T> tIterator)
                                                    throws Exception {
                                        int count = 0;
                                        while (tIterator.hasNext()) {
                                            tIterator.next();
                                            count++;
                                        }
                                        return Collections.singletonList(count).iterator();
                                    }
                                }, true).collect();

                Integer totalCount = 0;
                for (Integer i : partitionCounts)
                    totalCount += i;
                List<Double> partitionWeights = new ArrayList<>(Math.max(numPartitions, origNumPartitions));
                Double ideal = (double) totalCount / numPartitions;
                // partitions in the initial set and not in the final one get -1 => elements always jump
                // partitions in the final set not in the initial one get 0 => aim to receive the average amount
                for (int i = 0; i < Math.min(origNumPartitions, numPartitions); i++) {
                    partitionWeights.add((double) partitionCounts.get(i) / ideal);
                }
                for (int i = Math.min(origNumPartitions, numPartitions); i < Math.max(origNumPartitions,
                                numPartitions); i++) {
                    // we shrink the # of partitions
                    if (i >= numPartitions)
                        partitionWeights.add(-1D);
                    // we enlarge the # of partitions
                    else
                        partitionWeights.add(0D);
                }

                // this method won't trigger a spark job, which is different from {@link org.apache.spark.rdd.RDD#zipWithIndex}

                JavaPairRDD<Tuple2<Long, Integer>, T> indexedRDD = rdd.zipWithUniqueId()
                                .mapToPair(new PairFunction<Tuple2<T, Long>, Tuple2<Long, Integer>, T>() {
                                    @Override
                                    public Tuple2<Tuple2<Long, Integer>, T> call(Tuple2<T, Long> tLongTuple2) {
                                        return new Tuple2<>(
                                                        new Tuple2<Long, Integer>(tLongTuple2._2(), 0),
                                                        tLongTuple2._1());
                                    }
                                });

                HashingBalancedPartitioner hbp =
                                new HashingBalancedPartitioner(Collections.singletonList(partitionWeights));
                JavaPairRDD<Tuple2<Long, Integer>, T> partitionedRDD = indexedRDD.partitionBy(hbp);

                return partitionedRDD.map(new Function<Tuple2<Tuple2<Long, Integer>, T>, T>() {
                    @Override
                    public T call(Tuple2<Tuple2<Long, Integer>, T> indexNPayload) {
                        return indexNPayload._2();
                    }
                });
            default:
                throw new RuntimeException("Unknown setting for repartition: " + repartition);
        }
    }

    /**
     * Repartition a RDD (given the {@link Repartition} setting) such that we have approximately
     * {@code numPartitions} partitions, each of which has {@code objectsPerPartition} objects.
     *
     * @param rdd RDD to repartition
     * @param repartition Repartitioning setting
     * @param objectsPerPartition Number of objects we want in each partition
     * @param numPartitions Number of partitions to have
     * @param <T> Type of RDD
     * @return Repartitioned RDD, or the original RDD if no repartitioning was performed
     */
    public static <T> JavaRDD<T> repartitionBalanceIfRequired(JavaRDD<T> rdd, Repartition repartition,
                    int objectsPerPartition, int numPartitions) {
        int origNumPartitions = rdd.partitions().size();
        switch (repartition) {
            case Never:
                return rdd;
            case NumPartitionsWorkersDiffers:
                if (origNumPartitions == numPartitions)
                    return rdd;
            case Always:
                //Repartition: either always, or origNumPartitions != numWorkers

                //First: count number of elements in each partition. Need to know this so we can work out how to properly index each example,
                // so we can in turn create properly balanced partitions after repartitioning
                //Because the objects (DataSets etc) should be small, this should be OK

                //Count each partition...
                List<Tuple2<Integer, Integer>> partitionCounts =
                                rdd.mapPartitionsWithIndex(new CountPartitionsFunction<T>(), true).collect();
                int totalObjects = 0;
                int initialPartitions = partitionCounts.size();

                boolean allCorrectSize = true;
                int x = 0;
                for (Tuple2<Integer, Integer> t2 : partitionCounts) {
                    int partitionSize = t2._2();
                    allCorrectSize &= (partitionSize == objectsPerPartition);
                    totalObjects += t2._2();
                }

                if (numPartitions * objectsPerPartition < totalObjects) {
                    allCorrectSize = true;
                    for (Tuple2<Integer, Integer> t2 : partitionCounts) {
                        allCorrectSize &= (t2._2() == objectsPerPartition);
                    }
                }

                if (initialPartitions == numPartitions && allCorrectSize) {
                    //Don't need to do any repartitioning here - already in the format we want
                    return rdd;
                }

                //Index each element for repartitioning (can only do manual repartitioning on a JavaPairRDD)
                JavaPairRDD<Integer, T> pairIndexed = indexedRDD(rdd);

                int remainder = (totalObjects - numPartitions * objectsPerPartition) % numPartitions;
                log.trace("About to rebalance: numPartitions={}, objectsPerPartition={}, remainder={}", numPartitions, objectsPerPartition, remainder);
                pairIndexed = pairIndexed
                                .partitionBy(new BalancedPartitioner(numPartitions, objectsPerPartition, remainder));
                return pairIndexed.values();
            default:
                throw new RuntimeException("Unknown setting for repartition: " + repartition);
        }
    }

    public static <T> JavaPairRDD<Integer, T> indexedRDD(JavaRDD<T> rdd) {
        return rdd.zipWithIndex().mapToPair(new PairFunction<Tuple2<T, Long>, Integer, T>() {
            @Override
            public Tuple2<Integer, T> call(Tuple2<T, Long> elemIdx) {
                return new Tuple2<>(elemIdx._2().intValue(), elemIdx._1());
            }
        });
    }

    public static <T> JavaRDD<T> repartitionEqually(JavaRDD<T> rdd, Repartition repartition, int numPartitions){
        int origNumPartitions = rdd.partitions().size();
        switch (repartition) {
            case Never:
                return rdd;
            case NumPartitionsWorkersDiffers:
                if (origNumPartitions == numPartitions)
                    return rdd;
            case Always:
                return new EqualRepartitioner().repartition(rdd, -1, numPartitions);
            default:
                throw new RuntimeException("Unknown setting for repartition: " + repartition);
        }
    }

    /**
     * Random split the specified RDD into a number of RDDs, where each has {@code numObjectsPerSplit} in them.
     * <p>
     * This similar to how RDD.randomSplit works (i.e., split via filtering), but this should result in more
     * equal splits (instead of independent binomial sampling that is used there, based on weighting)
     * This balanced splitting approach is important when the number of DataSet objects we want in each split is small,
     * as random sampling variance of {@link JavaRDD#randomSplit(double[])} is quite large relative to the number of examples
     * in each split. Note however that this method doesn't <i>guarantee</i> that partitions will be balanced
     * <p>
     * Downside is we need total object count (whereas {@link JavaRDD#randomSplit(double[])} does not). However, randomSplit
     * requires a full pass of the data anyway (in order to do filtering upon it) so this should not add much overhead in practice
     *
     * @param totalObjectCount   Total number of objects in the RDD to split
     * @param numObjectsPerSplit Number of objects in each split
     * @param data               Data to split
     * @param <T>                Generic type for the RDD
     * @return The RDD split up (without replacement) into a number of smaller RDDs
     */
    public static <T> JavaRDD<T>[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit, JavaRDD<T> data) {
        return balancedRandomSplit(totalObjectCount, numObjectsPerSplit, data, new Random().nextLong());
    }

    /**
     * Equivalent to {@link #balancedRandomSplit(int, int, JavaRDD)} with control over the RNG seed
     */
    public static <T> JavaRDD<T>[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit, JavaRDD<T> data,
                    long rngSeed) {
        JavaRDD<T>[] splits;
        if (totalObjectCount <= numObjectsPerSplit) {
            splits = (JavaRDD<T>[]) Array.newInstance(JavaRDD.class, 1);
            splits[0] = data;
        } else {
            int numSplits = totalObjectCount / numObjectsPerSplit; //Intentional round down
            splits = (JavaRDD<T>[]) Array.newInstance(JavaRDD.class, numSplits);
            for (int i = 0; i < numSplits; i++) {
                splits[i] = data.mapPartitionsWithIndex(new SplitPartitionsFunction<T>(i, numSplits, rngSeed), true);
            }

        }
        return splits;
    }

    /**
     * Equivalent to {@link #balancedRandomSplit(int, int, JavaRDD)} but for Pair RDDs
     */
    public static <T, U> JavaPairRDD<T, U>[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit,
                    JavaPairRDD<T, U> data) {
        return balancedRandomSplit(totalObjectCount, numObjectsPerSplit, data, new Random().nextLong());
    }

    /**
     * Equivalent to {@link #balancedRandomSplit(int, int, JavaRDD)} but for pair RDDs, and with control over the RNG seed
     */
    public static <T, U> JavaPairRDD<T, U>[] balancedRandomSplit(int totalObjectCount, int numObjectsPerSplit,
                    JavaPairRDD<T, U> data, long rngSeed) {
        JavaPairRDD<T, U>[] splits;
        if (totalObjectCount <= numObjectsPerSplit) {
            splits = (JavaPairRDD<T, U>[]) Array.newInstance(JavaPairRDD.class, 1);
            splits[0] = data;
        } else {
            int numSplits = totalObjectCount / numObjectsPerSplit; //Intentional round down

            splits = (JavaPairRDD<T, U>[]) Array.newInstance(JavaPairRDD.class, numSplits);
            for (int i = 0; i < numSplits; i++) {

                //What we really need is a .mapPartitionsToPairWithIndex function
                //but, of course Spark doesn't provide this
                //So we need to do a two-step process here...

                JavaRDD<Tuple2<T, U>> split = data.mapPartitionsWithIndex(
                                new SplitPartitionsFunction2<T, U>(i, numSplits, rngSeed), true);
                splits[i] = split.mapPartitionsToPair(new MapTupleToPairFlatMap<T, U>(), true);
            }
        }
        return splits;
    }

    /**
     * List of the files in the given directory (path), as a {@code JavaRDD<String>}
     *
     * @param sc      Spark context
     * @param path    Path to list files in
     * @return        Paths in the directory
     * @throws IOException If error occurs getting directory contents
     */
    public static JavaRDD<String> listPaths(JavaSparkContext sc, String path) throws IOException {
        return listPaths(sc, path, false);
    }

    /**
     * List of the files in the given directory (path), as a {@code JavaRDD<String>}
     *
     * @param sc        Spark context
     * @param path      Path to list files in
     * @param recursive Whether to walk the directory tree recursively (i.e., include subdirectories)
     * @return Paths in the directory
     * @throws IOException If error occurs getting directory contents
     */
    public static JavaRDD<String> listPaths(JavaSparkContext sc, String path, boolean recursive) throws IOException {
        //NativeImageLoader.ALLOWED_FORMATS
        return listPaths(sc, path, recursive, (Set<String>)null);
    }

    /**
     * List of the files in the given directory (path), as a {@code JavaRDD<String>}
     *
     * @param sc                Spark context
     * @param path              Path to list files in
     * @param recursive         Whether to walk the directory tree recursively (i.e., include subdirectories)
     * @param allowedExtensions If null: all files will be accepted. If non-null: only files with the specified extension will be allowed.
     *                          Exclude the extension separator - i.e., use "txt" not ".txt" here.
     * @return Paths in the directory
     * @throws IOException If error occurs getting directory contents
     */
    public static JavaRDD<String> listPaths(JavaSparkContext sc, String path, boolean recursive, String[] allowedExtensions) throws IOException {
        return listPaths(sc, path, recursive, (allowedExtensions == null ? null : new HashSet<>(Arrays.asList(allowedExtensions))));
    }

    /**
     * List of the files in the given directory (path), as a {@code JavaRDD<String>}
     *
     * @param sc                Spark context
     * @param path              Path to list files in
     * @param recursive         Whether to walk the directory tree recursively (i.e., include subdirectories)
     * @param allowedExtensions If null: all files will be accepted. If non-null: only files with the specified extension will be allowed.
     *                          Exclude the extension separator - i.e., use "txt" not ".txt" here.
     * @return Paths in the directory
     * @throws IOException If error occurs getting directory contents
     */
    public static JavaRDD<String> listPaths(JavaSparkContext sc, String path, boolean recursive, Set<String> allowedExtensions) throws IOException {
        return listPaths(sc, path, recursive, allowedExtensions, sc.hadoopConfiguration());
    }

    /**
     * List of the files in the given directory (path), as a {@code JavaRDD<String>}
     *
     * @param sc                Spark context
     * @param path              Path to list files in
     * @param recursive         Whether to walk the directory tree recursively (i.e., include subdirectories)
     * @param allowedExtensions If null: all files will be accepted. If non-null: only files with the specified extension will be allowed.
     *                          Exclude the extension separator - i.e., use "txt" not ".txt" here.
     * @param config            Hadoop configuration to use. Must not be null.
     * @return Paths in the directory
     * @throws IOException If error occurs getting directory contents
     */
    public static JavaRDD<String> listPaths(@NonNull JavaSparkContext sc, String path, boolean recursive,
                                            Set<String> allowedExtensions, @NonNull Configuration config) throws IOException {
        List<String> paths = new ArrayList<>();
        FileSystem hdfs = FileSystem.get(URI.create(path), config);
        RemoteIterator<LocatedFileStatus> fileIter = hdfs.listFiles(new org.apache.hadoop.fs.Path(path), recursive);

        while (fileIter.hasNext()) {
            String filePath = fileIter.next().getPath().toString();
            if(allowedExtensions == null){
                paths.add(filePath);
            } else {
                String ext = FilenameUtils.getExtension(path);
                if(allowedExtensions.contains(ext)){
                    paths.add(filePath);
                }
            }
        }
        return sc.parallelize(paths);
    }


    /**
     * Randomly shuffle the examples in each DataSet object, and recombine them into new DataSet objects
     * with the specified BatchSize
     *
     * @param rdd DataSets to shuffle/recombine
     * @param newBatchSize New batch size for the DataSet objects, after shuffling/recombining
     * @param numPartitions Number of partitions to use when splitting/recombining
     * @return A new {@link JavaRDD<DataSet>}, with the examples shuffled/combined in each
     */
    public static JavaRDD<DataSet> shuffleExamples(JavaRDD<DataSet> rdd, int newBatchSize, int numPartitions) {
        //Step 1: split into individual examples, mapping to a pair RDD (random key in range 0 to numPartitions)

        JavaPairRDD<Integer, DataSet> singleExampleDataSets =
                        rdd.flatMapToPair(new SplitDataSetExamplesPairFlatMapFunction(numPartitions));

        //Step 2: repartition according to the random keys
        singleExampleDataSets = singleExampleDataSets.partitionBy(new HashPartitioner(numPartitions));

        //Step 3: Recombine
        return singleExampleDataSets.values().mapPartitions(new BatchDataSetsFunction(newBatchSize));
    }

    /**
     * Get the Spark executor ID<br>
     * The ID is parsed from the JVM launch args. If that is not specified (or can't be obtained) then the value
     * from {@link UIDProvider#getJVMUID()} is returned
     * @return
     */
    public static String getSparkExecutorId(){
        if(sparkExecutorId != null)
            return sparkExecutorId;

        synchronized (SparkUtils.class){
            //re-check, in case some other thread set it while waiting for lock
            if(sparkExecutorId != null)
                return sparkExecutorId;

            String s = System.getProperty("sun.java.command");
            if(s == null || s.isEmpty() || !s.contains("executor-id")){
                sparkExecutorId = UIDProvider.getJVMUID();
                return sparkExecutorId;
            }

            int idx = s.indexOf("executor-id");
            String sub = s.substring(idx);
            String[] split = sub.split(" ");
            if(split.length < 2){
                sparkExecutorId = UIDProvider.getJVMUID();
                return sparkExecutorId;
            }
            sparkExecutorId = split[1];
            return sparkExecutorId;
        }
    }

    public static Broadcast<byte[]> asByteArrayBroadcast(JavaSparkContext sc, INDArray array){
        ByteArrayOutputStream baos = new ByteArrayOutputStream();
        try {
            Nd4j.write(array, new DataOutputStream(baos));
        } catch (IOException e){
            throw new RuntimeException(e);  //Should never happen
        }
        byte[] paramBytes = baos.toByteArray();       //See docs in EvaluationRunner for why we use byte[] instead of INDArray (thread locality etc)
        return sc.broadcast(paramBytes);
    }
}