/* * _______ _ ____ _ _ * |__ __| | | / ____| | | | * | | __ _ _ __ ___ ___ ___| | | (___ | |___| | * | |/ _` | '__/ __|/ _ \/ __| | \___ \| ___ | * | | (_| | | \__ \ (_) \__ \ |____ ____) | | | | * |_|\__,_|_| |___/\___/|___/_____/|_____/|_| |_| * * ----------------------------------------------------------- * * TarsosLSH is developed by Joren Six. * * ----------------------------------------------------------- * * Info : http://0110.be/tag/TarsosLSH * Github : https://github.com/JorenSix/TarsosLSH * Releases: http://0110.be/releases/TarsosLSH/ * */ package be.tarsos.lsh; import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.PriorityQueue; import java.util.Random; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import be.tarsos.lsh.families.DistanceComparator; import be.tarsos.lsh.families.DistanceMeasure; import be.tarsos.lsh.families.HashFamily; import be.tarsos.lsh.util.FileUtils; /** * Implements a Locality Sensitive Hash scheme. * @author Joren Six */ public class LSH { List<Vector> dataset; private Index index; private final HashFamily hashFamily; public LSH(List<Vector> dataset,HashFamily hashFamily) { this.dataset = dataset; this.hashFamily = hashFamily; } /** * Build an index by creating a new one and adding each vector. * * @param numberOfHashes * The number of hashes to use in each hash table. * @param numberOfHashTables * The number of hash tables to use. */ public void buildIndex(int numberOfHashes, int numberOfHashTables){ // Do we want to deserialize or build a new index??? // index = new Index(hashFamily, numberOfHashes, numberOfHashTables); // Deserialization can cause duplicates? //index = Index.deserialize(hashFamily, numberOfHashes, numberOfHashTables); index = new Index(hashFamily, numberOfHashes, numberOfHashTables); if(dataset != null){ for(Vector vector : dataset){ index.index(vector); } Index.serialize(index); } } /** * Benchmark the current LSH construction. * @param neighboursSize the expected size of the neighbourhood. * @param measure The measure to use to check for correctness. */ public void benchmark(int neighboursSize,DistanceMeasure measure){ long startTime = 0; double linearSearchTime = 0; double lshSearchTime = 0; int numbercorrect = 0; int falsePositives = 0; int truePositives = 0; int falseNegatives = 0; //int intersectionSize = 0; for(int i = 0 ; i < dataset.size() ; i++){ Vector query = dataset.get(i); startTime = System.currentTimeMillis(); List<Vector> lshResult = index.query(query,neighboursSize); lshSearchTime += System.currentTimeMillis() - startTime; startTime = System.currentTimeMillis(); List<Vector> linearResult = linearSearch(dataset,query,neighboursSize,measure); linearSearchTime += System.currentTimeMillis() - startTime; Set<Vector> set = new HashSet<Vector>(); set.addAll(lshResult); set.addAll(linearResult); //intersectionSize += set.size(); //In the best case, LSH result and linear result contain the exact same elements. //The number of false positives is the number of vectors that exceed the number of linear results. falsePositives += set.size() - linearResult.size(); //The number of true positives is Union of results - intersection. truePositives += lshResult.size() + linearResult.size() - set.size(); //The number of false Negatives the number of vectors that exceed the number of lsh results . falseNegatives += set.size() - lshResult.size(); //result is only correct if all nearest neighbours are the same (rather strict). boolean correct = true; for(int j = 0 ; j < Math.min(lshResult.size(),linearResult.size()); j++){ correct = correct && lshResult.get(j)== linearResult.get(j); } if(correct){ numbercorrect++; } } double numberOfqueries = dataset.size(); double dataSetSize = dataset.size(); double precision = truePositives / Double.valueOf(truePositives+falsePositives) * 100; double recall = truePositives / Double.valueOf(truePositives+falseNegatives) * 100; double percentageCorrect = numbercorrect / dataSetSize * 100; double percentageTouched = index.getTouched()/numberOfqueries/dataSetSize*100; linearSearchTime/=1000.0; lshSearchTime/=1000.0; int hashes = index.getNumberOfHashes(); int hashTables = index.getNumberOfHashTables(); //System.out.printf("%10s%15s%10s%10s%10s%10s%10s%10s\n","#hashes","#hashTables","Correct","Touched","linear","lsh","Precision","Recall"); System.out.printf("%10d%15d%9.2f%%%9.2f%%%9.4fs%9.4fs%9.2f%%%9.2f%%\n",hashes,hashTables,percentageCorrect,percentageTouched,linearSearchTime,lshSearchTime,precision,recall); } /** * Find the nearest neighbours for a query in the index. * * @param query * The query vector. * @param neighboursSize * The size of the neighbourhood. The returned list length * contains the maximum number of elements, or less. Zero * elements are possible. * @return A list of nearest neigbours, according to the index. The returned * list length contains the maximum number of elements, or less. * Zero elements are possible. */ public List<Vector> query(final Vector query,int neighboursSize){ return index.query(query,neighboursSize); } /** * Search for the actual nearest neighbours for a query vector using an * exhaustive linear search. For each vector a priority queue is created, * the distance between the query and other vectors is used to sort the * priority queue. The closest k neighbours show up at the head of the * priority queue. * * @param dataset * The data set with a bunch of vectors. * @param query * The query vector. * @param resultSize * The k nearest neighbours to find. Returns k vectors if the * data set size is larger than k. * @param measure * The distance measure used to sort the priority queue with. * @return The list of k nearest neighbours to the query vector, according * to the given distance measure. */ public static List<Vector> linearSearch(List<Vector> dataset,final Vector query,int resultSize,DistanceMeasure measure){ DistanceComparator dc = new DistanceComparator(query, measure); PriorityQueue<Vector> pq = new PriorityQueue<Vector>(dataset.size(),dc); pq.addAll(dataset); List<Vector> vectors = new ArrayList<Vector>(); for(int i = 0 ; i < resultSize;i++){ vectors.add(pq.poll()); } return vectors; } /** * Read a data set from a text file. The file has the following contents, * with identifier being an optional string identifying the vector and a * list of N coordinates (which should be doubles). This results in an * N-dimensional vector. * * <pre> * [Identifier] coord1 coord2 ... coordN * [Identifier] coord1 coord2 ... coordN * </pre> * * For example a data set with two elements with 4 dimensions looks like * this: * * <pre> * Hans 12 24 18.5 -45.6 * Jane 13 19 -12.0 49.8 * </pre> * * * @param file * The file to read. * @param maxSize * The maximum number of elements in the data set (even if the * file defines more points). * @return a list of vectors, the data set. */ public static List<Vector> readDataset(String file,int maxSize) { List<Vector> ret = new ArrayList<Vector>(); List<String[]> data = FileUtils.readCSVFile(file, " ", -1); if(data.size() > maxSize){ data = data.subList(0, maxSize); } boolean firstColumnIsKey = false; try{ Double.parseDouble(data.get(0)[0]); }catch(Exception e){ firstColumnIsKey = true; } int dimensions = firstColumnIsKey ? data.get(0).length - 1 : data.get(0).length; int startIndex = firstColumnIsKey ? 1 : 0; for(String[] row : data){ Vector item = new Vector(dimensions); if(firstColumnIsKey){ item.setKey(row[0]); } for (int d = startIndex; d < row.length; d++) { double value = Double.parseDouble(row[d]); item.set(d - startIndex, value); } ret.add(item); } return ret; } static double determineRadius(List<Vector> dataset,DistanceMeasure measure,int timeout){ ExecutorService executor = Executors.newSingleThreadExecutor(); double radius = 0.0; DetermineRadiusTask drt = new DetermineRadiusTask(dataset,measure); Future<Double> future = executor.submit(drt); try { System.out.println("Determine radius.."); radius = 0.90 * future.get(timeout, TimeUnit.SECONDS); System.out.println("Determined radius: " + radius); } catch (TimeoutException e) { System.err.println("Terminated!"); radius = 0.90 * drt.getRadius(); } catch (InterruptedException e) { System.err.println("Execution interrupted!" + e.getMessage()); radius = 0.90 * drt.getRadius(); } catch (ExecutionException e) { radius = 0.90 * drt.getRadius(); } executor.shutdownNow(); return radius; } static class DetermineRadiusTask implements Callable<Double> { private double queriesDone = 0; private double radiusSum = 0.0; private final List<Vector> dataset; private final Random rand; private final DistanceMeasure measure; public DetermineRadiusTask(List<Vector> dataset,DistanceMeasure measure){ this.dataset = dataset; this.rand = new Random(); this.measure=measure; } @Override public Double call() throws Exception { for(int i = 0 ; i < 30; i ++){ Vector query = dataset.get(rand.nextInt(dataset.size())); List<Vector> result = linearSearch(dataset, query, 2, measure); //the first vector is the query self, the second the closest. radiusSum += measure.distance(query, result.get(1)); queriesDone++; } return radiusSum/queriesDone; } public double getRadius(){ return radiusSum/queriesDone; } } public static void main(String args[]) { CommandLineInterface cli = new CommandLineInterface(args); cli.parseArguments(); cli.startApplication(); } }