package util; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; import java.util.concurrent.ThreadLocalRandom; import core.contracts.Dataset; import datasets.ListDataset; public class Sampler { private static Random rand = new Random(); public Sampler(Random rand) { // this.rand = rand; } //reference //https://stackoverflow.com/questions/4702036/take-n-random-elements-from-a-liste?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa public static <E> List<E> pickNRandomElements(List<E> list, int n, Random r) { int length = list.size(); if (length < n) return null; //We don't need to shuffle the whole list for (int i = length - 1; i >= length - n; --i) { Collections.swap(list, i , r.nextInt(i + 1)); } return list.subList(length - n, length); } public static <E> List<E> pickNRandomElements(List<E> list, int n) { return pickNRandomElements(list, n, ThreadLocalRandom.current()); } /** * An improved version (Durstenfeld) of the Fisher-Yates algorithm with O(n) time complexity * Permutes the given array * @param array array to be shuffled * reference * @url http://www.programming-algorithms.net/article/43676/Fisher-Yates-shuffle */ public static void fisherYatesKnuthShuffle(int[] array) { // Random r = new Random(); for (int i = array.length - 1; i > 0; i--) { int index = rand.nextInt(i); //swap int tmp = array[index]; array[index] = array[i]; array[i] = tmp; } } public static ListDataset uniform_sample(Dataset dataset, int n) { n = n > dataset.size() ? dataset.size() : n; ListDataset sample = new ListDataset(n, dataset.length()); int[] indices = new int[n]; for (int i = 0; i < n; i++) { indices[i] = i; } Sampler.fisherYatesKnuthShuffle(indices); for (int i = 0; i < n; i++) { sample.add(dataset.get_class(i), dataset.get_series(i)); } return sample; } //TODO naive implementation, quick fix public static ListDataset uniform_sample(Dataset dataset, int n, double[][] exclude) { ListDataset sample = Sampler.uniform_sample(dataset, n); int size = sample.size(); for (int i = 0; i < size; i++) { for (int j = 0; j < exclude.length; j++) { if (sample.get_series(i) == exclude[j]) { sample.remove(i); } } } return sample; } public static ListDataset stratified_sample(Map<Integer, ListDataset> data_per_class, int n_per_class, boolean shuffle, double[][] exclude) { ListDataset sample = new ListDataset(data_per_class.size() * n_per_class); ListDataset class_sample; int class_sample_size; for (Map.Entry<Integer, ListDataset> entry : data_per_class.entrySet()) { if (exclude == null) { class_sample = Sampler.uniform_sample(entry.getValue(), n_per_class); }else { class_sample = Sampler.uniform_sample(entry.getValue(), n_per_class, exclude); } class_sample_size = class_sample.size(); for (int i = 0; i < class_sample_size; i++) { sample.add(class_sample.get_class(i), class_sample.get_series(i)); } } if (shuffle) { sample.shuffle(); } return sample; } public static Map<Integer, ListDataset> stratified_sample_per_class( Map<Integer, ListDataset> data_per_class, int n_per_class, boolean shuffle_each_class, double[][] exclude) { Map<Integer, ListDataset> sample = new HashMap<Integer, ListDataset> (); ListDataset class_sample; for (Map.Entry<Integer, ListDataset> entry : data_per_class.entrySet()) { if (exclude == null) { class_sample = Sampler.uniform_sample(entry.getValue(), n_per_class); }else { class_sample = Sampler.uniform_sample(entry.getValue(), n_per_class, exclude); } if (shuffle_each_class) { class_sample.shuffle(); } sample.put(entry.getKey(), class_sample); } return sample; } }