# Copyright 2019 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for information-theoretic preprocessing algorithms.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import math # math.log2 was added in Python 3.3 try: log2 = math.log2 except AttributeError: log2 = lambda x: math.log(x, 2) # TODO(b/157302701): Evaluate optimizations or approximations for this function, # in particular the _hypergeometric_pmf. def calculate_partial_expected_mutual_information(n, x_i, y_j): """Calculates the partial expected mutual information (EMI) of two variables. EMI reflects the MI expected by chance, and is used to compute adjusted mutual information. See www.wikipedia.org/wiki/Adjusted_mutual_information. The EMI for two variables x and y, is the sum of the expected mutual info for each value of x with each value of y. This function computes the EMI for a single value of each variable (x_i, y_j) and is thus considered a partial EMI calculation. Specifically: EMI(x, y) = sum_{n_ij = max(0, x_i + y_j - n) to min(x_i, y_j)} ( n_ij / n * log2((n * n_ij / (x_i * y_j)) * ((x_i! * y_j! * (n - x_i)! * (n - y_j)!) / (n! * n_ij! * (x_i - n_ij)! * (y_j - n_ij)! * (n - x_i - y_j + n_ij)!))) where n_ij is the joint count of x taking on value i and y taking on value j, x_i is the count for x taking on value i, y_j is the count for y taking on value j, and n represents total count. Args: n: The sum of weights for all values. x_i: The sum of weights for the first variable taking on value i y_j: The sum of weights for the second variable taking on value j Returns: Calculated expected mutual information for x_i, y_j. """ if x_i == 0 or y_j == 0: return 0 coefficient = (-log2(x_i) - log2(y_j) + log2(n)) sum_probability = 0.0 partial_result = 0.0 for n_j, p_j in _hypergeometric_pmf(n, x_i, y_j): if n_j != 0: partial_result += n_j * (coefficient + log2(n_j)) * p_j sum_probability += p_j # The values of p_j should sum to 1, but given approximate calculations for # log2(x) and exp2(x) with large x, the full pmf might not sum to exactly 1. # We correct for this by dividing by the sum of the probabilities. return partial_result / sum_probability def calculate_partial_mutual_information(n_ij, x_i, y_j, n): """Calculates Mutual Information for x=i, y=j from sample counts. The standard formulation of mutual information is: MI(X,Y) = Sum_i,j {p_ij * log2(p_ij / p_i * p_j)} We are operating over counts (p_ij = n_ij / n), so this is transformed into MI(X,Y) = Sum_i,j {n_ij * (log2(n_ij) + log2(n) - log2(x_i) - log2(y_j))} / n This function returns the argument to the summation, the mutual information for a particular pair of values x_i, y_j (the caller is expected to divide the summation by n to compute the final mutual information result). Args: n_ij: The co-occurrence of x=i and y=j x_i: The frequency of x=i. y_j: The frequency of y=j. n: The total # observations Returns: Mutual information for the cell x=i, y=j. """ if n_ij == 0: return 0 return n_ij * ((log2(n_ij) + log2(n)) - (log2(x_i) + log2(y_j))) def _hypergeometric_pmf(n, x_i, y_j): """Probablity for expectation computation under hypergeometric distribution. Args: n: The sum of weights for all values. x_i: The sum of weights for the first variable taking on value i y_j: The sum of weights for the second variable taking on value j Yields: The probability p_j at point n_j in the hypergeometric distribution. """ start = int(round(max(0, x_i + y_j - n))) end = int(round(min(x_i, y_j))) # Use log factorial to preserve calculation precision. # Note: because the factorials are expensive to compute, we compute the # denominator incrementally, at the cost of some readability. numerator = ( _logfactorial(x_i) + _logfactorial(y_j) + _logfactorial(n - x_i) + _logfactorial(n - y_j)) denominator = ( _logfactorial(n) + _logfactorial(start) + _logfactorial(x_i - start) + _logfactorial(y_j - start) + _logfactorial(n - x_i - y_j + start)) for n_j in range(start, end + 1): p_j = math.exp(numerator - denominator) if n_j != end: denominator += ( math.log(n_j + 1) - math.log(x_i - n_j) - math.log(y_j - n_j) + math.log(n - x_i - y_j + n_j + 1)) yield n_j, p_j def _logfactorial(n): """Calculate natural logarithm of n!.""" return math.lgamma(n + 1)