package com.asher_stern.crf.crf; import static com.asher_stern.crf.utilities.ArithmeticUtilities.*; import java.math.BigDecimal; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import org.apache.log4j.Logger; import com.asher_stern.crf.utilities.CrfException; import com.asher_stern.crf.utilities.TaggedToken; /** * Calculates, for each feature, the expected sum of its values over the whole corpus. * For example, let's assume that the corpus contains 10 sentences, and each sentence contains 8 tokens, * so there are 80 tokens, and in each of them the feature is expected to get some value. In this class the sum of all those * expected values is calculated, for each feature. * * @see CrfLogLikelihoodFunction * * @author Asher Stern * Date: Nov 9, 2014 * * @param <K> * @param <G> */ public class CrfFeatureValueExpectationByModel<K, G> { public CrfFeatureValueExpectationByModel( Iterator<? extends List<? extends TaggedToken<K, G>>> corpusIterator, CrfModel<K, G> model) { super(); this.corpusIterator = corpusIterator; this.model = model; } public void calculate() { featureValueExpectation = new BigDecimal[model.getFeatures().getFilteredFeatures().length]; for (int i=0;i<featureValueExpectation.length;++i) {featureValueExpectation[i]=BigDecimal.ZERO;} // Explicit initialization to zero, just to be on the safe side. ExecutorService executor = Executors.newWorkStealingPool(); List<Future<?>> futures = new LinkedList<>(); while (corpusIterator.hasNext()) { final List<? extends TaggedToken<K, G>> sentence = corpusIterator.next(); futures.add(executor.submit( new Runnable() { @Override public void run() { addValueForSentence(sentence); } })); } for (Future<?> future: futures) { try { future.get(); } catch (InterruptedException | ExecutionException e) { throw new CrfException(e); } } } public BigDecimal[] getFeatureValueExpectation() { return featureValueExpectation; } private void addValueForSentence(List<? extends TaggedToken<K, G>> sentence) { K[] sentenceTokens = CrfUtilities.extractSentence(sentence); // Find the "active" features for each triple of {token-index,tag-of-token,tag-of-previous-token} CrfRememberActiveFeatures<K, G> activeFeaturesForSentence = CrfRememberActiveFeatures.findForSentence(model.getFeatures(), model.getCrfTags(), sentenceTokens); // Calculate the CRF formula: e^{\Sum_{i=0}^{F-1}\theta_i*f_i(j,g,g')} where F is the number of features, j is a token index, g is a tag for that token, and g' is a tag for the previous token. CrfPsi_FormulaAllTokens<K, G> allTokensFormula = CrfPsi_FormulaAllTokens.createAndCalculate(model,sentenceTokens,activeFeaturesForSentence); CrfForwardBackward<K,G> forwardBackward = new CrfForwardBackward<K,G>(model,sentenceTokens,activeFeaturesForSentence); forwardBackward.setAllTokensFormulaValues(allTokensFormula); forwardBackward.calculateForwardAndBackward(); final BigDecimal normalizationFactor = forwardBackward.getCalculatedNormalizationFactor(); for (int tokenIndex=0;tokenIndex<sentenceTokens.length;++tokenIndex) { for (G currentTag : model.getCrfTags().getTags()) { Set<G> possiblePreviousTags = CrfUtilities.getPreviousTags(sentenceTokens, tokenIndex, currentTag, model.getCrfTags()); for (G previousTag : possiblePreviousTags) { //Set<Integer> activeFeatures = CrfUtilities.getActiveFeatureIndexes(model.getFeatures(),sentenceTokens,tokenIndex,currentTag,previousTag); Set<Integer> activeFeatures = activeFeaturesForSentence.getOneTokenActiveFeatures(tokenIndex, currentTag, previousTag); for (int featureIndex : activeFeatures) { double featureValue = 0.0; if (model.getFeatures().getFilteredFeatures()[featureIndex].isWhenNotFilteredIsAlwaysOne()) { featureValue = 1.0; } else { featureValue = model.getFeatures().getFilteredFeatures()[featureIndex].getFeature().value(sentenceTokens,tokenIndex,currentTag,previousTag); } BigDecimal probabilityUnderModel = null; if (featureValue!=0.0) { // Calculate probabilityUnderModel if (null==probabilityUnderModel) { BigDecimal alpha_forward_previousValue = BigDecimal.ONE; if (tokenIndex>0) { alpha_forward_previousValue = forwardBackward.getAlpha_forward()[tokenIndex-1].get(previousTag); } BigDecimal beta_backward_value = forwardBackward.getBeta_backward().get(tokenIndex).get(currentTag); //double psi_probabilityForGivenIndexAndTags = CrfUtilities.oneTokenFormula(model,sentenceTokens,tokenIndex,currentTag,previousTag,activeFeatures); BigDecimal psi_probabilityForGivenIndexAndTags = allTokensFormula.getOneTokenFormula(tokenIndex,currentTag,previousTag); //probabilityUnderModel = (alpha_forward_previousValue*psi_probabilityForGivenIndexAndTags*beta_backward_value)/normalizationFactor; probabilityUnderModel = safeDivide(safeMultiply(safeMultiply(alpha_forward_previousValue, psi_probabilityForGivenIndexAndTags),beta_backward_value), normalizationFactor); } BigDecimal addToExpectation = safeMultiply(big(featureValue), probabilityUnderModel); synchronized(locker) { featureValueExpectation[featureIndex] = safeAdd(featureValueExpectation[featureIndex], addToExpectation); } } } // end for-each feature } // end for-each previous-tag } // end for-each current-tag } // end for-each token-index } private final Iterator<? extends List<? extends TaggedToken<K, G>>> corpusIterator; private final CrfModel<K, G> model; private final Object locker = new Object(); private BigDecimal[] featureValueExpectation; @SuppressWarnings("unused") private static final Logger logger = Logger.getLogger(CrfFeatureValueExpectationByModel.class); }