import numpy as np from scipy import signal import torch import torch.nn as nn from torch.autograd import Variable from tqdm import trange from . import sequences from . import initialize class TCM: def __init__(self, n_seeds, n_motifs, motif_width, min_sites, max_sites, batch_size, half_length, fudge, alpha, revcomp, tolerance, maxiter, erasewhole, cuda): self.n_seeds = n_seeds self.n_motifs = n_motifs self.motif_width = motif_width self.min_sites = min_sites self.max_sites = max_sites self.batch_size = batch_size self.half_length = half_length self.fudge = fudge self.alpha = alpha self.revcomp = revcomp self.tolerance = tolerance self.maxiter = maxiter self.erasewhole = erasewhole self.cuda = cuda self.ppms_ = None self.ppms_bg_ = None self.fracs_ = None self.n_sites_ = None def fit(self, X, X_neg=None): """Fit the model to the data X. Discover n_motifs motifs. Parameters ---------- X : {list of string sequences} Training data. Returns ------- self : TCM The fitted model. """ ppms_final = [] ppms_bg_final = [] fracs_final = [] n_sites_final = [] converged_early = False for i_motif in range(self.n_motifs): print('\nFinding motif %i of %i' % (i_motif+1, self.n_motifs)) X_seq = X X_neg_seq = X_neg N = len(X) if X_neg is not None: top_words = initialize.find_enriched_gapped_kmers(X, X_neg, self.half_length, 0, self.motif_width - 2 * self.half_length, self.alpha, self.revcomp, self.n_seeds) X = sequences.encode(X, self.alpha)#need to change this to X_seq X_seqs_onehot = X # Need to use one hot coded positive sequences later if X_neg is not None: # Need to use one hot coded negative sequences later X_neg_seqs_onehot = sequences.encode(X_neg, self.alpha) # Extract valid one-hot subsequences X = sequences.get_onehot_subsequences(X, self.motif_width) M, L, W = X.shape if self.revcomp: M *= 2 # Compute motif fractions seeds min_sites = self.min_sites min_frac = min_sites / M if self.max_sites is None: max_sites = N # Expect at most one motif occurrence per sequence by default else: max_sites = self.max_sites max_frac = max_sites / M fracs_seeds = np.geomspace(min_sites, max_sites, 5) / M n_uniq_fracs_seeds = len(fracs_seeds) fracs_seeds = np.repeat(fracs_seeds, self.n_seeds) fracs_seeds = torch.from_numpy(fracs_seeds.astype(np.float32)) # Compute background frequencies letter_frequency = X.sum(axis=(0,2)) if self.revcomp: # If reverse complements considered, complement letter frequencies set to same value letter_frequency[[0, 3]] = letter_frequency[0] + letter_frequency[3] letter_frequency[[1, 2]] = letter_frequency[1] + letter_frequency[2] X = np.concatenate((X, X[:, ::-1, ::-1]), axis=0) bg_probs = 1.0 * letter_frequency / letter_frequency.sum() ppms_bg_seeds = bg_probs.reshape([1, L, 1]).repeat( self.n_seeds * n_uniq_fracs_seeds, axis=0).astype(np.float32) ppms_bg_seeds = torch.from_numpy(ppms_bg_seeds) # Initialize PPMs large_prob = 0.9 small_prob = (1 - large_prob) / (L - 1) if X_neg is not None: ppms_seeds = sequences.encode(top_words, self.alpha) ppms_seeds = sequences.pad_onehot_sequences(ppms_seeds, W).astype(np.float32) * large_prob for ppm in ppms_seeds: ppm[:, ppm.sum(axis=0)==0] = bg_probs.reshape((L, 1)) ppms_seeds[ppms_seeds == 0] = small_prob else: ppms_seeds = X[0:self.n_seeds].astype(np.float32) * large_prob ppms_seeds[ppms_seeds == 0] = small_prob ppms_seeds = np.tile(ppms_seeds, (n_uniq_fracs_seeds, 1, 1)) ppms_seeds_original = ppms_seeds.copy() ppms_seeds = torch.from_numpy(ppms_seeds) # If using cuda, convert the three parameter tensors to cuda format if self.cuda: ppms_bg_seeds = ppms_bg_seeds.cuda() ppms_seeds = ppms_seeds.cuda() fracs_seeds = fracs_seeds.cuda() ppms_bg_seeds = ppms_bg_seeds.expand(len(ppms_bg_seeds), L, W) # Perform one On-line and one batch EM pass ppms_seeds, ppms_bg_seeds, fracs_seeds = \ self._online_em(X, ppms_seeds, ppms_bg_seeds, fracs_seeds, 1) ppms, ppms_bg, fracs = \ self._batch_em(X, ppms_seeds, ppms_bg_seeds, fracs_seeds, 1) log_likelihoods = self._compute_log_likelihood(X, ppms, ppms_bg, fracs) # Filter away all invalid parameter sets # Removed the right-most filter since it was causing issues for some people bool_mask = (log_likelihoods != np.inf) #& (fracs > min_frac) & (fracs < max_frac) indices = torch.arange(0, len(bool_mask), 1).long() if self.cuda: indices = indices.cuda() if len(indices) == 0: converged_early = True break indices = indices[bool_mask] log_likelihoods = log_likelihoods[indices] ppms = ppms[indices] ppms_bg = ppms_bg[indices] fracs = fracs[indices] ppms_seeds = ppms_seeds[indices] # Select seed that yields highest log likelihood after one online and one batch EM passes max_log_likelihoods, max_log_likelihoods_index = log_likelihoods.max(dim=0) max_log_likelihoods_index = max_log_likelihoods_index.item() # Replaced [0] w/ .item() for PyTorch >= 0.4 word_seed_best = sequences.decode( [ppms_seeds_original[max_log_likelihoods_index].round().astype(np.uint8)], self.alpha)[0] print('Using seed originating from word: %s' % (word_seed_best)) ppm_best = ppms[[max_log_likelihoods_index]] ppm_bg_best = ppms_bg[[max_log_likelihoods_index]] frac_best = fracs[[max_log_likelihoods_index]] # Refine the best seed with batch EM passes ppm_best, ppm_bg_best, frac_best = \ self._batch_em(X, ppm_best, ppm_bg_best, frac_best, self.maxiter) if np.isnan(ppm_best[0].cpu().numpy()).any(): converged_early = True break ppms_final.append(ppm_best[0].cpu().numpy()) ppms_bg_final.append(ppm_bg_best[0].cpu().numpy()) fracs_final.append(frac_best[0]) n_sites = M * fracs_final[-1].cpu() if np.isnan(n_sites): n_sites = 0 else: n_sites = int(n_sites) n_sites_final.append(n_sites) if self.erasewhole: print('\nRemoving sequences containing at least one motif occurrence') X = self._erase_seqs_containing_motifs(X_seqs_onehot, ppms_final[-1], ppms_bg_final[-1], fracs_final[-1]) if X_neg is not None: X_neg = self._erase_seqs_containing_motifs(X_neg_seqs_onehot, ppms_final[-1], ppms_bg_final[-1], fracs_final[-1]) else: print('\nRemoving individual occurrences of motif occurrences') X = self._erase_motif_occurrences(X_seqs_onehot, ppms_final[-1], ppms_bg_final[-1], fracs_final[-1]) if X_neg is not None: X_neg = self._erase_motif_occurrences(X_neg_seqs_onehot, ppms_final[-1], ppms_bg_final[-1], fracs_final[-1]) X_seq = X X_neg_seq = X_neg if converged_early: print('\n\nYou asked to find %i motifs, but YAMDA found only %i motifs' % (self.n_motifs, len(ppms_final))) self.ppms_ = ppms_final self.ppms_bg_ = ppms_bg_final self.fracs_ = fracs_final self.n_sites_ = n_sites_final return X_seq, X_neg_seq def _batch_em(self, X, ppms, ppms_bg, fracs, epochs): M, L, W = X.shape n_filters = len(ppms) m_log_ratios = nn.Conv1d(L, n_filters, W, stride=W, bias=False) fracs = fracs.view((1, n_filters, 1)) pfms = torch.zeros((n_filters, L, W)) pfms_bg = torch.zeros((n_filters, L, W)) counts = torch.zeros((n_filters, 1)) if self.cuda: m_log_ratios.cuda() pfms = pfms.cuda() pfms_bg = pfms_bg.cuda() counts = counts.cuda() converged = False pbar_epoch = trange(0, epochs, 1, desc='Batch EM') for i in pbar_epoch: if converged: continue old_ppms = ppms # E-step, compute membership weights and letter frequencies pfms.zero_() pfms_bg.zero_() counts.zero_() m_log_ratios.weight.data = torch.log(ppms) - torch.log(ppms_bg) fracs_ratio = fracs / (1 - fracs) for j in trange(0, M, self.batch_size, desc='Pass %i/%i' % (i + 1, epochs)): batch = X[j:j + self.batch_size] x = Variable(torch.from_numpy(batch).float()) if self.cuda: x = x.cuda() log_ratios = m_log_ratios(x).data ratios = torch.exp(log_ratios) c = self.fudge * fracs_ratio * ratios state_probs = c / (1 + c) counts.add_(state_probs.sum(dim=0)) batch_motif_matrix_counts = (state_probs.unsqueeze(-1) * x.data.unsqueeze(1)).sum(dim=0) pfms.add_(batch_motif_matrix_counts) pfms_bg.add_(x.data.sum(dim=0).unsqueeze(0) - batch_motif_matrix_counts) # M-step, update parameters fracs = (counts / M).unsqueeze(0) ppms = pfms / counts.unsqueeze(2) ppms_bg = (pfms_bg.sum(dim=-1) / (W * (M - counts))).unsqueeze(2).expand(n_filters, L, W) ppms_diff_norm = (ppms - old_ppms).view(n_filters, -1).norm(p=2, dim=1) max_ppms_diff_norm = ppms_diff_norm.max() if max_ppms_diff_norm < self.tolerance: pbar_epoch.set_description('Batch EM - convergence reached after %i epochs' % (i+1)) converged = True fracs = fracs.view(-1) return ppms, ppms_bg, fracs def _online_em(self, X, ppms, ppms_bg, fracs, epochs): M, L, W = X.shape n_filters = len(ppms) m_log_ratios = nn.Conv1d(L, n_filters, W, stride=W, bias=False) fracs = fracs.view((1, n_filters, 1)) # On-line EM specific-parameters gamma_0 = 0.5 alpha = 0.85 s_0 = fracs.clone()[0].unsqueeze(-1) s_1 = s_0 * ppms s_1_bg = (1 - s_0) * ppms_bg k = 0 indices = np.random.permutation(M) if self.cuda: m_log_ratios.cuda() s_0 = s_0.cuda() s_1 = s_1.cuda() s_1_bg = s_1_bg.cuda() pbar_epoch = trange(0, epochs, 1, desc='On-line EM') converged = False for i in pbar_epoch: if converged: continue old_ppms = ppms for j in trange(0, M, self.batch_size, desc='Pass %i/%i' % (i + 1, epochs)): k += 1 m_log_ratios.weight.data = torch.log(ppms) - torch.log(ppms_bg) fracs_ratio = fracs / (1 - fracs) # E-step, compute membership weights and letter frequencies for a batch batch = X[indices[j:j + self.batch_size]] actual_batch_size = len(batch) gamma = 1.0 * actual_batch_size / self.batch_size * gamma_0 / (k ** alpha) x = Variable(torch.from_numpy(batch).float()) if self.cuda: x = x.cuda() log_ratios = m_log_ratios(x).data ratios = torch.exp(log_ratios) c = self.fudge * fracs_ratio * ratios state_probs = c / (1 + c) s_0_temp = state_probs.mean(dim=0).unsqueeze(-1) s_1_temp = (state_probs.unsqueeze(-1) * x.data.unsqueeze(1)).mean(dim=0) s_1_bg_temp = x.data.mean(dim=0).unsqueeze(0) - s_1_temp # M-step, update parameters based on batch s_0.add_(gamma * (s_0_temp - s_0)) s_1.add_(gamma * (s_1_temp - s_1)) s_1_bg.add_(gamma * (s_1_bg_temp - s_1_bg)) fracs = s_0.view((1, n_filters, 1)) ppms = s_1 / s_0 ppms_bg = (s_1_bg / (1 - s_0)).mean(-1, keepdim=True).expand((n_filters, L, W)) ppms_diff_norm = (ppms - old_ppms).view(n_filters, -1).norm(p=2, dim=1) max_ppms_diff_norm = ppms_diff_norm.max() if max_ppms_diff_norm < self.tolerance: pbar_epoch.set_description('On-line EM - convergence reached') converged = True fracs = fracs.view(-1) return ppms, ppms_bg, fracs def _compute_log_likelihood(self, X, ppms, ppms_bg, fracs): M, L, W = X.shape n_filters = len(ppms) m_log_ppms_bg = nn.Conv1d(L, n_filters, W, bias=False) m_log_ppms_bg.weight.data = torch.log(ppms_bg) m_log_ratios = nn.Conv1d(L, n_filters, W, bias=False) m_log_ratios.weight.data = torch.log(ppms) - torch.log(ppms_bg) fracs = fracs.view((1, n_filters, 1)) log_likelihoods = torch.zeros(n_filters) fracs_ratio = fracs / (1 - fracs) log_fracs_bg = torch.log(1 - fracs) if self.cuda: m_log_ppms_bg.cuda() log_likelihoods = log_likelihoods.cuda() m_log_ratios.cuda() for j in trange(0, M, self.batch_size, desc='Computing log likelihood'): batch = X[j:j + self.batch_size] x = Variable(torch.from_numpy(batch).float()) if self.cuda: x = x.cuda() ppms_bg_logprob = m_log_ppms_bg(x).data log_ratios = m_log_ratios(x).data ratios = torch.exp(log_ratios) # Added back self.fudge here, since this is the quantity that EM is technically optimizing log_likelihoods.add_((log_fracs_bg + ppms_bg_logprob + torch.log(1 + self.fudge * fracs_ratio * ratios)).sum(dim=0).view(-1)) return log_likelihoods def _erase_motif_occurrences(self, seqs_onehot, ppm, ppm_bg, frac): frac = np.array(frac.cpu()) t = np.log((1 - frac) / frac) # Threshold ppm[ppm < 1e-12] = 1e-12 # handles small probabilities spec = np.log(ppm) - np.log(ppm_bg) # spec matrix spec_revcomp = spec[::-1, ::-1] L, W = ppm.shape for i in range(0, len(seqs_onehot), 1): s = seqs_onehot[i] # grab the one hot coded sequence seqlen = s.shape[1] if seqlen < W: # leave short sequences alone continue indices = np.arange(seqlen - W + 1) conv_signal = signal.correlate2d(spec, s, 'valid')[0] seq_motif_sites = indices[conv_signal > t] if self.revcomp: conv_signal_revcomp = signal.correlate2d(spec_revcomp, s, 'valid')[0] seq_motif_sites_revcomp = indices[conv_signal_revcomp > t] seq_motif_sites = np.concatenate((seq_motif_sites, seq_motif_sites_revcomp)) for motif_site in seq_motif_sites: s[:, motif_site:motif_site+W] = 0 seqs = sequences.decode(seqs_onehot, self.alpha) return seqs def _erase_seqs_containing_motifs(self, seqs_onehot, ppm, ppm_bg, frac): frac = np.array(frac.cpu()) t = np.log((1 - frac) / frac) # Threshold ppm[ppm < 1e-12] = 1e-12 # handles small probabilities spec = np.log(ppm) - np.log(ppm_bg) # spec matrix spec_revcomp = spec[::-1, ::-1] L, W = ppm.shape seqs_onehot_filtered = [] for i in range(0, len(seqs_onehot), 1): s = seqs_onehot[i] # grab the one hot coded sequence if s.shape[1] < W: # leave short sequences alone seqs_onehot_filtered.append(s) continue conv_signal = signal.correlate2d(spec, s, 'valid')[0] s_has_motif = (conv_signal > t).any() if self.revcomp: conv_signal_revcomp = signal.correlate2d(spec_revcomp, s, 'valid')[0] s_has_motif = s_has_motif or (conv_signal_revcomp > t).any() if not s_has_motif: seqs_onehot_filtered.append(s) seqs = sequences.decode(seqs_onehot_filtered, self.alpha) return seqs