package edu.tsinghua.dbgroup; import java.lang.Math; import java.util.*; import java.util.Map.Entry; import java.io.*; import java.util.concurrent.*; import java.util.concurrent.ThreadPoolExecutor.CallerRunsPolicy; import edu.tsinghua.dbgroup.*; public class EditDistanceJoiner { private List<String> mStrings; private TreeMap<Integer, ArrayList<HashMap<String, ArrayList<Integer>>>> mGlobalIndex; private int mThreshold; private int[][][] mDistanceBuffers; private int mNumThreads; private int mMaxLength; private ArrayList<EditDistanceJoinResult> mResults; private ArrayList<FilteredRawResult> mRawResults; static class UnfilteredResult { public int dstId; public int dstMatchPos; public int srcMatchPos; public int gramLen; } static class FilteredRawResult { public int srcId; public int dstId; public int similarity; } public EditDistanceJoiner(int threshold, int numThreads) { mGlobalIndex = new TreeMap<Integer, ArrayList<HashMap<String, ArrayList<Integer>>>>(); mStrings = new ArrayList<String>(); mMaxLength = 0; mThreshold = threshold; if (numThreads <= 0) { mNumThreads = 0; } else if (numThreads > Runtime.getRuntime().availableProcessors()) { mNumThreads = Runtime.getRuntime().availableProcessors(); } else { mNumThreads = numThreads; } } public EditDistanceJoiner(int threshold) { this(threshold, Runtime.getRuntime().availableProcessors()); } public int calculateEditDistanceWithThreshold(String s1, int start1, int l1, String s2, int start2, int l2, int threshold, int[][] distanceBuffer){ if (threshold < 0) { return 0; } if (threshold == 0) { String sub1 = s1.substring(start1, start1 + l1); String sub2 = s2.substring(start2, start2 + l2); return sub1.equals(sub2) ? 0 : 1; } if (l1 == 0) { return l2; } if (l2 == 0) { return l1; } for (int j = 1; j <= l1; j++) { int start = Math.max(j - threshold, 1); int end = Math.min(l2, j + threshold); if (j - threshold - 1 >= 1) { distanceBuffer[j - threshold - 1][j] = threshold + 1; } for (int i = start; i <= end; i++) { if (s1.charAt(start1 + j - 1) == s2.charAt(start2 + i - 1)) { distanceBuffer[i][j] = distanceBuffer[i - 1][j - 1]; } else { distanceBuffer[i][j] = Math.min(distanceBuffer[i - 1][j - 1] + 1, Math.min(distanceBuffer[i - 1][j] + 1, distanceBuffer[i][j - 1] + 1)); } } if (end < l2) distanceBuffer[end + 1][j] = threshold + 1; boolean earlyTerminateFlag = true; for (int i = start; i <= end; i++) { if (distanceBuffer[i][j] <= threshold) { earlyTerminateFlag = false; break; } } if (earlyTerminateFlag) return threshold + 1; } return distanceBuffer[l2][l1]; } private void indexStringById(int stringId){ String stringIndexing = mStrings.get(stringId); int l = stringIndexing.length();//3 3 2 if (!mGlobalIndex.containsKey(l)) { int strLen = 0; ArrayList<HashMap<String, ArrayList<Integer>>> subIndex = new ArrayList<HashMap<String, ArrayList<Integer>>>(); while (strLen < mThreshold + 1) { subIndex.add(new HashMap<String, ArrayList<Integer>>()); strLen++; } mGlobalIndex.put(l, subIndex); } for (int i = 0; i < mThreshold + 1; i++) { int gramLen = getGramLen(l, i); int startPos = getGramPos(l, i); String gram = stringIndexing.substring(startPos, startPos + gramLen); if (mGlobalIndex.get(l).get(i).containsKey(gram)) { mGlobalIndex.get(l).get(i).get(gram).add(stringId); } else { ArrayList<Integer> invertedList = new ArrayList<Integer>(); invertedList.add(stringId); mGlobalIndex.get(l).get(i).put(gram, invertedList); } } } public void initEditDistanceBuffer(){ mDistanceBuffers = new int[mNumThreads + 1][mMaxLength][mMaxLength]; for (int n = 0; n <= mNumThreads; n++) { for (int i = 0; i < mMaxLength; i++) { mDistanceBuffers[n][0][i] = i; mDistanceBuffers[n][i][0] = i; } } } public ArrayList<EditDistanceJoinResult> getJoinResults() { if(mStrings.size() == 0){ return new ArrayList<EditDistanceJoinResult>(); } long resultsBeforeRefiningNum = 0; long resultsRefinedNum = 0; long mainTid = Thread.currentThread().getId(); initEditDistanceBuffer(); mStrings = new ArrayList<String>(new TreeSet<String>(mStrings)); Collections.sort(mStrings, new Comparator<String>(){ @Override public int compare(String o1, String o2) { return compareString(o1, o2); } }); mResults = new ArrayList<EditDistanceJoinResult>(); mRawResults = new ArrayList<FilteredRawResult>(); int srcId = 1; ThreadPoolExecutor executor = null; if(mNumThreads != 0) { executor = new ThreadPoolExecutor(mNumThreads, mNumThreads, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>(3000), new ThreadPoolExecutor.CallerRunsPolicy()); } indexStringById(0); while(srcId < mStrings.size()){ int srcLen = mStrings.get(srcId).length(); ArrayList<UnfilteredResult> resultsBeforeRefining = new ArrayList<UnfilteredResult>(); getResultsFromIndex(srcId, resultsBeforeRefining); resultsBeforeRefiningNum += resultsBeforeRefining.size(); final int currentId = srcId; if (mNumThreads != 0) { executor.submit(() -> { long tid = Thread.currentThread().getId(); int[][] buffer = mDistanceBuffers[(int)(tid % mNumThreads)]; if (tid == mainTid) { buffer = mDistanceBuffers[mNumThreads]; } ArrayList<FilteredRawResult> resultsRefined = null; synchronized(buffer) { resultsRefined = refineResults(currentId, resultsBeforeRefining, buffer); } synchronized(mRawResults){ mRawResults.addAll(resultsRefined); } }); } else { int[][] buffer = mDistanceBuffers[0]; ArrayList<FilteredRawResult> resultsRefined = null; resultsRefined = refineResults(currentId, resultsBeforeRefining, buffer); mRawResults.addAll(resultsRefined); } mGlobalIndex.subMap(0, true, Math.max(1, srcLen - mThreshold), false).clear(); indexStringById(srcId); srcId++; } if(mNumThreads != 0) { executor.shutdown(); try { executor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS); } catch (InterruptedException e) { System.err.println(e.getMessage()); } } Collections.sort(mRawResults, new Comparator<FilteredRawResult>(){ @Override public int compare(FilteredRawResult o1, FilteredRawResult o2) { if (o1.srcId < o2.srcId) return -1; if (o1.srcId > o2.srcId) return 1; if (o1.dstId < o2.dstId) return -1; if (o1.dstId > o2.dstId) return 1; return 0; } }); for(FilteredRawResult rawResult : mRawResults) { EditDistanceJoinResult r = new EditDistanceJoinResult(); r.src = mStrings.get(rawResult.srcId); r.dst = mStrings.get(rawResult.dstId); r.similarity = rawResult.similarity; mResults.add(r); } return mResults; } private void getResultsFromIndex(int srcId, ArrayList<UnfilteredResult> resultsBeforeRefining){ String src = mStrings.get(srcId); int srcLen = src.length(); for (int dstLen = Math.max(srcLen - mThreshold, mThreshold + 1); dstLen <= mGlobalIndex.lastKey(); dstLen++) { if(!mGlobalIndex.containsKey(dstLen)){ continue; } int delta = srcLen - dstLen; for (int gramNo = 0; gramNo <= mThreshold; gramNo++) { int candidateGramPos = getGramPos(dstLen, gramNo); int candidateGramLen = getGramLen(dstLen, gramNo); int startPos = Math.max(Math.max(candidateGramPos - gramNo, candidateGramPos + delta + gramNo - mThreshold), 0); int endPos = Math.min(Math.min(candidateGramPos + gramNo, candidateGramPos + delta - gramNo + mThreshold), srcLen - candidateGramLen); for (; startPos <= endPos; startPos++) { String gram = src.substring(startPos, startPos + candidateGramLen); ArrayList<Integer> invertedList = mGlobalIndex.get(dstLen).get(gramNo).get(gram); if (invertedList != null) { for (int k = 0; k < invertedList.size(); k++) { int dstId = invertedList.get(k); UnfilteredResult t = new UnfilteredResult(); t.dstId = dstId; t.dstMatchPos = candidateGramPos; t.srcMatchPos = startPos; t.gramLen = candidateGramLen; resultsBeforeRefining.add(t); } } } } } Collections.sort(resultsBeforeRefining, new Comparator<UnfilteredResult>() { @Override public int compare(UnfilteredResult a, UnfilteredResult b) { if (a.dstId < b.dstId) return -1; if (a.dstId > b.dstId) return 1; return 0; } }); } private int filterCandidate(String src, String dst, int srcMatchPos, int dstMatchPos, int len, int[][] distanceBuffer){ int srcRightLen = src.length() - srcMatchPos - len; int dstRightLen = dst.length() - dstMatchPos - len; int leftThreshold = mThreshold - Math.abs(srcRightLen - dstRightLen); int leftDistance = calculateEditDistanceWithThreshold(src, 0, srcMatchPos, dst, 0, dstMatchPos, leftThreshold, distanceBuffer); if (leftDistance > leftThreshold) { return -1; } int rightThreshold = mThreshold - leftDistance; int rightDistance = calculateEditDistanceWithThreshold( src, srcMatchPos + len, src.length() - srcMatchPos - len, dst, dstMatchPos + len, dst.length() - dstMatchPos - len, rightThreshold, distanceBuffer); if (rightDistance > rightThreshold) { return -1; } return leftDistance + rightDistance; } private ArrayList<FilteredRawResult> refineResults(int srcId, ArrayList<UnfilteredResult> resultsBeforeRefining, int[][] distanceBuffer){ ArrayList<FilteredRawResult> resultsRefined = new ArrayList<FilteredRawResult>(); HashSet<Integer> matchStringIds = new HashSet<Integer>(); for (UnfilteredResult t : resultsBeforeRefining) { int dstId = t.dstId; if(matchStringIds.contains(dstId)){ continue; } int dstMatchPos = t.dstMatchPos; int srcMatchPos = t.srcMatchPos; String dst = mStrings.get(dstId); String src = mStrings.get(srcId); int len = t.gramLen; int distance = filterCandidate(src, dst, srcMatchPos, dstMatchPos, len, distanceBuffer); if(distance != -1){ FilteredRawResult r = new FilteredRawResult(); r.srcId = dstId; r.dstId = srcId; r.similarity = distance; resultsRefined.add(r); matchStringIds.add(dstId); } } return resultsRefined; } public void populate(String s) { if(s.length() > mThreshold){ mStrings.add(s); mMaxLength = Math.max(mMaxLength, s.length()); } } public void populate(List<String> strings){ for(String s : strings){ populate(s); } } static private int compareString(String o1, String o2) { if (o1.length() > o2.length()) { return 1; } else if (o1.length() < o2.length()) { return -1; } return o1.compareTo(o2); } private int getGramPos(int strLen, int gramNo){ int shortGramLen = strLen / (mThreshold + 1); int longGramOffset = gramNo - (mThreshold + 1 - strLen % (mThreshold + 1)); if(longGramOffset > 0){ return shortGramLen * gramNo + longGramOffset; } return shortGramLen * gramNo; } private int getGramLen(int strLen, int gramNo){ int shortGramLen = strLen / (mThreshold + 1); if(gramNo + strLen % (mThreshold + 1) >= mThreshold + 1){ return shortGramLen + 1; } return shortGramLen; } }