/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.mitre.quaerite.analysis; import static org.mitre.quaerite.core.util.CommandLineUtil.getInt; import static org.mitre.quaerite.core.util.CommandLineUtil.getLong; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.Reader; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorCompletionService; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.DefaultParser; import org.apache.commons.cli.HelpFormatter; import org.apache.commons.cli.Option; import org.apache.commons.cli.Options; import org.apache.commons.cli.ParseException; import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVRecord; import org.apache.commons.io.input.BOMInputStream; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.mutable.MutableLong; import org.apache.log4j.Logger; import org.mitre.quaerite.connectors.SearchClient; import org.mitre.quaerite.connectors.SearchClientException; import org.mitre.quaerite.connectors.SearchClientFactory; import org.mitre.quaerite.core.stats.TokenDF; import org.mitre.quaerite.core.util.CommandLineUtil; import org.mitre.quaerite.core.util.MapUtil; public class CompareAnalyzers { static Options OPTIONS = new Options(); static Logger LOG = Logger.getLogger(SearchClient.class); private static int DEFAULT_NUM_THREADS = 10; private static int DEFAULT_MIN_SET_SIZE = 1; private static long DEFAULT_MIN_DF = 0; static { OPTIONS.addOption( Option.builder("s") .hasArg() .required() .desc("search server").build() ); OPTIONS.addOption(Option.builder("bf") .longOpt("baseField") .hasArg(true) .required() .desc("baseField").build() ); OPTIONS.addOption( Option.builder("ff").longOpt("filteredField") .hasArg() .required() .desc("filtered field").build() ); OPTIONS.addOption( Option.builder("q") .longOpt("queries") .hasArg() .required(false) .desc("query csv file to filter results -- UTF-8 csv with at least " + "'query' column header").build() ); OPTIONS.addOption( Option.builder("n") .longOpt("numThreads") .hasArg() .required(false) .desc("number of threads").build() ); OPTIONS.addOption( Option.builder("minSetSize") .longOpt("minEquivalenceSetSize") .hasArg() .required(false) .desc("minimum size for an equivalence set (default =" + DEFAULT_MIN_SET_SIZE + ")").build() ); OPTIONS.addOption( Option.builder("minDF") .longOpt("minDocumentFrequency") .hasArg() .required(false) .desc("minimum document frequency (default = 0)").build() ); } private int numThreads = DEFAULT_NUM_THREADS; private Set<String> targetTokens = Collections.EMPTY_SET; public static void main(String[] args) throws Exception { CommandLine commandLine = null; try { commandLine = new DefaultParser().parse(OPTIONS, args); } catch (ParseException e) { HelpFormatter helpFormatter = new HelpFormatter(); helpFormatter.printHelp( "java -jar org.mitre.quaerite.analysis.CompareAnalyzers", OPTIONS); return; } SearchClient client = SearchClientFactory.getClient(commandLine.getOptionValue("s")); CompareAnalyzers compareAnalyzers = new CompareAnalyzers(); int minSetSize = getInt(commandLine, "minSetSize", DEFAULT_MIN_SET_SIZE); long minDF = getLong(commandLine, "minDF", DEFAULT_MIN_DF); compareAnalyzers.setNumThreads( getInt(commandLine, "n", DEFAULT_NUM_THREADS)); String filteredField = commandLine.getOptionValue("ff"); String baseField = commandLine.getOptionValue("bf"); Set<String> targetTokens; List<QueryTokenPair> queryTokenPairs; if (commandLine.hasOption("q")) { targetTokens = ConcurrentHashMap.newKeySet(); queryTokenPairs = loadQueries( CommandLineUtil.getPath(commandLine, "q", true), client, baseField, filteredField); for (QueryTokenPair p : queryTokenPairs) { targetTokens.addAll(p.getTokens()); } } else { targetTokens = Collections.EMPTY_SET; queryTokenPairs = Collections.EMPTY_LIST; } compareAnalyzers.setTargetTokens(targetTokens); Map<String, EquivalenceSet> map = compareAnalyzers.compare( client, baseField, filteredField); for (Map.Entry<String, EquivalenceSet> e : map.entrySet()) { if (e.getValue().getMap().size() > minSetSize) { boolean printed = false; for (Map.Entry<String, MutableLong> orig : e.getValue().getSortedMap().entrySet()) { if (orig.getValue().longValue() < minDF) { continue; } if (!printed) { System.out.println(e.getKey()); printed = true; } System.out.println("\t" + orig.getKey() + ": " + orig.getValue()); } } } System.out.println("\n\nQUERY...\n\n\n"); int maxEquivalences = 10; for (QueryTokenPair q : queryTokenPairs) { System.out.println(q.query); for (String token : q.getTokens()) { EquivalenceSet e = map.get(token); if (e == null) { System.out.println("\t" + token); } else { boolean printed = false; int equivs = 0; for (Map.Entry<String, MutableLong> orig : e.getSortedMap().entrySet()) { if (!printed) { System.out.println("\t" + token); printed = true; } System.out.println("\t\t" + orig.getKey() + ": " + orig.getValue()); if (equivs++ >= maxEquivalences) { System.out.println("\t\t..."); break; } } } } System.out.println("\n"); } } private static List<QueryTokenPair> loadQueries(Path path, SearchClient searchClient, String baseField, String filterField) throws IOException, SearchClientException { Set<String> queries = new HashSet<>(); try (InputStream is = Files.newInputStream(path)) { try (Reader reader = new InputStreamReader(new BOMInputStream(is), "UTF-8")) { Iterable<CSVRecord> records = CSVFormat.EXCEL .withFirstRecordAsHeader().parse(reader); for (CSVRecord record : records) { String query = record.get("query"); queries.add(query); } } } List<QueryTokenPair> queryTokenPairs = new ArrayList<>(); int max = 0; for (String query : queries) { List<String> baseAnalyzed = searchClient.analyze(baseField, query); List<String> allFiltered = new ArrayList<>(); for (String baseToken : baseAnalyzed) { List<String> filtered = searchClient.analyze(filterField, baseToken); if (filtered.size() == 0) { filtered.add(""); } allFiltered.add(StringUtils.join(filtered, ", ")); } queryTokenPairs.add( new QueryTokenPair(query, allFiltered) ); } return queryTokenPairs; } private void setTargetTokens(Set<String> targetTokens) { this.targetTokens = targetTokens; } private void setNumThreads(int numThreads) { this.numThreads = numThreads; } public Map<String, EquivalenceSet> compare(SearchClient client, String baseField, String filteredField) { ArrayBlockingQueue<Set<TokenDF>> queue = new ArrayBlockingQueue<>(100); List<ReAnalyzer> reAnalyzers = new ArrayList<>(); for (int i = 0; i < numThreads; i++) { reAnalyzers.add(new ReAnalyzer(queue, client, filteredField)); } ExecutorService executorService = Executors.newFixedThreadPool(numThreads + 1); ExecutorCompletionService<Integer> completionService = new ExecutorCompletionService<>(executorService); completionService.submit(new TermGetter(queue, numThreads, client, baseField)); for (int i = 0; i < numThreads; i++) { completionService.submit(reAnalyzers.get(i)); } //map int completed = 0; int totalAnalyzed = 0; while (completed < numThreads + 1) { try { Future<Integer> future = completionService.poll(1, TimeUnit.SECONDS); if (future != null) { int analyzed = future.get(); if (analyzed > 0) { totalAnalyzed += analyzed; } completed++; } } catch (Exception e) { throw new RuntimeException(e); } } LOG.info("Analyzed " + totalAnalyzed); executorService.shutdownNow(); //reduce Map<String, EquivalenceSet> overall = new HashMap<>(); for (ReAnalyzer reAnalyzer : reAnalyzers) { reduce(reAnalyzer.getMap(), overall); } return MapUtil.sortByDescendingValue(overall); } private void reduce(Map<String, EquivalenceSet> src, Map<String, EquivalenceSet> overall) { for (Map.Entry<String, EquivalenceSet> e : src.entrySet()) { String filtered = e.getKey(); EquivalenceSet overallEs = overall.get(filtered); if (overallEs == null) { overall.put(filtered, e.getValue()); } else { mergeInto(e.getValue(), overallEs); } } } private void mergeInto(EquivalenceSet es, EquivalenceSet overallEs) { for (Map.Entry<String, MutableLong> e : es.getMap().entrySet()) { overallEs.addTerm(e.getKey(), e.getValue().longValue()); } } private static class TermGetter implements Callable<Integer> { private final int termSetSize = 100; private final int minDF = 0; private final ArrayBlockingQueue<Set<TokenDF>> queue; private final int numThreads; private final SearchClient client; private final String field; public TermGetter(ArrayBlockingQueue<Set<TokenDF>> queue, int numThreads, SearchClient client, String field) { this.queue = queue; this.numThreads = numThreads; this.client = client; this.field = field; } @Override public Integer call() throws Exception { String lower = ""; while (true) { List<TokenDF> terms = client.getTerms(field, lower, termSetSize, minDF); if (terms.size() == 0) { break; } Set<TokenDF> tdf = new HashSet<>(terms); boolean added = queue.offer(tdf, 1, TimeUnit.SECONDS); while (added == false) { added = queue.offer(tdf, 1, TimeUnit.SECONDS); LOG.debug("waiting to offer"); } lower = terms.get(terms.size() - 1).getToken(); } for (int i = 0; i < numThreads; i++) { boolean added = queue.offer(Collections.EMPTY_SET, 1, TimeUnit.SECONDS); while (added == false) { added = queue.offer(Collections.EMPTY_SET, 1, TimeUnit.SECONDS); LOG.debug("waiting to offer poison"); } } return -1; } } private static class QueryTokenPair { private final String query; private final List<String> filteredTokens; public QueryTokenPair(String query, List<String> filteredTokens) { this.query = query; this.filteredTokens = filteredTokens; } public List<String> getTokens() { return filteredTokens; } } private class ReAnalyzer implements Callable<Integer> { private final ArrayBlockingQueue<Set<TokenDF>> queue; private final SearchClient client; private final String field; private final Map<String, EquivalenceSet> equivalenceMap = new HashMap<>(); public ReAnalyzer(ArrayBlockingQueue<Set<TokenDF>> queue, SearchClient client, String field) { this.queue = queue; this.client = client; this.field = field; } @Override public Integer call() throws Exception { int analyzed = 0; while (true) { Set<TokenDF> set = queue.take(); if (set != null) { if (set.size() == 0) { break; } for (TokenDF tdf : set) { String filtered = analyze(client, field, tdf.getToken()); analyzed++; if (filtered == null) { continue; } if (targetTokens.size() == 0 || targetTokens.contains(filtered)) { EquivalenceSet es = equivalenceMap.get(filtered); if (es == null) { es = new EquivalenceSet(); es.addTerm(tdf.getToken(), tdf.getDf()); equivalenceMap.put(filtered, es); } else { es.addTerm(tdf.getToken(), tdf.getDf()); } } } } } return analyzed; } private String analyze(SearchClient client, String field, String s) { List<String> tokens = null; try { tokens = client.analyze(field, s); } catch (IOException | SearchClientException e) { LOG.warn(e); return null; } return StringUtils.join(tokens, "|"); } public Map<String, EquivalenceSet> getMap() { return equivalenceMap; } } }