import logging import numpy as np import pandas as pd import scipy.linalg as lin from itertools import chain, combinations from numpy.linalg import multi_dot as mdot import pyfocus as pf __all__ = ["fine_map"] def add_credible_set(df, credible_set=0.9): """ Compute the credible gene set and add it to the dataframe. :param df: pandas.DataFrame containing TWAS summary results :param credible_set: float sensitivity to compute the credble set :return: pandas.DataFrame containing TWAS summary results augmented with `in_cred_set` flag """ df = df.sort_values(by=["pip"], ascending=False) # add credible set flag psum = np.sum(df.pip.values) npost = df.pip.values / psum csum = np.cumsum(npost) in_cred_set = csum <= credible_set df["in_cred_set"] = in_cred_set.astype(int) return df def create_output(meta_data, attr, zscores, pips, null_res, region): """ Creates TWAS pandas.DataFrame output. :param meta_data: pandas.DataFrame Metadata about the gene models :param attr: pandas.DataFrame Prediction performance metadata about gene models :param zscores: numpy.ndarray of TWAS zscores :param pips: numpy.ndarray of posterior inclusion probabilities (PIPs) :param null_res: float posterior probability of the null :param region: str region identifier :return: pandas.DataFrame TWAS summary results """ # merge attributes df = pd.merge(meta_data, attr, left_on="model_id", right_index=True) df["twas_z"] = zscores df["pip"] = pips df["in_cred_set"] = 0 df["region"] = region # sort by tx start site and we're good to go df = df.sort_values(by="tx_start") # chrom is listed twice (once from weight and once from molfeature) idxs = np.where(df.columns == "chrom")[0] if len(idxs) > 1: df = df.iloc[:, [j for j, c in enumerate(df.columns) if j != idxs[0]]] # drop model-id df = df.drop("model_id", axis=1) # add null model result null_dict = dict() for c in df.columns: null_dict[c] = None null_dict["ens_gene_id"] = "NULL.MODEL" null_dict["mol_name"] = "NULL" null_dict["type"] = "NULL" null_dict["pip"] = null_res null_dict["region"] = region null_dict["twas_z"] = 0 null_dict["chrom"] = df["chrom"].values[0] df = df.append(null_dict, ignore_index=True) return df def align_data(gwas, ref_geno, wcollection, ridge=0.1): """ Align and merge gwas, LD reference, and eQTL weight data to the same reference alleles. :param gwas: pyfocus.GWAS object containing a risk region :param ref_geno: pyfocus.LDRefPanel object containing reference genotypes at risk region :param wcollection: pandas.DataFrame object containing overlapping eQTL weights :param ridge: ridge adjustment for LD estimation (default = 0.1) :return: tuple of aligned GWAS, eQTL weight-matrix W, gene-names list, LD-matrix V """ log = logging.getLogger(pf.LOG) # align gwas with ref snps merged_snps = ref_geno.overlap_gwas(gwas) if len(merged_snps) == 0: log.info("No overlap between LD reference and GWAS") return None ref_snps = merged_snps.loc[~pd.isna(merged_snps.i)] # filter out mis-matched SNPs matched = pf.check_valid_alleles(ref_snps[pf.GWAS.A1COL], ref_snps[pf.GWAS.A2COL], ref_snps[pf.LDRefPanel.A1COL], ref_snps[pf.LDRefPanel.A2COL]) n_miss = sum(np.logical_not(matched)) if n_miss > 0: log.debug("Pruned {} SNPs due to invalid allele pairs between GWAS/RefPanel.".format(n_miss)) ref_snps = ref_snps.loc[matched] # flip Zscores to match reference panel ref_snps[pf.GWAS.ZCOL] = pf.flip_alleles(ref_snps[pf.GWAS.ZCOL].values, ref_snps[pf.GWAS.A1COL], ref_snps[pf.GWAS.A2COL], ref_snps[pf.LDRefPanel.A1COL], ref_snps[pf.LDRefPanel.A2COL]) # collapse the gene models into a single weight matrix idxs = [] final_df = None for eid, model in wcollection.groupby(["ens_gene_id", "tissue", "inference", "ref_name"]): log.debug("Aligning weights for gene {}".format(eid)) # merge local model with the reference panel # effect_allele alt_allele effect m_merged = pd.merge(ref_snps, model, how="inner", left_on=pf.GWAS.SNPCOL, right_on="snp") m_matched = pf.check_valid_alleles(m_merged["effect_allele"], m_merged["alt_allele"], m_merged[pf.LDRefPanel.A1COL], m_merged[pf.LDRefPanel.A2COL]) n_miss = sum(np.logical_not(m_matched)) if n_miss > 0: log.debug("Gene {} pruned {} SNPs due to invalid allele pairs between weight-db/GWAS.".format(eid, n_miss)) m_merged = m_merged.loc[m_matched] # make sure effects are for same ref allele as GWAS + reference panel m_merged["effect"] = pf.flip_alleles(m_merged["effect"].values, m_merged["effect_allele"], m_merged["alt_allele"], m_merged[pf.LDRefPanel.A1COL], m_merged[pf.LDRefPanel.A2COL]) # skip genes whose overlapping weights are all 0s if len(m_merged) > 1 and all(np.isclose(m_merged["effect"], 0)): log.debug("Gene {} has only zero-weights. This will break variance estimate. Skipping.".format(eid)) continue # skip genes that do not have weights at referenced SNPs if all(pd.isnull(m_merged["effect"])): log.debug("Gene {} has no overlapping weights. Skipping.".format(eid)) continue # keep model_id around to grab other attributes (pred-R2, etc) later on cur_idx = model.index[0] idxs.append(cur_idx) # perform a union (outer merge) to build the aligned/flipped weight (possibly jagged) matrix if final_df is None: final_df = m_merged[[pf.GWAS.SNPCOL, "effect"]] final_df = final_df.rename(index=str, columns={"effect": "model_{}".format(cur_idx)}) else: final_df = pd.merge(final_df, m_merged[[pf.GWAS.SNPCOL, "effect"]], how="outer", on="SNP") final_df = final_df.rename(index=str, columns={"effect": "model_{}".format(cur_idx)}) # break out early if len(idxs) == 0: log.info("No weights overlapped GWAS data") return None # final align back with GWAS + reference panel ref_snps = pd.merge(ref_snps, final_df, how="inner", on=pf.GWAS.SNPCOL) # compute linkage-disequilibrium estimate log.debug("Estimating LD for {} SNPs".format(len(ref_snps))) ldmat = ref_geno.estimate_ld(ref_snps, adjust=ridge) # subset down to just actual GWAS data gwas = ref_snps[pf.GWAS.REQ_COLS] # need to replace NA with 0 due to jaggedness across genes wmat = ref_snps.filter(like="model").values wmat[np.isnan(wmat)] = 0.0 # Meta-data on the current model # what other things should we include in here? meta_data = wcollection.loc[idxs, ["ens_gene_id", "ens_tx_id", "mol_name", "tissue", "ref_name", "type", "chrom", "tx_start", "tx_stop", "inference", "model_id"] ] # re-rorder by tx_start ranks = np.argsort(meta_data["tx_start"].values) wmat = wmat.T[ranks].T meta_data = meta_data.iloc[ranks] return gwas, wmat, meta_data, ldmat def estimate_cor(wmat, ldmat, intercept=False): """ Estimate the sample correlation structure for predicted expression. :param wmat: numpy.ndarray eQTL weight matrix for a risk region :param ldmat: numpy.ndarray LD matrix for a risk region :param intercept: bool to return the intercept variable or not :return: tuple (pred_expr correlation, intercept variable; None if intercept=False) """ wcov = mdot([wmat.T, ldmat, wmat]) scale = np.diag(1 / np.sqrt(np.diag(wcov))) wcor = mdot([scale, wcov, scale]) if intercept: inter = mdot([scale, wmat.T, ldmat]) return wcor, inter else: return wcor, None def assoc_test(weights, gwas, ldmat, heterogeneity=False): """ TWAS association test. :param weights: numpy.ndarray of eQTL weights :param gwas: pyfocus.GWAS object :param ldmat: numpy.ndarray LD matrix :param heterogeneity: bool estimate variance from multiplicative random effect :return: tuple (beta, se) """ p = ldmat.shape[0] assoc = np.dot(weights, gwas.Z) if heterogeneity: resid = assoc - gwas.Z resid_var = mdot([resid, lin.pinvh(ldmat), resid]) / p else: resid_var = 1 se = np.sqrt(resid_var * mdot([weights, ldmat, weights])) return assoc, se def get_resid(zscores, swld, wcor): """ Regress out the average pleiotropic signal tagged by TWAS at the region :param zscores: numpy.ndarray TWAS zscores :param swld: numpy.ndarray intercept variable :param wcor: numpy.ndarray predicted expression correlation :return: tuple (residual TWAS zscores, intercept z-score) """ m, m = wcor.shape m, p = swld.shape # create mean factor intercept = swld.dot(np.ones(p)) # estimate under the null for variance components, i.e. V = SW LD SW wcor_inv, rank = lin.pinvh(wcor, return_rank=True) numer = mdot([intercept.T, wcor_inv, zscores]) denom = mdot([intercept.T, wcor_inv, intercept]) alpha = numer / denom resid = zscores - intercept * alpha s2 = mdot([resid, wcor_inv, resid]) / (rank - 1) inter_se = np.sqrt(s2 / denom) inter_z = alpha / inter_se return resid, inter_z def bayes_factor(zscores, idx_set, wcor, prior_chisq, prb, use_log=True): """ Compute the Bayes Factor for the evidence that a set of genes explain the observed association signal under the correlation structure. :param zscores: numpy.ndarray TWAS zscores :param idx_set: list the indices representing the causal gene-set :param wcor: numpy.ndarray predicted expression correlation :param prior_chisq: float prior effect-size variance scaled by GWAS sample size :param prb: float prior probability for a gene to be causal :param use_log: bool whether to compute the log Bayes Factor :return: float the Bayes Factor (log Bayes Factor if use_log = True) """ idx_set = np.array(idx_set) m = len(zscores) # only need genes in the causal configuration using FINEMAP BF trick nc = len(idx_set) cur_chi2 = prior_chisq / nc cur_wcor = wcor[idx_set].T[idx_set].T cur_zscores = zscores[idx_set] # compute SVD for robust estimation if nc > 1: cur_U, cur_EIG, _ = lin.svd(cur_wcor) scaled_chisq = (cur_zscores.T.dot(cur_U)) ** 2 else: cur_U, cur_EIG = 1, cur_wcor scaled_chisq = cur_zscores ** 2 # log BF + log prior cur_bf = 0.5 * -np.sum(np.log(1 + cur_chi2 * cur_EIG)) + \ 0.5 * np.sum((cur_chi2 / (1 + cur_chi2 * cur_EIG)) * scaled_chisq) + \ nc * np.log(prb) + (m - nc) * np.log(1 - prb) if use_log: return cur_bf else: return np.exp(cur_bf) def fine_map(gwas, wcollection, ref_geno, intercept=False, heterogeneity=False, max_genes=3, ridge=0.1, prior_prob=1e-3, prior_chisq=40, credible_level=0.9, plot=False): """ Perform a TWAS and fine-map the results. :param gwas: pyfocus.GWAS object for the risk region :param wcollection: pandas.DataFrame containing overlapping eQTL weight information for the risk region :param ref_geno: pyfocus.LDRefPanel object for the risk region :param intercept: bool flag to estimate the average TWAS signal due to tagged pleiotropy :param heterogeneity: bool flag to compute sample variance in TWAS test assuming multiplicative random effect :param max_genes: int or None the maximum number of genes to include in any given causal configuration. None if all genes :param ridge: float ridge adjustment for LD estimation (default = 0.1) :param prior_prob: float prior probability for a gene to be causal :param prior_chisq: float prior effect-size variance scaled by GWAS sample size :param credible_level: float the credible-level to compute credible gene sets (default = 0.9) :param plot: bool whether or not to generate visualizations/plots at the risk region :return: pandas.DataFrame containing the TWAS statistics and fine-mapping results if plot=False. (pandas.DataFrame, list of plot-objects) if plot=True """ log = logging.getLogger(pf.LOG) log.info("Starting fine-mapping at region {}".format(ref_geno)) # align all GWAS, LD reference, and overlapping molecular weights parameters = align_data(gwas, ref_geno, wcollection, ridge=ridge) if parameters is None: # break; logging of specific reason should be in align_data return None else: gwas, wmat, meta_data, ldmat = parameters zscores = [] # run local TWAS for idx, weights in enumerate(wmat.T): log.debug("Computing TWAS association statistic for gene {}".format(meta_data.iloc[idx]["ens_gene_id"])) beta, se = assoc_test(weights, gwas, ldmat, heterogeneity) zscores.append(beta / se) # perform fine-mapping zscores = np.array(zscores) log.debug("Estimating local TWAS correlation structure") wcor, swld = estimate_cor(wmat, ldmat, intercept) m = len(zscores) rm = range(m) pips = np.zeros(m) if intercept: # should really be done at the SNP level first ala Barfield et al 2018 log.debug("Regressing out average tagged pleiotropic associations") zscores, inter_z = get_resid(zscores, swld, wcor) else: inter_z = None k = m if max_genes > m else max_genes null_res = m * np.log(1 - prior_prob) marginal = null_res # enumerate all subsets for subset in chain.from_iterable(combinations(rm, n) for n in range(1, k + 1)): local = bayes_factor(zscores, subset, wcor, prior_chisq, prior_prob) # keep track for marginal likelihood marginal = np.logaddexp(local, marginal) # marginalize the posterior for marginal-posterior on causals for idx in subset: if pips[idx] == 0: pips[idx] = local else: pips[idx] = np.logaddexp(pips[idx], local) pips = np.exp(pips - marginal) null_res = np.exp(null_res - marginal) # Query the db to grab model attributes # We might want to filter to only certain attributes at some point session = pf.get_session() attr = pd.read_sql(session.query(pf.ModelAttribute) .filter(pf.ModelAttribute.model_id.in_(meta_data.model_id.values.astype(object))) # why doesn't inte64 work!?! .statement, con=session.connection()) # convert from long to wide format attr = attr.pivot("model_id", "attr_name", "value") # clean up and return results region = str(ref_geno).replace(" ", "") # dont sort here to make plotting easier df = create_output(meta_data, attr, zscores, pips, null_res, region) log.info("Completed fine-mapping at region {}".format(ref_geno)) if plot: log.info("Creating FOCUS plots at region {}".format(ref_geno)) plot_arr = pf.focus_plot(wcor, df) # sort here and create credible set df = add_credible_set(df, credible_set=credible_level) return df, plot_arr else: # sort here and create credible set df = add_credible_set(df, credible_set=credible_level) return df