import numpy as np import numba from sklearn.base import BaseEstimator, TransformerMixin from sklearn.utils import check_array, check_random_state from sklearn.utils.validation import _check_sample_weight from scipy.sparse import issparse, csr_matrix, coo_matrix from enstop.utils import ( normalize, coherence, mean_coherence, log_lift, mean_log_lift, ) from enstop.plsa import plsa_init @numba.njit( [ "f4[:,::1](i4[::1],i4[::1],f4[:,::1],f4[:,::1],f4[:,::1],f4)", "f4[:,::1](i4[::1],i4[::1],f4[:,:],f4[:,::1],f4[:,::1],f4)", ], locals={ "k": numba.types.uint16, "w": numba.types.uint32, "d": numba.types.uint32, "z": numba.types.uint16, "v": numba.types.float32, "nz_idx": numba.types.uint32, "norm": numba.types.float32, }, fastmath=True, nogil=True, ) def plsa_e_step_on_a_block( block_rows, block_cols, p_w_given_z_block, p_z_given_d_block, p_z_given_wd_block, probability_threshold=1e-32, ): k = p_w_given_z_block.shape[0] for nz_idx in range(block_rows.shape[0]): if block_rows[nz_idx] < 0: break d = block_rows[nz_idx] w = block_cols[nz_idx] norm = 0.0 for z in range(k): v = p_w_given_z_block[z, w] * p_z_given_d_block[d, z] if v > probability_threshold: p_z_given_wd_block[nz_idx, z] = v norm += v else: p_z_given_wd_block[nz_idx, z] = 0.0 for z in range(k): if norm > 0: p_z_given_wd_block[nz_idx, z] /= norm return p_z_given_wd_block @numba.njit( [ "void(i4[::1],i4[::1],f4[::1],f4[:,::1],f4[:,::1],f4[:,::1],f4[::1],f4[::1])", "void(i4[::1],i4[::1],f4[::1],f4[:,:],f4[:,:],f4[:,::1],f4[::1],f4[::1])", ], locals={ "k": numba.types.uint16, "w": numba.types.uint32, "d": numba.types.uint32, "x": numba.types.float32, "z": numba.types.uint16, "nz_idx": numba.types.uint32, "s": numba.types.float32, }, fastmath=True, nogil=True, ) def plsa_partial_m_step_on_a_block( block_rows, block_cols, block_vals, p_w_given_z_block, p_z_given_d_block, p_z_given_wd_block, norm_pwz, norm_pdz_block, ): k = p_w_given_z_block.shape[0] for nz_idx in range(block_rows.shape[0]): if block_rows[nz_idx] < 0: break d = block_rows[nz_idx] w = block_cols[nz_idx] x = block_vals[nz_idx] for z in range(k): s = x * p_z_given_wd_block[nz_idx, z] p_w_given_z_block[z, w] += s p_z_given_d_block[d, z] += s norm_pwz[z] += s norm_pdz_block[d] += s @numba.njit( "void(i4[:,:,::1],i4[:,:,::1],f4[:,:,::1],f4[:,:,::1],f4[:,:,::1],f4[:,:,:,::1]," "f4[:,:,:,::1],f4[:,:,:,::1],f4[:,::1],f4[:,:,::1],f4)", locals={ "n": numba.types.uint32, "m": numba.types.uint32, "k": numba.types.uint16, "z": numba.types.uint16, "d": numba.types.uint32, "i": numba.types.uint16, "j": numba.types.uint16, "n_w_blocks": numba.types.uint16, "n_d_blocks": numba.types.uint16, }, parallel=True, fastmath=True, nogil=True, ) def plsa_em_step_by_blocks( block_rows_ndarray, block_cols_ndarray, block_vals_ndarray, prev_p_w_given_z, prev_p_z_given_d, blocked_next_p_w_given_z, blocked_next_p_z_given_d, p_z_given_wd_block, blocked_norm_pwz, blocked_norm_pdz, e_step_thresh=1e-32, ): n_d_blocks = block_rows_ndarray.shape[0] n_w_blocks = block_rows_ndarray.shape[1] # n = prev_p_z_given_d.shape[0] # m = prev_p_w_given_z.shape[1] k = prev_p_z_given_d.shape[2] # zero out the norms for recomputation blocked_norm_pdz[:] = 0.0 blocked_norm_pwz[:] = 0.0 for i in numba.prange(n_d_blocks): for j in numba.prange(n_w_blocks): block_rows = block_rows_ndarray[i, j] block_cols = block_cols_ndarray[i, j] block_vals = block_vals_ndarray[i, j] plsa_e_step_on_a_block( block_rows, block_cols, prev_p_w_given_z[j], prev_p_z_given_d[i], p_z_given_wd_block[i, j], np.float32(e_step_thresh), ) plsa_partial_m_step_on_a_block( block_rows, block_cols, block_vals, blocked_next_p_w_given_z[i, j], blocked_next_p_z_given_d[j, i], p_z_given_wd_block[i, j], blocked_norm_pwz[i], blocked_norm_pdz[j, i], ) prev_p_z_given_d[:] = blocked_next_p_z_given_d.sum(axis=0) norm_pdz = blocked_norm_pdz.sum(axis=0) prev_p_w_given_z[:] = blocked_next_p_w_given_z.sum(axis=0) norm_pwz = blocked_norm_pwz.sum(axis=0) # Once complete we can normalize to complete the M step for z in numba.prange(k): if norm_pwz[z] > 0: for w_block in range(prev_p_w_given_z.shape[0]): for w_offset in range(prev_p_w_given_z.shape[2]): prev_p_w_given_z[w_block, z, w_offset] /= norm_pwz[z] for d_block in range(prev_p_z_given_d.shape[0]): for d_offset in range(prev_p_z_given_d.shape[1]): if norm_pdz[d_block, d_offset] > 0: prev_p_z_given_d[d_block, d_offset, z] /= norm_pdz[ d_block, d_offset ] # Zero out the old matrices these matrices for next time blocked_next_p_z_given_d[:] = 0.0 blocked_next_p_w_given_z[:] = 0.0 @numba.njit( locals={ "i": numba.types.uint16, "j": numba.types.uint16, "k": numba.types.uint16, "w": numba.types.uint32, "d": numba.types.uint32, "z": numba.types.uint16, "nz_idx": numba.types.uint32, "x": numba.types.float32, "result": numba.types.float32, "p_w_given_d": numba.types.float32, }, fastmath=True, nogil=True, parallel=True, ) def log_likelihood_by_blocks( block_rows_ndarray, block_cols_ndarray, block_vals_ndarray, p_w_given_z, p_z_given_d, ): result = 0.0 k = p_z_given_d.shape[2] for i in numba.prange(block_rows_ndarray.shape[0]): for j in range(block_rows_ndarray.shape[1]): for nz_idx in range(block_rows_ndarray.shape[2]): if block_rows_ndarray[i, j, nz_idx] < 0: break d = block_rows_ndarray[i, j, nz_idx] w = block_cols_ndarray[i, j, nz_idx] x = block_vals_ndarray[i, j, nz_idx] p_w_given_d = 0.0 for z in range(k): p_w_given_d += p_w_given_z[j, z, w] * p_z_given_d[i, d, z] result += x * np.log(p_w_given_d) return result @numba.njit(fastmath=True, nogil=True) def plsa_fit_inner_blockwise( block_rows_ndarray, block_cols_ndarray, block_vals_ndarray, p_w_given_z, p_z_given_d, block_row_size, block_col_size, n_iter=100, n_iter_per_test=10, tolerance=0.001, e_step_thresh=1e-32, ): k = p_z_given_d.shape[2] n_d_blocks = block_rows_ndarray.shape[0] n_w_blocks = block_rows_ndarray.shape[1] block_size = block_rows_ndarray.shape[2] p_z_given_wd_block = np.zeros( (n_d_blocks, n_w_blocks, block_size, k), dtype=np.float32 ) blocked_next_p_w_given_z = np.zeros( ( np.int64(n_d_blocks), np.int64(n_w_blocks), np.int64(k), np.int64(block_col_size), ), dtype=np.float32, ) blocked_norm_pwz = np.zeros((n_d_blocks, k), dtype=np.float32) blocked_next_p_z_given_d = np.zeros( ( np.int64(n_w_blocks), np.int64(n_d_blocks), np.int64(block_row_size), np.int64(k), ), dtype=np.float32, ) blocked_norm_pdz = np.zeros( (np.int64(n_w_blocks), np.int64(n_d_blocks), np.int64(block_row_size)), dtype=np.float32, ) previous_log_likelihood = log_likelihood_by_blocks( block_rows_ndarray, block_cols_ndarray, block_vals_ndarray, p_w_given_z, p_z_given_d, ) for i in range(n_iter): plsa_em_step_by_blocks( block_rows_ndarray, block_cols_ndarray, block_vals_ndarray, p_w_given_z, p_z_given_d, blocked_next_p_w_given_z, blocked_next_p_z_given_d, p_z_given_wd_block, blocked_norm_pwz, blocked_norm_pdz, e_step_thresh, ) if i % n_iter_per_test == 0: current_log_likelihood = log_likelihood_by_blocks( block_rows_ndarray, block_cols_ndarray, block_vals_ndarray, p_w_given_z, p_z_given_d, ) change = np.abs(current_log_likelihood - previous_log_likelihood) if change / np.abs(current_log_likelihood) < tolerance: break else: previous_log_likelihood = current_log_likelihood return p_z_given_d, p_w_given_z def plsa_fit( X, k, n_row_blocks=8, n_col_blocks=8, init="random", n_iter=100, n_iter_per_test=10, tolerance=0.001, e_step_thresh=1e-32, random_state=None, ): rng = check_random_state(random_state) p_z_given_d_init, p_w_given_z_init = plsa_init(X, k, init=init, rng=rng) A = X.tocsr().astype(np.float32) n = A.shape[0] m = A.shape[1] block_row_size = np.uint16(np.ceil(A.shape[0] / n_row_blocks)) block_col_size = np.uint16(np.ceil(A.shape[1] / n_col_blocks)) p_z_given_d = np.zeros((block_row_size * n_row_blocks, k), dtype=np.float32) p_z_given_d[: p_z_given_d_init.shape[0]] = p_z_given_d_init p_z_given_d = p_z_given_d.reshape(n_row_blocks, block_row_size, k) p_w_given_z = np.zeros((k, block_col_size * n_col_blocks), dtype=np.float32) p_w_given_z[:, : p_w_given_z_init.shape[1]] = p_w_given_z_init p_w_given_z = np.transpose( p_w_given_z.T.reshape(n_col_blocks, block_col_size, k), axes=[0, 2, 1] ).astype(np.float32, order="C") A_blocks = [[0] * n_col_blocks for i in range(n_row_blocks)] max_nnz_per_block = 0 for i in range(n_row_blocks): row_start = block_row_size * i row_end = min(row_start + block_row_size, n) for j in range(n_col_blocks): col_start = block_col_size * j col_end = min(col_start + block_col_size, m) A_blocks[i][j] = A[row_start:row_end, col_start:col_end].tocoo() if A_blocks[i][j].nnz > max_nnz_per_block: max_nnz_per_block = A_blocks[i][j].nnz block_rows_ndarray = np.full( (n_row_blocks, n_col_blocks, max_nnz_per_block), -1, dtype=np.int32 ) block_cols_ndarray = np.full( (n_row_blocks, n_col_blocks, max_nnz_per_block), -1, dtype=np.int32 ) block_vals_ndarray = np.zeros( (n_row_blocks, n_col_blocks, max_nnz_per_block), dtype=np.float32 ) for i in range(n_row_blocks): for j in range(n_col_blocks): nnz = A_blocks[i][j].nnz block_rows_ndarray[i, j, :nnz] = A_blocks[i][j].row block_cols_ndarray[i, j, :nnz] = A_blocks[i][j].col block_vals_ndarray[i, j, :nnz] = A_blocks[i][j].data p_z_given_d, p_w_given_z = plsa_fit_inner_blockwise( block_rows_ndarray, block_cols_ndarray, block_vals_ndarray, p_w_given_z, p_z_given_d, block_row_size, block_col_size, n_iter=n_iter, n_iter_per_test=n_iter_per_test, tolerance=tolerance, e_step_thresh=e_step_thresh, ) p_z_given_d = p_z_given_d.reshape(-1, k)[:n, :] p_w_given_z = ( np.transpose(p_w_given_z, axes=[0, 2, 1]).reshape(-1, k).T[:, :m] ) # p_z_given_d, p_w_given_z = plsa_fit_inner_dask( # block_rows_ndarray, # block_cols_ndarray, # block_vals_ndarray, # p_w_given_z, # p_z_given_d, # block_row_size, # block_col_size, # n_iter=n_iter, # n_iter_per_test=n_iter_per_test, # tolerance=tolerance, # e_step_thresh=e_step_thresh, # ) return p_z_given_d, p_w_given_z class BlockParallelPLSA(BaseEstimator, TransformerMixin): def __init__( self, n_components=10, init="random", n_row_blocks=8, n_col_blocks=8, n_iter=100, n_iter_per_test=10, tolerance=0.001, e_step_thresh=1e-32, transform_random_seed=42, random_state=None, ): self.n_components = n_components self.init = init self.n_row_blocks = n_row_blocks self.n_col_blocks = n_col_blocks self.n_iter = n_iter self.n_iter_per_test = n_iter_per_test self.tolerance = tolerance self.e_step_thresh = e_step_thresh self.transform_random_seed = transform_random_seed self.random_state = random_state def fit(self, X, y=None, sample_weight=None): """Learn the pLSA model for the data X and return the document vectors. This is more efficient than calling fit followed by transform. Parameters ---------- X: array or sparse matrix of shape (n_docs, n_words) The data matrix pLSA is attempting to fit to. y: Ignored sample_weight: array of shape (n_docs,) Input document weights. Returns ------- self """ self.fit_transform(X, sample_weight=sample_weight) return self def fit_transform(self, X, y=None, sample_weight=None): """Learn the pLSA model for the data X and return the document vectors. This is more efficient than calling fit followed by transform. Parameters ---------- X: array or sparse matrix of shape (n_docs, n_words) The data matrix pLSA is attempting to fit to. y: Ignored sample_weight: array of shape (n_docs,) Input document weights. Returns ------- embedding: array of shape (n_docs, n_topics) An embedding of the documents into a topic space. """ X = check_array(X, accept_sparse="csr") if not issparse(X): X = csr_matrix(X) sample_weight = _check_sample_weight(sample_weight, X, dtype=np.float32) if np.any(X.data < 0): raise ValueError( "PLSA is only valid for matrices with non-negative " "entries" ) row_sums = np.array(X.sum(axis=1).T)[0] good_rows = row_sums != 0 if not np.all(good_rows): zero_rows_found = True data_for_fitting = X[good_rows] else: zero_rows_found = False data_for_fitting = X U, V = plsa_fit( data_for_fitting, self.n_components, n_row_blocks=self.n_row_blocks, n_col_blocks=self.n_col_blocks, init=self.init, n_iter=self.n_iter, n_iter_per_test=self.n_iter_per_test, tolerance=self.tolerance, e_step_thresh=self.e_step_thresh, random_state=self.random_state, ) if zero_rows_found: self.embedding_ = np.zeros((X.shape[0], self.n_components)) self.embedding_[good_rows] = U else: self.embedding_ = U self.components_ = V self.training_data_ = X return self.embedding_