/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package org.mitre.quaerite.solrtools; import java.io.BufferedWriter; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.Reader; import java.io.Writer; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.text.DecimalFormat; import java.text.DecimalFormatSymbols; import java.text.NumberFormat; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.DefaultParser; import org.apache.commons.cli.HelpFormatter; import org.apache.commons.cli.Option; import org.apache.commons.cli.Options; import org.apache.commons.cli.ParseException; import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVRecord; import org.apache.commons.io.input.BOMInputStream; import org.apache.commons.lang3.StringUtils; import org.apache.log4j.Logger; import org.mitre.quaerite.connectors.QueryRequest; import org.mitre.quaerite.connectors.SearchClient; import org.mitre.quaerite.connectors.SearchClientFactory; import org.mitre.quaerite.core.SearchResultSet; import org.mitre.quaerite.core.queries.TermQuery; import org.mitre.quaerite.core.util.MapUtil; public class ElevateQueryComparer { static Logger LOG = Logger.getLogger(ElevateQueryComparer.class); static Options OPTIONS = new Options(); static { OPTIONS.addOption( Option.builder("s") .hasArg().required().desc("solr url").build() ); OPTIONS.addOption( Option.builder("e") .longOpt("elevate") .hasArg(true) .required(true) .desc("elevate file (xml)").build() ); OPTIONS.addOption( Option.builder("q") .longOpt("queries") .hasArg(true) .required(true) .desc("queries (with optional counts)").build() ); OPTIONS.addOption( Option.builder("d") .longOpt("outputDirectory") .hasArg(true) .required(false) .desc("directory to which to write reports").build() ); //if you are analyzing e.g. top 10k of GoogleAnalytics //you'll need to supply the actual total number of queries OPTIONS.addOption( Option.builder("t") .longOpt("totalQueries") .hasArg(true) .required(false) .desc("denominator for total number of queries -- " + "sum of queries if used if this is not specified").build() ); //let's say you have a single solr index that hosts //several logical indices: "general", "catlovers", "doglovers", //and there ids include the logical index, e.g. general-1; catlovers-1 //you may only want to focus on ids that match this regex: //~/(?i)general-\\d+/ OPTIONS.addOption( Option.builder("r") .longOpt("regex") .hasArg(true) .required(false) .desc("regex to subset ids (in case of multiple logical " + "indices stored in a single Solr index)").build() ); } public static void main(String[] args) throws Exception { CommandLine commandLine = null; try { commandLine = new DefaultParser().parse(OPTIONS, args); } catch (ParseException e) { HelpFormatter helpFormatter = new HelpFormatter(); helpFormatter.printHelp( "java -jar org.mitre.quaerite.solrtools.ElevateQueryComparer", OPTIONS); return; } Matcher idMatcher = null; if (commandLine.hasOption("r")) { idMatcher = Pattern.compile(commandLine.getOptionValue("r")).matcher(""); } Path reportsRoot = Paths.get("."); if (commandLine.hasOption("d")) { reportsRoot = Paths.get(commandLine.getOptionValue("d")); } DecimalFormat df = new DecimalFormat("##.###%", DecimalFormatSymbols.getInstance(Locale.ROOT)); //TODO: lowercase queries or run them through an analyzer from a specific field? QuerySet queries = loadQueries(Paths.get(commandLine.getOptionValue("q"))); Map<String, Elevate> elevateMap = ElevateScraper.scrape(Paths.get( commandLine.getOptionValue("e")), idMatcher); int totalQueries = queries.total; if (commandLine.hasOption("t")) { totalQueries = Integer.parseInt(commandLine.getOptionValue("t")); } List<Query> sorted = new ArrayList<>(queries.queries.values()); Collections.sort(sorted); if (!Files.isDirectory(reportsRoot)) { Files.createDirectories(reportsRoot); } dumpAllElevated(elevateMap, queries, totalQueries, df, reportsRoot); dumpElevatedQueries(sorted, totalQueries, elevateMap, df, reportsRoot); dumpElevatedButNoQueries(elevateMap.keySet(), queries.queries.keySet(), reportsRoot); dumpElevatedCountDistributions(elevateMap, reportsRoot); Set<String> ids = new HashSet<>(); int elevated = 0; for (Elevate e : elevateMap.values()) { List<String> docs = e.getIds(); elevated += docs.size(); ids.addAll(docs); } LOG.info(String.format(Locale.US, "There are %s elevate entries", elevateMap.keySet().size())); LOG.info(String.format(Locale.US, "There are %s unique elevated document ids " + "and %s total elevated document ids", ids.size(), elevated)); if (commandLine.hasOption("s")) { dumpElevateVsIndex(commandLine.getOptionValue("s"), sorted, elevateMap, df, totalQueries, reportsRoot); } } private static void dumpElevatedCountDistributions(Map<String, Elevate> elevateMap, Path reportsRoot) throws IOException { //histogram of document ids per query //<number of ids, number of entries Map<Integer, Integer> m = new HashMap<>(); for (Elevate e : elevateMap.values()) { int numOfDocs = e.getIds().size(); Integer cnt = m.get(numOfDocs); if (cnt == null) { cnt = 1; } else { cnt++; } m.put(numOfDocs, cnt); } try (BufferedWriter writer = Files.newBufferedWriter( reportsRoot.resolve("elevated_num_docs_histogram.csv"), StandardCharsets.UTF_8 )) { writer.write(StringUtils.joinWith(",", "Number of Documents in an Elevate Entry,Number of Entries\n")); for (Map.Entry<Integer, Integer> e : MapUtil.sortByDescendingValue(m).entrySet()) { writer.write( String.format(Locale.US, "%s,%s\n", e.getKey(), e.getValue())); } } } private static void dumpAllElevated(Map<String, Elevate> elevateMap, QuerySet queries, int totalCount, NumberFormat df, Path reportsRoot) throws Exception { try (BufferedWriter writer = Files.newBufferedWriter( reportsRoot.resolve("elevated.csv"), StandardCharsets.UTF_8 )) { writer.write(StringUtils.joinWith(",", "Elevated", "QueryCount", "QueryPercentage", "\n")); for (String elevated : elevateMap.keySet()) { if (elevateMap.get(elevated).ids.size() == 0) { LOG.warn("no ids for this elevated item >" + elevated + "<"); continue; } int cnt = 0; Query q = queries.queries.get(elevated); if (q != null) { cnt = q.getCount(); } writer.write( StringUtils.joinWith(",", clean(elevated), cnt, df.format(((double) cnt / (double) totalCount)) ) + "\n" ); } } } private static void dumpElevateVsIndex(String searchServer, List<Query> sorted, Map<String, Elevate> elevateMap, DecimalFormat df, int totalCount, Path reportsRoot) throws Exception { SearchClient searchClient = SearchClientFactory.getClient(searchServer); Set<String> indexContains = new HashSet<>(); Set<String> indexMissing = new HashSet<>(); try (BufferedWriter writer = Files.newBufferedWriter( reportsRoot.resolve("elevated_vs_index.csv"), StandardCharsets.UTF_8 )) { writer.write(StringUtils.joinWith(",", "Query", "Id", "IndexContainsId", "QueryCount", "QueryPercentage", "\n")); for (Query q : sorted) { if (elevateMap.containsKey(q.q)) { Elevate e = elevateMap.get(q.q); for (String id : e.getIds()) { if (!indexContains.contains(id) && !indexMissing.contains(id)) { boolean contains = indexContains(id, searchClient); if (contains) { indexContains.add(id); } else { indexMissing.add(id); } } String contains = "index contains"; if (indexMissing.contains(id)) { contains = "index missing"; } writer.write( StringUtils.joinWith(",", clean(q.getQ()), clean(id), contains, q.getCount(), df.format(((double) q.getCount() / (double) totalCount)) ) + "\n" ); } } } } //now go get all the elevated irrespective of queries for (Elevate e : elevateMap.values()) { for (String id : e.getIds()) { if (!indexContains.contains(id) && !indexMissing.contains(id)) { boolean contains = indexContains(id, searchClient); if (contains) { indexContains.add(id); } else { indexContains.add(id); } } } } int zeroValidDocs = 0; int totalValidDocs = 0; int totalInvalidDocs = 0; Map<String, Integer> valid = new HashMap<>(); Map<String, Integer> invalid = new HashMap<>(); for (Elevate e : elevateMap.values()) { int v = 0; for (String id : e.getIds()) { if (indexContains.contains(id)) { v++; increment(valid, id); totalValidDocs++; } else { increment(invalid, id); totalInvalidDocs++; } } if (v == 0) { zeroValidDocs++; } } LOG.info( String.format(Locale.US, "There are %s unique valid docs and %s " + "total docs in the elevate file.", valid.size(), totalValidDocs) ); LOG.info( String.format(Locale.US, "There are %s unique missing docs and %s " + "missing docs in the elevate file.", invalid.size(), totalInvalidDocs) ); LOG.info( String.format(Locale.US, "There are %s entries with zero valid docs.", zeroValidDocs) ); ; } private static void increment(Map<String, Integer> m, String k) { Integer val = m.get(k); if (val == null) { m.put(k, 1); } else { m.put(k, ++val); } } private static void dumpElevatedButNoQueries(Set<String> elevated, Set<String> queries, Path reportsRoot) throws IOException { try (BufferedWriter writer = Files.newBufferedWriter( reportsRoot.resolve("elevated_zero_queries.csv"), StandardCharsets.UTF_8)) { writer.write("ElevatedQueryNotInQueryLog\n"); List<String> sorted = new ArrayList<>(elevated); Collections.sort(sorted); for (String q : sorted) { if (!queries.contains(q)) { writer.write(clean(q) + "\n"); } } } } private static void dumpElevatedQueries(List<Query> sorted, int totalQueries, Map<String, Elevate> elevateMap, DecimalFormat df, Path reportsRoot) throws Exception { try (Writer writer = Files.newBufferedWriter( reportsRoot.resolve("queries_elevated_or_not.csv"), StandardCharsets.UTF_8)) { //header writer.write( StringUtils.joinWith(",", "Query", "ElevatedOrNot", "QueryCount", "QueryPercentage", "\n") ); for (Query q : sorted) { String elevated = "not_elevated"; if (elevateMap.containsKey(q.q)) { elevated = "elevated"; } writer.write(StringUtils.joinWith(",", clean(q.getQ()), elevated, q.getCount(), clean(df.format(((double) q.getCount() / (double) totalQueries))) )); writer.write("\n"); } } } private static boolean indexContains(String id, SearchClient searchClient) throws Exception { QueryRequest qr = new QueryRequest(new TermQuery("id", id)); SearchResultSet rs = searchClient.search(qr); return rs.getIds().size() > 0; } private static String clean(String s) { if (s == null) { return StringUtils.EMPTY; } if (s.contains(",") || s.contains("\n") || s.contains("\r") || s.contains("\"")) { s = "\"" + s.replaceAll("\"", "\"\"") + "\""; } return s; } private static QuerySet loadQueries(Path file) throws Exception { QuerySet querySet = new QuerySet(); Matcher uc = Pattern.compile("[A-Z]").matcher(""); try (InputStream is = Files.newInputStream(file)) { try (Reader reader = new InputStreamReader(new BOMInputStream(is), "UTF-8")) { Iterable<CSVRecord> records = CSVFormat.EXCEL .withFirstRecordAsHeader().parse(reader); for (CSVRecord record : records) { String q = record.get("query"); Integer c = Integer.parseInt(record.get("count")); if (querySet.queries.containsKey(q)) { LOG.warn("duplicate queries?! >" + q + "<"); } querySet.set(q, c); } } } LOG.info("loaded " + querySet.queries.size() + " queries"); return querySet; } private static class ElevateSet { Map<String, List<String>> queryToIds = new HashMap<>(); public void add(String query, String id) { List<String> ids = queryToIds.get(query); if (ids == null) { ids = new ArrayList<>(); queryToIds.put(query, ids); } ids.add(id); } @Override public String toString() { return "ElevateSet{" + "queryToIds=" + queryToIds + '}'; } } private static class QuerySet { int total; Map<String, Query> queries = new HashMap<>(); public void set(String query, int count) { if (!query.equals("(other)")) { queries.put(query, new Query(query, count)); } total += count; } @Override public String toString() { return "QuerySet{" + "total=" + total + ", queries=" + queries + '}'; } } private static class Query implements Comparable<Query> { String q; int count = -1; public Query(String q, int count) { this.q = q; this.count = count; } public String getQ() { return q; } public int getCount() { return count; } @Override public String toString() { return "Query{" + "q='" + q + '\'' + ", count=" + count + '}'; } @Override public int compareTo(Query other) { if (other.getCount() == count) { return q.compareTo(other.q); } return Integer.compare(other.count, count); } @Override public boolean equals(Object o) { if (this == o) return true; if (!(o instanceof Query)) return false; Query query = (Query) o; return count == query.count && q.equals(query.q); } @Override public int hashCode() { return Objects.hash(q, count); } } }