#!/usr/bin/env python
# Copyright 2017 Calico LLC

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
from __future__ import print_function

from optparse import OptionParser
import copy, os, pdb, random, shutil, subprocess, time

import h5py
import matplotlib
matplotlib.use('PDF')
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
import seaborn as sns
from sklearn import preprocessing
import tensorflow as tf

import basenji

'''
basenji_motifs.py

Collect statistics and make plots to explore the first convolution layer
of the given model using the given sequences.
'''

weblogo_opts = '-X NO -Y NO --errorbars NO --fineprint ""'
weblogo_opts += ' -C "#CB2026" A A'
weblogo_opts += ' -C "#34459C" C C'
weblogo_opts += ' -C "#FBB116" G G'
weblogo_opts += ' -C "#0C8040" T T'

################################################################################
# main
################################################################################
def main():
  usage = 'usage: %prog [options] <params_file> <model_file> <data_file>'
  parser = OptionParser(usage)
  parser.add_option(
      '-a',
      dest='act_t',
      default=0.5,
      type='float',
      help=
      'Activation threshold (as proportion of max) to consider for PWM [Default: %default]'
  )
  parser.add_option(
      '-d',
      dest='model_hdf5_file',
      default=None,
      help='Pre-computed model output as HDF5.')
  parser.add_option('-o', dest='out_dir', default='.')
  parser.add_option(
      '-m',
      dest='meme_db',
      default='%s/data/motifs/Homo_sapiens.meme' % os.environ['BASENJIDIR'],
      help='MEME database used to annotate motifs')
  parser.add_option(
      '-p',
      dest='plot_heats',
      default=False,
      action='store_true',
      help=
      'Plot heat maps describing filter activations in the test sequences [Default: %default]'
  )
  parser.add_option(
      '-s',
      dest='sample',
      default=None,
      type='int',
      help='Sample sequences from the test set [Default:%default]')
  parser.add_option(
      '-t',
      dest='trim_filters',
      default=False,
      action='store_true',
      help='Trim uninformative positions off the filter ends [Default: %default]'
  )
  (options, args) = parser.parse_args()

  if len(args) != 3:
    parser.error(
        'Must provide Basenji parameters and model files and test data in HDF5'
        ' format.'
    )
  else:
    params_file = args[0]
    model_file = args[1]
    data_file = args[2]

  if not os.path.isdir(options.out_dir):
    os.mkdir(options.out_dir)

  #################################################################
  # load data

  data_open = h5py.File(data_file)

  test_seqs1 = data_open['test_in']
  test_targets = data_open['test_out']

  try:
    target_names = list(data_open['target_labels'])
  except KeyError:
    target_names = ['t%d' % ti for ti in range(test_targets.shape[1])]

  if options.sample is not None:
    # choose sampled indexes
    sample_i = sorted(random.sample(range(test_seqs1.shape[0]), options.sample))

    # filter
    test_seqs1 = test_seqs1[sample_i]
    test_targets = test_targets[sample_i]

  # convert to letters
  test_seqs = basenji.dna_io.hot1_dna(test_seqs1)

  #################################################################
  # model parameters and placeholders

  job = basenji.dna_io.read_job_params(params_file)

  job['seq_length'] = test_seqs1.shape[1]
  job['seq_depth'] = test_seqs1.shape[2]
  job['num_targets'] = test_targets.shape[2]
  job['target_pool'] = int(np.array(data_open.get('pool_width', 1)))

  t0 = time.time()
  dr = basenji.seqnn.SeqNN()
  dr.build(job)
  print('Model building time %ds' % (time.time() - t0))

  # adjust for fourier
  job['fourier'] = 'train_out_imag' in data_open
  if job['fourier']:
    test_targets_imag = data_open['test_out_imag']
    if options.valid:
      test_targets_imag = data_open['valid_out_imag']

  #################################################################
  # predict

  # initialize batcher
  if job['fourier']:
    batcher_test = basenji.batcher.BatcherF(
        test_seqs1,
        test_targets,
        test_targets_imag,
        batch_size=dr.batch_size,
        pool_width=job['target_pool'])
  else:
    batcher_test = basenji.batcher.Batcher(
        test_seqs1,
        test_targets,
        batch_size=dr.batch_size,
        pool_width=job['target_pool'])

  # initialize saver
  saver = tf.train.Saver()

  with tf.Session() as sess:
    # load variables into session
    saver.restore(sess, model_file)

    # get weights
    filter_weights = sess.run(dr.filter_weights[0])
    filter_weights = np.transpose(np.squeeze(filter_weights), [2, 1, 0])
    print(filter_weights.shape)

    # test
    t0 = time.time()
    layer_filter_outs, _ = dr.hidden(sess, batcher_test, layers=[0])
    filter_outs = layer_filter_outs[0]
    print(filter_outs.shape)

  # store useful variables
  num_filters = filter_weights.shape[0]
  filter_size = filter_weights.shape[2]

  #################################################################
  # individual filter plots
  #################################################################
  # also save information contents
  filters_ic = []
  meme_out = meme_intro('%s/filters_meme.txt' % options.out_dir, test_seqs)

  for f in range(num_filters):
    print('Filter %d' % f)

    # plot filter parameters as a heatmap
    plot_filter_heat(filter_weights[f, :, :],
                     '%s/filter%d_heat.pdf' % (options.out_dir, f))

    # write possum motif file
    filter_possum(filter_weights[f, :, :], 'filter%d' % f,
                  '%s/filter%d_possum.txt' % (options.out_dir,
                                              f), options.trim_filters)

    # plot weblogo of high scoring outputs
    plot_filter_logo(
        filter_outs[:, :, f],
        filter_size,
        test_seqs,
        '%s/filter%d_logo' % (options.out_dir, f),
        maxpct_t=options.act_t)

    # make a PWM for the filter
    filter_pwm, nsites = make_filter_pwm('%s/filter%d_logo.fa' %
                                         (options.out_dir, f))

    if nsites < 10:
      # no information
      filters_ic.append(0)
    else:
      # compute and save information content
      filters_ic.append(info_content(filter_pwm))

      # add to the meme motif file
      meme_add(meme_out, f, filter_pwm, nsites, options.trim_filters)

  meme_out.close()

  #################################################################
  # annotate filters
  #################################################################
  # run tomtom
  subprocess.call(
      'tomtom -dist pearson -thresh 0.1 -oc %s/tomtom %s/filters_meme.txt %s' %
      (options.out_dir, options.out_dir, options.meme_db),
      shell=True)

  # read in annotations
  filter_names = name_filters(
      num_filters, '%s/tomtom/tomtom.txt' % options.out_dir, options.meme_db)

  #################################################################
  # print a table of information
  #################################################################
  table_out = open('%s/table.txt' % options.out_dir, 'w')

  # print header for later panda reading
  header_cols = ('', 'consensus', 'annotation', 'ic', 'mean', 'std')
  print('%3s  %19s  %10s  %5s  %6s  %6s' % header_cols, file=table_out)

  for f in range(num_filters):
    # collapse to a consensus motif
    consensus = filter_motif(filter_weights[f, :, :])

    # grab annotation
    annotation = '.'
    name_pieces = filter_names[f].split('_')
    if len(name_pieces) > 1:
      annotation = name_pieces[1]

    # plot density of filter output scores
    fmean, fstd = plot_score_density(
        np.ravel(filter_outs[:, :, f]),
        '%s/filter%d_dens.pdf' % (options.out_dir, f))

    row_cols = (f, consensus, annotation, filters_ic[f], fmean, fstd)
    print('%-3d  %19s  %10s  %5.2f  %6.4f  %6.4f' % row_cols, file=table_out)

  table_out.close()

  #################################################################
  # global filter plots
  #################################################################
  if options.plot_heats:
    # plot filter-sequence heatmap
    plot_filter_seq_heat(filter_outs, '%s/filter_seqs.pdf' % options.out_dir)

    # plot filter-segment heatmap
    plot_filter_seg_heat(filter_outs, '%s/filter_segs.pdf' % options.out_dir)
    plot_filter_seg_heat(
        filter_outs, '%s/filter_segs_raw.pdf' % options.out_dir, whiten=False)

    # plot filter-target correlation heatmap
    plot_target_corr(filter_outs, seq_targets, filter_names, target_names,
                     '%s/filter_target_cors_mean.pdf' % options.out_dir, 'mean')
    plot_target_corr(filter_outs, seq_targets, filter_names, target_names,
                     '%s/filter_target_cors_max.pdf' % options.out_dir, 'max')


def get_motif_proteins(meme_db_file):
  """ Hash motif_id's to protein names using the MEME DB file """
  motif_protein = {}
  for line in open(meme_db_file):
    a = line.split()
    if len(a) > 0 and a[0] == 'MOTIF':
      if a[2][0] == '(':
        motif_protein[a[1]] = a[2][1:a[2].find(')')]
      else:
        motif_protein[a[1]] = a[2]
  return motif_protein


def info_content(pwm, transpose=False, bg_gc=0.415):
  """ Compute PWM information content.

    In the original analysis, I used a bg_gc=0.5. For any
    future analysis, I ought to switch to the true hg19
    value of 0.415.
    """
  pseudoc = 1e-9

  if transpose:
    pwm = np.transpose(pwm)

  bg_pwm = [1 - bg_gc, bg_gc, bg_gc, 1 - bg_gc]

  ic = 0
  for i in range(pwm.shape[0]):
    for j in range(4):
      # ic += 0.5 + pwm[i][j]*np.log2(pseudoc+pwm[i][j])
      ic += -bg_pwm[j] * np.log2(
          bg_pwm[j]) + pwm[i][j] * np.log2(pseudoc + pwm[i][j])

  return ic


def make_filter_pwm(filter_fasta):
  """ Make a PWM for this filter from its top hits """

  nts = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
  pwm_counts = []
  nsites = 4  # pseudocounts
  for line in open(filter_fasta):
    if line[0] != '>':
      seq = line.rstrip()
      nsites += 1
      if len(pwm_counts) == 0:
        # initialize with the length
        for i in range(len(seq)):
          pwm_counts.append(np.array([1.0] * 4))

      # count
      for i in range(len(seq)):
        try:
          pwm_counts[i][nts[seq[i]]] += 1
        except KeyError:
          pwm_counts[i] += np.array([0.25] * 4)

  # normalize
  pwm_freqs = []
  for i in range(len(pwm_counts)):
    pwm_freqs.append([pwm_counts[i][j] / float(nsites) for j in range(4)])

  return np.array(pwm_freqs), nsites - 4


def meme_add(meme_out, f, filter_pwm, nsites, trim_filters=False):
  """ Print a filter to the growing MEME file

    Attrs:
        meme_out : open file
        f (int) : filter index #
        filter_pwm (array) : filter PWM array
        nsites (int) : number of filter sites
    """
  if not trim_filters:
    ic_start = 0
    ic_end = filter_pwm.shape[0] - 1
  else:
    ic_t = 0.2

    # trim PWM of uninformative prefix
    ic_start = 0
    while ic_start < filter_pwm.shape[0] and info_content(
        filter_pwm[ic_start:ic_start + 1]) < ic_t:
      ic_start += 1

    # trim PWM of uninformative suffix
    ic_end = filter_pwm.shape[0] - 1
    while ic_end >= 0 and info_content(filter_pwm[ic_end:ic_end + 1]) < ic_t:
      ic_end -= 1

  if ic_start < ic_end:
    print('MOTIF filter%d' % f, file=meme_out)
    print(
        'letter-probability matrix: alength= 4 w= %d nsites= %d' %
        (ic_end - ic_start + 1, nsites),
        file=meme_out)

    for i in range(ic_start, ic_end + 1):
      print('%.4f %.4f %.4f %.4f' % tuple(filter_pwm[i]), file=meme_out)
    print('', file=meme_out)


def meme_intro(meme_file, seqs):
  """ Open MEME motif format file and print intro

    Attrs:
        meme_file (str) : filename
        seqs [str] : list of strings for obtaining background freqs

    Returns:
        mem_out : open MEME file
    """
  nts = {'A': 0, 'C': 1, 'G': 2, 'T': 3}

  # count
  nt_counts = [1] * 4
  for i in range(len(seqs)):
    for nt in seqs[i]:
      try:
        nt_counts[nts[nt]] += 1
      except KeyError:
        pass

  # normalize
  nt_sum = float(sum(nt_counts))
  nt_freqs = [nt_counts[i] / nt_sum for i in range(4)]

  # open file for writing
  meme_out = open(meme_file, 'w')

  # print intro material
  print('MEME version 4', file=meme_out)
  print('', file=meme_out)
  print('ALPHABET= ACGT', file=meme_out)
  print('', file=meme_out)
  print('Background letter frequencies:', file=meme_out)
  print('A %.4f C %.4f G %.4f T %.4f' % tuple(nt_freqs), file=meme_out)
  print('', file=meme_out)

  return meme_out


def name_filters(num_filters, tomtom_file, meme_db_file):
  """ Name the filters using Tomtom matches.

    Attrs:
        num_filters (int) : total number of filters
        tomtom_file (str) : filename of Tomtom output table.
        meme_db_file (str) : filename of MEME db

    Returns:
        filter_names [str] :
    """
  # name by number
  filter_names = ['f%d' % fi for fi in range(num_filters)]

  # name by protein
  if tomtom_file is not None and meme_db_file is not None:
    motif_protein = get_motif_proteins(meme_db_file)

    # hash motifs and q-value's by filter
    filter_motifs = {}

    tt_in = open(tomtom_file)
    tt_in.readline()
    for line in tt_in:
      a = line.split()
      fi = int(a[0][6:])
      motif_id = a[1]
      qval = float(a[5])

      filter_motifs.setdefault(fi, []).append((qval, motif_id))

    tt_in.close()

    # assign filter's best match
    for fi in filter_motifs:
      top_motif = sorted(filter_motifs[fi])[0][1]
      filter_names[fi] += '_%s' % motif_protein[top_motif]

  return np.array(filter_names)


################################################################################
# plot_target_corr
#
# Plot a clustered heatmap of correlations between filter activations and
# targets.
#
# Input
#  filter_outs:
#  filter_names:
#  target_names:
#  out_pdf:
################################################################################
def plot_target_corr(filter_outs, seq_targets, filter_names, target_names, out_pdf, seq_op='mean'):
  num_seqs = filter_outs.shape[0]
  num_targets = len(target_names)

  if seq_op == 'mean':
    filter_outs_seq = filter_outs.mean(axis=2)
  else:
    filter_outs_seq = filter_outs.max(axis=2)

  # std is sequence by filter.
  filter_seqs_std = filter_outs_seq.std(axis=0)
  filter_outs_seq = filter_outs_seq[:, filter_seqs_std > 0]
  filter_names_live = filter_names[filter_seqs_std > 0]

  filter_target_cors = np.zeros((len(filter_names_live), num_targets))
  for fi in range(len(filter_names_live)):
    for ti in range(num_targets):
      cor, p = spearmanr(filter_outs_seq[:, fi], seq_targets[:num_seqs, ti])
      filter_target_cors[fi, ti] = cor

  cor_df = pd.DataFrame(
      filter_target_cors, index=filter_names_live, columns=target_names)

  sns.set(font_scale=0.3)
  plt.figure()
  sns.clustermap(cor_df, cmap='BrBG', center=0, figsize=(8, 10))
  plt.savefig(out_pdf)
  plt.close()


################################################################################
# plot_filter_seq_heat
#
# Plot a clustered heatmap of filter activations in
#
# Input
#  param_matrix: np.array of the filter's parameter matrix
#  out_pdf:
################################################################################
def plot_filter_seq_heat(filter_outs, out_pdf, whiten=True, drop_dead=True):
  # compute filter output means per sequence
  filter_seqs = filter_outs.mean(axis=2)

  # whiten
  if whiten:
    filter_seqs = preprocessing.scale(filter_seqs)

  # transpose
  filter_seqs = np.transpose(filter_seqs)

  if drop_dead:
    filter_stds = filter_seqs.std(axis=1)
    filter_seqs = filter_seqs[filter_stds > 0]

  # downsample sequences
  seqs_i = np.random.randint(0, filter_seqs.shape[1], 500)

  hmin = np.percentile(filter_seqs[:, seqs_i], 0.1)
  hmax = np.percentile(filter_seqs[:, seqs_i], 99.9)

  sns.set(font_scale=0.3)

  plt.figure()
  sns.clustermap(
      filter_seqs[:, seqs_i],
      row_cluster=True,
      col_cluster=True,
      linewidths=0,
      xticklabels=False,
      vmin=hmin,
      vmax=hmax)
  plt.savefig(out_pdf)
  #out_png = out_pdf[:-2] + 'ng'
  #plt.savefig(out_png, dpi=300)
  plt.close()


################################################################################
# plot_filter_seq_heat
#
# Plot a clustered heatmap of filter activations in sequence segments.
#
# Mean doesn't work well for the smaller segments for some reason, but taking
# the max looks OK. Still, similar motifs don't cluster quite as well as you
# might expect.
#
# Input
#  filter_outs
################################################################################
def plot_filter_seg_heat(filter_outs, out_pdf, whiten=True, drop_dead=True):
  b = filter_outs.shape[0]
  f = filter_outs.shape[1]
  l = filter_outs.shape[2]

  s = 5
  while l / float(s) - (l / s) > 0:
    s += 1
  print('%d segments of length %d' % (s, l / s))

  # split into multiple segments
  filter_outs_seg = np.reshape(filter_outs, (b, f, s, l / s))

  # mean across the segments
  filter_outs_mean = filter_outs_seg.max(axis=3)

  # break each segment into a new instance
  filter_seqs = np.reshape(np.swapaxes(filter_outs_mean, 2, 1), (s * b, f))

  # whiten
  if whiten:
    filter_seqs = preprocessing.scale(filter_seqs)

  # transpose
  filter_seqs = np.transpose(filter_seqs)

  if drop_dead:
    filter_stds = filter_seqs.std(axis=1)
    filter_seqs = filter_seqs[filter_stds > 0]

  # downsample sequences
  seqs_i = np.random.randint(0, filter_seqs.shape[1], 500)

  hmin = np.percentile(filter_seqs[:, seqs_i], 0.1)
  hmax = np.percentile(filter_seqs[:, seqs_i], 99.9)

  sns.set(font_scale=0.3)
  if whiten:
    dist = 'euclidean'
  else:
    dist = 'cosine'

  plt.figure()
  sns.clustermap(
      filter_seqs[:, seqs_i],
      metric=dist,
      row_cluster=True,
      col_cluster=True,
      linewidths=0,
      xticklabels=False,
      vmin=hmin,
      vmax=hmax)
  plt.savefig(out_pdf)
  #out_png = out_pdf[:-2] + 'ng'
  #plt.savefig(out_png, dpi=300)
  plt.close()


################################################################################
# filter_motif
#
# Collapse the filter parameter matrix to a single DNA motif.
#
# Input
#  param_matrix: np.array of the filter's parameter matrix
#  out_pdf:
################################################################################
def filter_motif(param_matrix):
  nts = 'ACGT'

  motif_list = []
  for v in range(param_matrix.shape[1]):
    max_n = 0
    for n in range(1, 4):
      if param_matrix[n, v] > param_matrix[max_n, v]:
        max_n = n

    if param_matrix[max_n, v] > 0:
      motif_list.append(nts[max_n])
    else:
      motif_list.append('N')

  return ''.join(motif_list)


################################################################################
# filter_possum
#
# Write a Possum-style motif
#
# Input
#  param_matrix: np.array of the filter's parameter matrix
#  out_pdf:
################################################################################
def filter_possum(param_matrix, motif_id, possum_file, trim_filters=False, mult=200):
  # possible trim
  trim_start = 0
  trim_end = param_matrix.shape[1] - 1
  trim_t = 0.3
  if trim_filters:
    # trim PWM of uninformative prefix
    while trim_start < param_matrix.shape[1] and np.max(
        param_matrix[:, trim_start]) - np.min(
            param_matrix[:, trim_start]) < trim_t:
      trim_start += 1

    # trim PWM of uninformative suffix
    while trim_end >= 0 and np.max(param_matrix[:, trim_end]) - np.min(
        param_matrix[:, trim_end]) < trim_t:
      trim_end -= 1

  if trim_start < trim_end:
    possum_out = open(possum_file, 'w')
    print('BEGIN GROUP', file=possum_out)
    print('BEGIN FLOAT', file=possum_out)
    print('ID %s' % motif_id, file=possum_out)
    print('AP DNA', file=possum_out)
    print('LE %d' % (trim_end + 1 - trim_start), file=possum_out)
    for ci in range(trim_start, trim_end + 1):
      print(
          'MA %s' % ' '.join(['%.2f' % (mult * n)
                              for n in param_matrix[:, ci]]),
          file=possum_out)
    print('END', file=possum_out)
    print('END', file=possum_out)

    possum_out.close()


################################################################################
# plot_filter_heat
#
# Plot a heatmap of the filter's parameters.
#
# Input
#  param_matrix: np.array of the filter's parameter matrix
#  out_pdf:
################################################################################
def plot_filter_heat(param_matrix, out_pdf):
  param_range = abs(param_matrix).max()

  sns.set(font_scale=2)
  plt.figure(figsize=(param_matrix.shape[1], 4))
  sns.heatmap(
      param_matrix,
      cmap='PRGn',
      linewidths=0.2,
      vmin=-param_range,
      vmax=param_range)
  ax = plt.gca()
  ax.set_xticklabels(range(1, param_matrix.shape[1] + 1))
  ax.set_yticklabels('TGCA', rotation='horizontal')  # , size=10)
  plt.savefig(out_pdf)
  plt.close()


################################################################################
# plot_filter_logo
#
# Plot a weblogo of the filter's occurrences
#
# Input
#  param_matrix: np.array of the filter's parameter matrix
#  out_pdf:
################################################################################
def plot_filter_logo(filter_outs, filter_size, seqs, out_prefix, raw_t=0, maxpct_t=None):
  if maxpct_t:
    all_outs = np.ravel(filter_outs)
    all_outs_mean = all_outs.mean()
    all_outs_norm = all_outs - all_outs_mean
    raw_t = maxpct_t * all_outs_norm.max() + all_outs_mean

  left_pad = (filter_size - 1) // 2
  right_pad = filter_size - left_pad

  # print fasta file of positive outputs
  filter_fasta_out = open('%s.fa' % out_prefix, 'w')
  filter_count = 0

  for i in range(filter_outs.shape[0]):
    for j in range(filter_outs.shape[1]):
      if filter_outs[i, j] > raw_t:
        # construct kmer
        kmer = ''

        # determine boundaries, considering padding
        fstart = j - left_pad
        fend = fstart + filter_size

        # if it starts in left_pad
        if fstart < 0:
          kmer += 'N' * (-fstart)
          fstart = 0

        # add primary sequence
        kmer += seqs[i][fstart:fend]

        # if it ends in right_pad
        if fend > len(seqs[i]):
          kmer += 'N' * (fend - len(seqs[i]))

        # output
        print('>%d_%d' % (i, j), file=filter_fasta_out)
        print(kmer, file=filter_fasta_out)
        filter_count += 1

  filter_fasta_out.close()

  # make weblogo
  if filter_count > 0:
    weblogo_cmd = 'weblogo %s < %s.fa > %s.eps' % (weblogo_opts, out_prefix,
                                                   out_prefix)
    subprocess.call(weblogo_cmd, shell=True)


################################################################################
# plot_score_density
#
# Plot the score density and print to the stats table.
#
# Input
#  param_matrix: np.array of the filter's parameter matrix
#  out_pdf:
################################################################################
def plot_score_density(f_scores, out_pdf):
  sns.set(font_scale=1.3)
  plt.figure()
  sns.distplot(f_scores, kde=False)
  plt.xlabel('ReLU output')
  plt.savefig(out_pdf)
  plt.close()

  return f_scores.mean(), f_scores.std()


################################################################################
# __main__
################################################################################
if __name__ == '__main__':
  main()
  # pdb.runcall(main)