Code automatically generated with cleaning of non-relevant contents
Please cite: Xu et al. De novo visual proteomics of single cells through pattern mining

import os
import sys
import copy
import time
import traceback
import warnings
import uuid
import pickle as pickle
import shutil
from collections import defaultdict
import scipy.stats as SCS
from multiprocessing.pool import Pool as Pool
import numpy as N
import numpy.fft as NF
from aitom.tomominer.io.cache import Cache
from aitom.tomominer.common.obj import Object
import aitom.tomominer.dimension_reduction.util as DU
import aitom.image.vol.util as UV
import aitom.tomominer.io.file as IV
import aitom.geometry.ang_loc as AAL
import aitom.geometry.rotate as GR
import aitom.tomominer.align.fast.full as AFF
import aitom.tomominer.align.refine.gradient_refine as AFGF
import aitom.align.fast.util as AU
import aitom.tomominer.statistics.ssnr as SS
import aitom.tomominer.segmentation.watershed as SW
import aitom.tomominer.segmentation.active_contour.chan_vese.segment as SACS
import aitom.tomominer.filter.gaussian as FG
import aitom.tomominer.model.util as MU
import aitom.tomominer.core.core as core

def align_vols_with_wedge(v1, m1, v2, m2, op=None, logger=None):
    err = None
    if ('fast_align_and_refine' in op):
            faar_op = copy.deepcopy(op['fast_align_and_refine'])
            faar_op['fast_max_l'] = op['L']
            faar_op.update({'v1': v1, 'm1': m1, 'v2': v2, 'm2': m2, })
            re = AFGF.fast_align_and_refine(faar_op)
            ang = re['ang']
            loc = re['loc_r']
            score = re['cor']
        except Exception as e:
            err = traceback.format_exc()
            re = AU.align_vols(v1=v1, m1=m1, v2=v2, m2=m2, L=op['L'])
            ang = re['angle']
            loc = re['loc']
            score = re['score']
        except Exception as e:
            err = traceback.format_exc()
    if (err is not None):
        score = float('nan')
        loc = N.zeros(3)
        ang = (N.random.random(3) * (N.pi * 2))
    return {'angle': ang, 'loc': loc, 'score': score, 'err': err, }

def align_keys(self, v1k, v2k, op, v2_out_key=None):
    if (self is None):
        self = Object()
        self.cache = Cache()
    if op['with_missing_wedge']:
        v1 = self.cache.get_mrc(v1k['subtomogram'])
        if ('mask' in v1k):
            m1 = self.cache.get_mrc(v1k['mask'])
            m1 = None
        v2 = self.cache.get_mrc(v2k['subtomogram'])
        if ('mask' in v2k):
            m2 = self.cache.get_mrc(v2k['mask'])
            m2 = None
        re = align_vols_with_wedge(v1, m1, v2, m2, op=op)
        v1 = self.cache.get_mrc(v1k['subtomogram'])
        v2 = self.cache.get_mrc(v2k['subtomogram'])
        re = align_vols_no_wedge(v1=v1, v2=v2, op=op)
    re['v1_key'] = v1k
    re['v2_key'] = v2k
    re['v2_out_key'] = v2_out_key
    if (v2_out_key is not None):
        v2r = GR.rotate_pad_mean(v2, angle=re['angle'], loc_r=re['loc'])
        IV.put_mrc(v2r, v2_out_key['subtomogram'])
        if op['with_missing_wedge']:
            m2r = GR.rotate_mask(m2, angle=re['angle'])
            IV.put_mrc(m2r, v2_out_key['mask'])
    re['op'] = op
    return re

def rotate_key(self, vk, vk_out, angle, loc):
    if (self is not None):
        v = self.cache.get_mrc(vk['subtomogram'])
        v = IV.get_mrc(vk['subtomogram'])
    vr = GR.rotate_pad_mean(v, angle=angle, loc_r=loc)
    IV.put_mrc(vr, vk_out['subtomogram'])
    if (('mask' in vk) and ('mask' in vk_out)):
        if (self is not None):
            m = self.cache.get_mrc(vk['mask'])
            m = IV.get_mrc(vk['mask'])
        mr = GR.rotate_mask(m, angle=angle)
        IV.put_mrc(mr, vk_out['mask'])

def impute_aligned_vols(t, v, vm, normalize=None):
    assert (normalize is not None)
    if normalize:
        v = ((v - v.mean()) / v.std())
        if (t is not None):
            t = ((t - t.mean()) / t.std())
    if (t is None):
        return v
    t_f = NF.fftshift(NF.fftn(t))
    v_f = NF.fftshift(NF.fftn(v))
    v_f[(vm == 0)] = t_f[(vm == 0)]
    v_f_if = N.real(NF.ifftn(NF.ifftshift(v_f)))
    if normalize:
        v_f_if = ((v_f_if - v_f_if.mean()) / v_f_if.std())
    if N.all(N.isfinite(v_f_if)):
        return v_f_if
        print('warning: imputation failed')
        return v

def impute_vols(v, vm, ang, loc, t=None, align_to_template=True, normalize=None):
    ang = N.array(ang, dtype=N.float)
    loc = N.array(loc, dtype=N.float)
    v_r = None
    vm_r = None
    t_r = None
    if align_to_template:
        v_r = GR.rotate_pad_mean(v, angle=ang, loc_r=loc)
        assert N.all(N.isfinite(v_r))
        vm_r = GR.rotate_mask(vm, angle=ang)
        assert N.all(N.isfinite(vm_r))
        vi = impute_aligned_vols(t=t, v=v_r, vm=(vm_r > 0.5), normalize=normalize)
        (ang_inv, loc_inv) = AAL.reverse_transform_ang_loc(ang, loc)
        if (t is not None):
            assert N.all(N.isfinite(t))
            t_r = GR.rotate_vol_pad_mean(t, angle=ang_inv, loc_r=loc_inv)
        vi = impute_aligned_vols(t=t_r, v=v, vm=(vm > 0.5), normalize=normalize)
    return {'vi': vi, 'v_r': v_r, 'vm_r': vm_r, 't_r': t_r, }

def impute_vol_keys(vk, ang, loc, normalize, tk=None, align_to_template=True, cache=None):
    v = cache.get_mrc(vk['subtomogram'])
    if (not N.all(N.isfinite(v))):
        raise Exception('error loading', vk['subtomogram'])
    vm = cache.get_mrc(vk['mask'])
    if (not N.all(N.isfinite(vm))):
        raise Exception('error loading', vk['mask'])
    t = None
    if (tk is not None):
        t = IV.get_mrc(tk['subtomogram'])
        if (not N.all(N.isfinite(t))):
            raise Exception('error loading', tk['subtomogram'])
    return impute_vols(v=v, vm=vm, ang=ang, loc=loc, t=t, align_to_template=align_to_template, normalize=normalize)

def neighbor_covariance_avg__parallel(self, data_json, segmentation_tg_op, normalize, n_chunk):
    start_time = time.time()
    data_json_copy = [_ for _ in data_json]
    inds = list(range(len(data_json_copy)))
    tasks = []
    while data_json_copy:
        data_json_copy_part = data_json_copy[:n_chunk]
        inds_t = inds[:n_chunk]
        tasks.append(self.runner.task(module='tomominer.pursuit.multi.util', method='neighbor_covariance__collect_info', kwargs={'data_json': data_json_copy_part, 'segmentation_tg_op': segmentation_tg_op, 'normalize': normalize, }))
        data_json_copy = data_json_copy[n_chunk:]
        inds = inds[n_chunk:]
    sum_global = None
    neighbor_prod_sum = None
    for res in self.runner.run__except(tasks):
        with open(res.result) as f:
            re = pickle.load(f)
        if (sum_global is None):
            sum_global = re['sum']
            sum_global += re['sum']
        assert N.all(N.isfinite(sum_global))
        if (neighbor_prod_sum is None):
            neighbor_prod_sum = re['neighbor_prod_sum']
            neighbor_prod_sum += re['neighbor_prod_sum']
        assert N.all(N.isfinite(neighbor_prod_sum))
    avg_global = (sum_global / len(data_json))
    neighbor_prod_avg = (neighbor_prod_sum / len(data_json))
    shift = re['shift']
    cov = N.zeros(neighbor_prod_avg.shape)
    for i in range(shift.shape[0]):
        cov[:, :, :, i] = (neighbor_prod_avg[:, :, :, i] - (avg_global * UV.roll(avg_global, shift[(i, 0)], shift[(i, 1)], shift[(i, 2)])))
    cov_avg = N.mean(cov, axis=3)
    print('Calculated neighbor covariance for', len(data_json), 'subtomograms', (' : %2.6f sec' % (time.time() - start_time)))
    return cov_avg

def neighbor_covariance__collect_info(self, data_json, segmentation_tg_op, normalize):
    sum_local = None
    neighbor_prod_sum = None
    for rec in data_json:
        if ('template' not in rec):
            rec['template'] = None
        vri = impute_vol_keys(vk=rec, ang=rec['angle'], loc=rec['loc'], tk=rec['template'], align_to_template=True, normalize=normalize, cache=self.cache)['vi']
        if ((segmentation_tg_op is not None) and (rec['template'] is not None) and ('segmentation' in rec['template'])):
            phi = IV.read_mrc(rec['template']['segmentation'])['value']
            vri_s = template_guided_segmentation(v=vri, m=(phi > 0.5), op=segmentation_tg_op)
            if (vri_s is not None):
                vri = vri_s
                del vri_s
                assert (normalize is not None)
                if normalize:
                    vri_t = N.zeros(vri.shape)
                    vri_f = N.isfinite(vri)
                    if (vri_f.sum() > 0):
                        vri_t[vri_f] = ((vri[vri_f] - vri[vri_f].mean()) / vri[vri_f].std())
                    vri = vri_t
                    del vri_f, vri_t
                    vri_t = N.zeros(vri.shape)
                    vri_f = N.isfinite(vri)
                    vri_t[vri_f] = vri[vri_f]
                    if (vri_f.sum() > 0):
                        vri_t[N.logical_not(vri_f)] = vri[vri_f].mean()
                    vri = vri_t
                    del vri_f, vri_t
        if (sum_local is None):
            sum_local = vri
            sum_local += vri
        nei_prod = DU.neighbor_product(vri)
        if (neighbor_prod_sum is None):
            neighbor_prod_sum = nei_prod['p']
            neighbor_prod_sum += nei_prod['p']
    re = {'sum': sum_local, 'neighbor_prod_sum': neighbor_prod_sum, 'shift': nei_prod['shift'], }
    re_key = self.cache.save_tmp_data(re, fn_id=self.task.task_id)
    assert (re_key is not None)
    return re_key

def data_matrix_collect__parallel(self, data_json, segmentation_tg_op, normalize, n_chunk, voxel_mask_inds=None):
    start_time = time.time()
    data_json_copy = [_ for _ in data_json]
    inds = list(range(len(data_json_copy)))
    tasks = []
    while data_json_copy:
        data_json_copy_t = data_json_copy[:n_chunk]
        inds_t = inds[:n_chunk]
        tasks.append(self.runner.task(module='tomominer.pursuit.multi.util', method='data_matrix_collect__local', kwargs={'data_json': data_json_copy_t, 'segmentation_tg_op': segmentation_tg_op, 'normalize': normalize, 'inds': inds_t, 'voxel_mask_inds': voxel_mask_inds, }))
        data_json_copy = data_json_copy[n_chunk:]
        inds = inds[n_chunk:]
    red = None
    for res in self.runner.run__except(tasks):
        with open(res.result) as f:
            re = pickle.load(f)
        if (red is None):
            red = N.zeros([len(data_json), re['mat'].shape[1]])
        red[re['inds'], :] = re['mat']
    print('Calculated matrix of', len(data_json), 'subtomograms', ('%2.6f sec' % (time.time() - start_time)))
    return red

def data_matrix_collect__local(self, data_json, inds, segmentation_tg_op, normalize, voxel_mask_inds=None):
    mat = None
    for (i, rec) in enumerate(data_json):
        if ('template' not in rec):
            rec['template'] = None
        vi = impute_vol_keys(vk=rec, ang=rec['angle'], loc=rec['loc'], tk=rec['template'], align_to_template=True, normalize=normalize, cache=self.cache)['vi']
        if ((segmentation_tg_op is not None) and (rec['template'] is not None) and ('segmentation' in rec['template'])):
            phi = IV.read_mrc(rec['template']['segmentation'])['value']
            vi_s = template_guided_segmentation(v=vi, m=(phi > 0.5), op=segmentation_tg_op)
            if (vi_s is not None):
                vi = vi_s
                del vi_s
                assert (normalize is not None)
                if normalize:
                    vi_t = N.zeros(vi.shape)
                    vi_f = N.isfinite(vi)
                    vi_t[vi_f] = ((vi[vi_f] - vi[vi_f].mean()) / vi[vi_f].std())
                    vi = vi_t
                    del vi_f, vi_t
                    vi_t = N.zeros(vi.shape)
                    vi_f = N.isfinite(vi)
                    vi_t[vi_f] = vi[vi_f]
                    vi_t[N.logical_not(vi_f)] = vi[vi_f].mean()
                    vi = vi_t
                    del vi_f, vi_t
        vi = vi.flatten()
        if (voxel_mask_inds is not None):
            vi = vi[voxel_mask_inds]
        if (mat is None):
            mat = N.zeros([len(data_json), vi.size])
        mat[i, :] = vi
    re = {'mat': mat, 'inds': inds, }
    re_key = self.cache.save_tmp_data(re, fn_id=self.task.task_id)
    assert (re_key is not None)
    return re_key

def covariance_filtered_pca(self, data_json_model=None, data_json_embed=None, normalize=None, segmentation_tg_op=None, n_chunk=100, max_feature_num=None, pca_op=None):
    print('Dimension reduction')
    start_time = time.time()
    if (data_json_model is None):
        assert (data_json_embed is not None)
        data_json_model = data_json_embed
        data_json_embed = None
    cov_avg = None
    cov_avg__feature_num_cutoff = None
    if ((max_feature_num == None) or (max_feature_num < 0)):
        voxel_mask_inds = None
        cov_avg = neighbor_covariance_avg__parallel(self=self, data_json=data_json_model, segmentation_tg_op=segmentation_tg_op, normalize=normalize, n_chunk=n_chunk)
        cov_avg_max = cov_avg.max()
        if ((not N.isfinite(cov_avg_max)) or (cov_avg_max <= 0)):
            raise Exception(('cov_avg.max(): ' + repr(cov_avg_max)))
        cov_avg_i = N.argsort((- cov_avg), axis=None)
        cov_avg__feature_num_cutoff = cov_avg.flatten()[cov_avg_i[min(max_feature_num, (cov_avg_i.size - 1))]]
        cov_avg__feature_num_cutoff = max(cov_avg__feature_num_cutoff, 0)
        voxel_mask_inds = N.flatnonzero((cov_avg > cov_avg__feature_num_cutoff))
    mat = data_matrix_collect__parallel(self=self, data_json=data_json_model, segmentation_tg_op=segmentation_tg_op, normalize=normalize, n_chunk=n_chunk, voxel_mask_inds=voxel_mask_inds)
    mat_mean = mat.mean(axis=0)
    for i in range(mat.shape[0]):
        mat[i, :] -= mat_mean
    empca_weight = N.zeros(mat.shape, dtype=float)
    empca_weight[N.isfinite(mat)] = 1.0
    if (cov_avg is not None):
        cov_avg_v = cov_avg.flatten()
        for (i, ind_t) in enumerate(voxel_mask_inds):
            empca_weight[:, i] *= cov_avg_v[ind_t]
    import aitom.tomominer.dimension_reduction.empca as drempca
    pca = drempca.empca(data=mat, weights=empca_weight, nvec=pca_op['n_dims'], niter=pca_op['n_iter'])
    if ((data_json_embed is None) or ()):
        red = pca.coeff
        mat_embed = data_matrix_collect__parallel(self=self, data_json=data_json_embed, segmentation_tg_op=segmentation_tg_op, normalize=normalize, n_chunk=n_chunk, voxel_mask_inds=voxel_mask_inds)
        red = N.dot(mat_embed, pca.eigvec.T)
    print(('PCA with covariange thresholding  : %2.6f sec' % (time.time() - start_time)))
    return {'red': red, 'cov_avg': cov_avg, 'cov_avg__feature_num_cutoff': cov_avg__feature_num_cutoff, 'voxel_mask_inds': voxel_mask_inds, }

def labels_to_clusters(data_json, labels, cluster_mode=None, ignore_negative_labels=True):
    clusters = {}
    for (l, d) in zip(labels, data_json):
        if (ignore_negative_labels and (l < 0)):
        if (l not in clusters):
            clusters[l] = {}
        if ('cluster_mode' not in clusters[l]):
            clusters[l]['cluster_mode'] = cluster_mode
        if ('data_json' not in clusters[l]):
            clusters[l]['data_json'] = []
    return clusters

def kmeans_clustering(x, k):
    from sklearn.cluster import KMeans
    if False:
        import multiprocessing
        km = KMeans(n_clusters=k, n_jobs=multiprocessing.cpu_count())
        km = KMeans(n_clusters=k)
    labels = km.fit_predict(x)
    labels_t = (N.zeros(labels.shape, dtype=N.int) + N.nan)
    label_count = 0
    for l in N.unique(labels):
        labels_t[(labels == l)] = label_count
        label_count += 1
    labels = labels_t.astype(N.int)
    return labels

def cluster_ssnr_fsc(self, clusters, n_chunk, op=None):
    start_time = time.time()
    sps = SS.ssnr_parallel(self=self, clusters=clusters, n_chunk=n_chunk, op=op)
    return sps

def vol_avg__local(self, data_json, op=None, return_key=True):
    vol_sum = None
    mask_sum = None
    for rec in data_json:
        if self.work_queue.done_tasks_contains(self.task.task_id):
            raise Exception('Duplicated task')
        if (op['with_missing_wedge'] or op['use_fft']):
            in_re = impute_vol_keys(vk=rec, ang=rec['angle'], loc=rec['loc'], tk=None, align_to_template=True, normalize=False, cache=self.cache)
            vt = in_re['v_r']
            raise Exception('following options need to be re-considered')
            if ('template' not in rec):
                rec['template'] = None
            in_re = impute_vol_keys(vk=rec, ang=rec['angle'], loc=rec['loc'], tk=rec['template'], align_to_template=True, normalize=True, cache=self.cache)
            vt = in_re['vi']
        if (vol_sum is None):
            vol_sum = N.zeros(vt.shape, dtype=N.float64, order='F')
        vol_sum += vt
        if (mask_sum is None):
            mask_sum = N.zeros(in_re['vm_r'].shape, dtype=N.float64, order='F')
        mask_sum += in_re['vm_r']
    re = {'vol_sum': vol_sum, 'mask_sum': mask_sum, 'vol_count': len(data_json), 'op': op, }
    if return_key:
        re_key = self.cache.save_tmp_data(re, fn_id=self.task.task_id)
        assert (re_key is not None)
        return {'key': re_key, }
        return re

def cluster_averaging_vols(self, clusters, op={}):
    start_time = time.time()
    if op['centerize_loc']:
        clusters_cnt = copy.deepcopy(clusters)
        for c in clusters_cnt:
            loc = N.zeros((len(clusters_cnt[c]), 3))
            for (i, rec) in enumerate(clusters_cnt[c]):
                loc[i, :] = N.array(rec['loc'], dtype=N.float)
            loc -= N.tile(loc.mean(axis=0), (loc.shape[0], 1))
            assert N.all((N.abs(loc.mean(axis=0)) <= 1e-10))
            for (i, rec) in enumerate(clusters_cnt[c]):
                rec['loc'] = loc[i]
        clusters = clusters_cnt
    tasks = []
    for c in clusters:
        while clusters[c]:
            part = clusters[c][:op['n_chunk']]
            op_t = copy.deepcopy(op)
            op_t['cluster'] = c
            tasks.append(self.runner.task(module='tomominer.pursuit.multi.util', method='vol_avg__local', kwargs={'data_json': part, 'op': op_t, 'return_key': True, }))
            clusters[c] = clusters[c][op['n_chunk']:]
    cluster_sums = {}
    cluster_mask_sums = {}
    cluster_sizes = {}
    for res in self.runner.run__except(tasks):
        with open(res.result['key']) as f:
            re = pickle.load(f)
        oc = re['op']['cluster']
        ms = re['mask_sum']
        s = re['vol_sum']
        vc = re['vol_count']
        if (oc not in cluster_sums):
            cluster_sums[oc] = s
            cluster_sums[oc] += s
        if (oc not in cluster_mask_sums):
            cluster_mask_sums[oc] = ms
            cluster_mask_sums[oc] += ms
        if (oc not in cluster_sizes):
            cluster_sizes[oc] = vc
            cluster_sizes[oc] += vc
        del oc, ms, s, vc
    cluster_avg_dict = {}
    for c in cluster_sums:
        assert (cluster_sizes[c] > 0)
        assert (cluster_mask_sums[c].max() > 0)
        if op['use_fft']:
            ind = (cluster_mask_sums[c] >= op['mask_count_threshold'])
            if (ind.sum() == 0):
            cluster_sums_fft = NF.fftshift(NF.fftn(cluster_sums[c]))
            cluster_avg = N.zeros(cluster_sums_fft.shape, dtype=N.complex)
            cluster_avg[ind] = (cluster_sums_fft[ind] / cluster_mask_sums[c][ind])
            cluster_avg = N.real(NF.ifftn(NF.ifftshift(cluster_avg)))
            cluster_avg = (cluster_sums[c] / cluster_sizes[c])
        if op['mask_binarize']:
            cluster_mask_avg = (cluster_mask_sums[c] >= op['mask_count_threshold'])
            cluster_mask_avg = (cluster_mask_sums[c] / cluster_sizes[c])
        cluster_avg_dict[c] = {'vol': cluster_avg, 'mask': cluster_mask_avg, }
    if ('smooth' in op):
        print('smoothing', op['smooth']['mode'], end=' ')
        for c in cluster_avg_dict:
            if (c not in op['smooth']['fsc']):
            cluster_avg_dict[c]['smooth'] = {'vol-original': cluster_avg_dict[c]['vol'], }
            s = cluster_averaging_vols__smooth(v=cluster_avg_dict[c]['vol'], fsc=op['smooth']['fsc'][c], mode=op['smooth']['mode'])
            cluster_avg_dict[c]['vol'] = s['v']
            if ('fit' in s):
                cluster_avg_dict[c]['smooth']['fit'] = s['fit']
                print('average', c, 'sigma ', cluster_avg_dict[c]['smooth']['fit']['c'], '    ', end=' ')
    return {'cluster_avg_dict': cluster_avg_dict, 'cluster_sums': cluster_sums, 'cluster_mask_sums': cluster_mask_sums, 'cluster_sizes': cluster_sizes, }

def cluster_averaging_vols__smooth(v, fsc, mode):
    re = {}
    assert N.all((fsc >= 0))
    if (fsc.max() == 0.0):
        return {'v': v, }
    if (mode == 'fsc_direct'):
        band_pass_curve = fsc
    elif (mode == 'fsc_gaussian'):
        import aitom.tomominer.fitting.gaussian.one_dim as FGO
        bands = N.array(list(range(len(fsc))))
        fit = FGO.fit__zero_mean__multi_start(x=bands, y=fsc)
        if (fit['c'] < bands.max()):
            re['fit'] = fit
            band_pass_curve = FGO.fit__zero_mean__gaussian_function(x=bands, a=fit['a'], c=fit['c'])
            re['v'] = v
            return re
        raise Exception('mode')
    import aitom.tomominer.filter.band_pass as IB
    re['v'] = IB.filter_given_curve(v=v, curve=band_pass_curve)
    return re

def cluster_averaging(self, clusters, op={}):
    cav = cluster_averaging_vols(self, clusters=clusters, op=op)
    if (not os.path.isdir(op['out_dir'])):
    with open(os.path.join(op['out_dir'], 'cluster.pickle'), 'wb') as f:
        pickle.dump(cav, f, protocol=(-1))
    cluster_avg_dict = cav['cluster_avg_dict']
    template_keys = {}
    for c in cluster_avg_dict:
        template_keys[c] = {}
        template_keys[c]['cluster'] = c
        vol_avg_out_key = os.path.join(op['out_dir'], ('clus_vol_avg_%03d.mrc' % c))
        IV.put_mrc(N.array(cluster_avg_dict[c]['vol'], order='F'), vol_avg_out_key)
        template_keys[c]['subtomogram'] = vol_avg_out_key
        if ('smooth' in cluster_avg_dict[c]):
            vol_avg_original_out_key = os.path.join(op['out_dir'], ('clus_vol_avg_original_%03d.mrc' % c))
            IV.put_mrc(N.array(cluster_avg_dict[c]['smooth']['vol-original'], order='F'), vol_avg_original_out_key)
            template_keys[c]['subtomogram-original'] = vol_avg_original_out_key
            vol_avg__riginal_out_key = None
        mask_avg_out_key = os.path.join(op['out_dir'], ('clus_mask_avg_%03d.mrc' % c))
        IV.put_mrc(N.array(cluster_avg_dict[c]['mask'], order='F'), mask_avg_out_key)
        template_keys[c]['mask'] = mask_avg_out_key
        if ('pass_i' in op):
            template_keys[c]['pass_i'] = op['pass_i']
    return {'template_keys': template_keys, }

def cluster_average_select_fsc(self, cluster_info, cluster_info_stat, op=None, debug=False):
    ci = []
    for i in cluster_info:
        for c in cluster_info[i]:
            if ('fsc' not in cluster_info[i][c]):
            if (len(cluster_info[i][c]['data_json']) < op['cluster']['size_min']):
            if ((not op['keep_non_specific_clusters']) and ('is_specific' in cluster_info_stat[i][c]) and (cluster_info_stat[i][c]['is_specific'] is not None)):
    ci = sorted(ci, key=(lambda x: float((- x['fsc'].sum()))))
    ci_cover = []
    covered_set = set()
    for ci_t in ci:
        if ('template_key' not in ci_t):
        subtomograms_t = set((_['subtomogram'] for _ in ci_t['data_json']))
        overlap_ratio_t = (float(len(covered_set.intersection(subtomograms_t))) / len(subtomograms_t))
        if (overlap_ratio_t <= op['cluster']['overlap_ratio']):
        if debug:
            print(ci_t['pass_i'], ci_t['cluster'], len(ci_t['data_json']), ci_t['fsc'].sum(), overlap_ratio_t)
    del ci
    print('Set sizes', sorted([len(_['data_json']) for _ in ci_cover]))
    tk = {}
    for (i, cc) in enumerate(ci_cover):
        tk[i] = copy.deepcopy(cc['template_key'])
        tk[i]['id'] = i
        assert (tk[i]['pass_i'] == cc['pass_i'])
        assert (tk[i]['cluster'] == cc['cluster'])
    tk_selected = set((tk[_]['subtomogram'] for _ in tk))
    tk_info = {}
    for i in cluster_info:
        for c in cluster_info[i]:
            if ('template_key' not in cluster_info[i][c]):
            tk_subtomogram = cluster_info[i][c]['template_key']['subtomogram']
            if (tk_subtomogram not in tk_selected):
            tk_info[tk_subtomogram] = cluster_info[i][c]
    if op['keep_non_specific_clusters']:
        for (i, tkt) in tk.items():
            cluster_info_stat[tkt['pass_i']][tkt['cluster']]['is_specific'] = None
    assert (len(tk) > 0)
    return {'selected_templates': tk, 'tk_info': tk_info, }

def cluster_removal_according_to_center_matching_specificity(ci, cis, al, tk, significance_level, test_type=0, test_sample_num_min=10):
    tk = copy.deepcopy(tk)
    tkd = {tk[_]['subtomogram']: _ for _ in tk}
    tk_fsc = {}
    for pass_i in ci:
        for c in ci[pass_i]:
            ci0 = ci[pass_i][c]
            if ('template_key' not in ci0):
            tk0 = ci0['template_key']['subtomogram']
            if (tk0 not in tkd):
            tk_fsc[tk0] = ci0['fsc'].sum()
    non_specific_clusters = []
    wilcoxion_stat = defaultdict(dict)
    for pass_i in ci:
        for ci_c0 in ci[pass_i]:
            ci0 = ci[pass_i][ci_c0]
            cis0 = cis[pass_i][ci_c0]
            if ('is_specific' not in cis0):
                cis0['is_specific'] = None
            if (cis0['is_specific'] is not None):
            if ('template_key' not in ci0):
            tk0 = ci0['template_key']['subtomogram']
            if (tk0 not in tkd):
            c0 = tkd[tk0]
            ci0s = set((str(_['subtomogram']) for _ in ci0['data_json']))
            al0 = [_ for _ in al if (_['vol_key']['subtomogram'] in ci0s)]
            ss = {}
            for c1 in tk:
                tk1 = tk[c1]['subtomogram']
                if (tk1 not in tkd):
                if (tk_fsc[tk1] < tk_fsc[tk0]):
                ss[c1] = N.array([_['align'][c1]['score'] for _ in al0])
            for c1 in ss:
                if (c1 == c0):
                tk1 = tk[c1]['subtomogram']
                assert (tk1 is not tk0)
                assert (tk1 in tkd)
                assert (tk_fsc[tk1] > tk_fsc[tk0])
                ind_t = N.logical_and(N.isfinite(ss[c0]), N.isfinite(ss[c1]))
                if (ind_t.sum() < test_sample_num_min):
                if N.all((ss[c0][ind_t] > ss[c1][ind_t])):
                is_specific = None
                if (test_type == 0):
                    (t_, p_) = SCS.wilcoxon(ss[c1][ind_t], ss[c0][ind_t])
                    if (p_ > significance_level):
                        is_specific = {'tk0': tk[c0], 'tk1': tk[c1], 'stat': {'t': t_, 'p': p_, }, }
                    elif (N.median(ss[c1][ind_t]) > N.median(ss[c0][ind_t])):
                        is_specific = {'tk0': tk[c0], 'tk1': tk[c1], 'stat': {'t': t_, 'p': p_, }, }
                elif (test_type == 1):
                    (t_, p_) = SCS.mannwhitneyu(ss[c1][ind_t], ss[c0][ind_t])
                    if (p_ >= significance_level):
                        is_specific = {'tk0': tk[c0], 'tk1': tk[c1], 'stat': {'t': t_, 'p': p_, }, }
                    raise AttributeError('test_type')
                wilcoxion_stat[c0][c1] = {'t': t_, 'p': p_, 'median_c0': N.median(ss[c0][ind_t]), 'median_c1': N.median(ss[c1][ind_t]), }
                if (is_specific is not None):
                    cis0['is_specific'] = is_specific
                    del tkd[tk0]
    none_specific_cluster_ids = []
    for c in list(tk.keys()):
        if (tk[c]['subtomogram'] in tkd):
        del tk[c]
    for al_ in al:
        best = {}
        best['score'] = (- N.inf)
        best['template_id'] = None
        best['angle'] = (N.random.random(3) * (N.pi * 2))
        best['loc'] = N.zeros(3)
        for c in tk:
            al_c = al_['align'][c]
            if (al_c['score'] > best['score']):
                best['score'] = al_c['score']
                best['angle'] = al_c['angle']
                best['loc'] = al_c['loc']
                best['template_id'] = c
        al_['best'] = best
    print(len(non_specific_clusters), 'redundant averages detected', none_specific_cluster_ids)
    return {'non_specific_clusters': non_specific_clusters, 'wilcoxion_stat': wilcoxion_stat, }

def cluster_average_align_common_frame__pairwise_alignment(self, template_keys, align_op):
    template_keys_inv = {}
    for c in template_keys:
        template_keys_inv[template_keys[c]['subtomogram']] = c
    tasks = []
    for c0 in template_keys:
        for c1 in template_keys:
            if (c1 <= c0):
            tasks.append(self.runner.task(module='tomominer.pursuit.multi.util', method='align_keys', kwargs={'v1k': template_keys[c0], 'v2k': template_keys[c1], 'op': align_op, }))
    pair_align = defaultdict(dict)
    for res_t in self.runner.run__except(tasks):
        res = res_t.result
        c0 = template_keys_inv[res['v1_key']['subtomogram']]
        c1 = template_keys_inv[res['v2_key']['subtomogram']]
        pair_align[c0][c1] = res
        pair_align[c0][c1].update({'c0': c0, 'c1': c1, })
        pair_align[c1][c0] = copy.deepcopy(res)
        (ang_rev, loc_rev) = AAL.reverse_transform_ang_loc(ang=res['angle'], loc_r=res['loc'])
        pair_align[c1][c0].update({'angle': ang_rev, 'loc': loc_rev, 'c0': c1, 'c1': c0, })
        if (res['err'] is not None):
            print(('cluster_average_align_common_frame__pairwise() alignment error ' + repr(res['err'])))
    return pair_align

def cluster_average_align_common_frame__multi_pair(self, tk, align_op, loc_r_max, pass_dir):
    print('align averages to common frames')
    out_dir = os.path.join(pass_dir, 'common_frame')
    if (not os.path.isdir(out_dir)):
    pa = cluster_average_align_common_frame__pairwise_alignment(self=self, template_keys=tk, align_op=align_op)
    pal = []
    for c0 in pa:
        for c1 in pa[c0]:
            if (c0 >= c1):
            if (N.linalg.norm(pa[c0][c1]['loc']) > loc_r_max):
    pal = sorted(pal, key=(lambda _: (- _['score'])))
    align_to_clus = {}
    unrotated_clus = set()
    for i in range(len(pal)):
        palt = pal[i]
        c0 = palt['c0']
        c1 = palt['c1']
        assert (c0 < c1)
        if (c1 in unrotated_clus):
        if (c0 in align_to_clus):
        if (c1 in align_to_clus):
        align_to_clus[c1] = c0
    assert (len(unrotated_clus.intersection(list(align_to_clus.keys()))) == 0)
    tka = {}
    for c in tk:
        tka[c] = {}
        tka[c]['id'] = c
        tka[c]['pass_i'] = tk[c]['pass_i']
        tka[c]['cluster'] = tk[c]['cluster']
        tka[c]['subtomogram'] = os.path.join(out_dir, ('clus_vol_avg_%03d.mrc' % (c,)))
        tka[c]['mask'] = os.path.join(out_dir, ('clus_mask_avg_%03d.mrc' % (c,)))
        if (c in align_to_clus):
            rotate_key(self=self, vk=tk[c], vk_out=tka[c], angle=pa[align_to_clus[c]][c]['angle'], loc=pa[align_to_clus[c]][c]['loc'])
            print(align_to_clus[c], '-', c, ':', ('%0.3f' % (pa[align_to_clus[c]][c]['score'],)), N.linalg.norm(pa[align_to_clus[c]][c]['loc']), '\t', end=' ')
            shutil.copyfile(tk[c]['subtomogram'], tka[c]['subtomogram'])
            shutil.copyfile(tk[c]['mask'], tka[c]['mask'])
            print(('copy(%d)' % (c,)), '\t', end=' ')
    return {'tka': tka, 'unrotated_clus': unrotated_clus, 'align_to_clus': align_to_clus, 'pa': pa, 'pal': pal, }

def template_segmentation__single(c, tk, op):
    out_path = os.path.join(os.path.dirname(tk['subtomogram']), ('clus_vol_seg_phi_%03d.mrc' % (c,)))
    v = IV.read_mrc(tk['subtomogram'])['value']
    phi = template_segmentation__single_vol(v=v, op=op)
    if (phi is not None):
        tk['segmentation'] = out_path
        IV.put_mrc(phi, out_path, overwrite=True)
    return {'c': c, 'tk': tk, }

def template_segmentation__single_vol(v, op):
    if (not op['density_positive']):
        v = (- v)
    del op['density_positive']
    if (('normalize_and_take_abs' in op) and op['normalize_and_take_abs']):
        v -= v.mean()
        v = N.abs(v)
    if ('gaussian_smooth_sigma' in op):
        vg = FG.smooth(v=v, sigma=float(op['gaussian_smooth_sigma']))
        del op['gaussian_smooth_sigma']
        vg = v
    phi = SACS.segment_with_postprocessing(vg, op)
    if (phi is None):
        sys.stderr.write((('Warning: segmentation failed ' + out_path) + '\n'))
    if (phi is not None):
        bm = MU.boundary_mask(phi.shape)
        if (bm[(phi < 0)].sum() < bm[(phi > 0)].sum()):
            phi = None
            sys.stderr.write((('Warning: segmentation of the following cluster average violates boundary condition ' + out_path) + '\n'))
    return phi

def template_segmentation(self, tk, op, multiprocessing=False):
    if multiprocessing:
        if (self.pool is None):
            self.pool = Pool()
        pool_results = [self.pool.apply_async(func=template_segmentation__single, kwds={'c': c, 'tk': tk[c], 'op': op, }) for c in tk]
        for r in pool_results:
            r = r.get(999999)
            tk[r['c']] = r['tk']
        self.pool = None
        for c in tk:
            r = template_segmentation__single(c=c, tk=copy.deepcopy(tk[c]), op=copy.deepcopy(op))
            tk[r['c']] = r['tk']

def template_guided_segmentation(v, m, op):
    op = copy.deepcopy(op)
    v_org = N.copy(v)
    if (not op['density_positive']):
        v = (- v_org)
    del op['density_positive']
    if ('gaussian_smooth_sigma' in op):
        vg = FG.smooth(v=v, sigma=float(op['gaussian_smooth_sigma']))
        del op['gaussian_smooth_sigma']
        vg = v
    if ((m > 0.5).sum() == 0):
    if ((m < 0.5).sum() == 0):
    op['mean_values'] = [vg[(m < 0.5)].mean(), vg[(m > 0.5)].mean()]
    phi = SACS.segment_with_postprocessing(vg, op)
    if (phi is None):
    if ((phi > 0).sum() == 0):
    if ((phi < 0).sum() == 0):
    struc = N.array((phi > 0), dtype=N.int32, order='F')
    mcr = core.connected_regions(struc)
    struc_tem = N.zeros(vg.shape)
    for l in range(1, (mcr['max_lbl'] + 1)):
        if (N.logical_and((mcr['lbl'] == l), m).sum() > 0):
            struc_tem[(mcr['lbl'] == l)] = 1
            struc_tem[(mcr['lbl'] == l)] = 2
    if ((struc_tem == 1).sum() == 0):
    sws = SW.segment(vol_map=phi, vol_lbl=struc_tem)
    seg = (sws['vol_seg_lbl'] == 1)
    if (seg.sum() == 0):
    seg = N.logical_and((phi > (phi[seg].max() * op['phi_propotion_cutoff'])), seg)
    if (seg.sum() == 0):
    vs = (N.zeros(v.shape) + N.nan)
    vs[seg] = v_org[seg]
    return vs

def align_to_templates__pair_align(c, t_key, v, vm, align_op):
    if align_op['with_missing_wedge']:
        t = IV.get_mrc(t_key['subtomogram'])
        tm = IV.get_mrc(t_key['mask'])
        at_re = align_vols_with_wedge(v1=t, m1=tm, v2=v, m2=vm, op=align_op)
        t = IV.get_mrc(t_key['subtomogram'])
        at_re = align_vols_no_wedge(v1=t, v2=vi, op=align_op)
    at_re['c'] = c
    return at_re

def align_to_templates(self, rec=None, segmentation_tg_op=None, tem_keys=None, template_wedge_cutoff=0.1, align_op=None, multiprocessing=False):
    vi = None
    if align_op['with_missing_wedge']:
        v = self.cache.get_mrc(rec['subtomogram'])
        vm = self.cache.get_mrc(rec['mask'])
        raise Exception('following options are need to be doube checked')
        if ('template' not in rec):
            rec['template'] = None
        vi = impute_vol_keys(vk=rec, ang=rec['angle'], loc=rec['loc'], tk=rec['template'], align_to_template=False, normalize=True, cache=self.cache)['vi']
    if ((segmentation_tg_op is not None) and ('template' in rec) and ('segmentation' in rec['template'])):
        v = align_to_templates__segment(rec=rec, v=v, segmentation_tg_op=segmentation_tg_op)['v']
    if multiprocessing:
        if (self.pool is None):
            self.pool = Pool()
        pool_results = [self.pool.apply_async(func=align_to_templates__pair_align, kwds={'c': c, 't_key': tem_keys[c], 'v': v, 'vm': vm, 'align_op': align_op, }) for c in tem_keys]
        align_re = {}
        for r in pool_results:
            at_re = r.get(999999)
            c = at_re['c']
            align_re[c] = at_re
            if N.isnan(align_re[c]['score']):
                if (self.logger is not None):
                    self.logger.warning('alignment failed: rec %s, template %s, error %s ', repr(rec), repr(tem_keys[c]), repr(align_re[c]['err']))
        self.pool = None
        align_re = {}
        for c in tem_keys:
            if self.work_queue.done_tasks_contains(self.task.task_id):
                raise Exception('Duplicated task')
            align_re[c] = align_to_templates__pair_align(c=c, t_key=tem_keys[c], v=v, vm=vm, align_op=align_op)
            if N.isnan(align_re[c]['score']):
                if (self.logger is not None):
                    self.logger.warning('alignment failed: rec %s, template %s, error %s ', repr(rec), repr(tem_keys[c]), repr(align_re[c]['err']))
    return {'vol_key': rec, 'align': align_re, }

def align_to_templates__segment(rec, v, segmentation_tg_op):
    phi = IV.read_mrc_vol(rec['template']['segmentation'])
    phi_m = (phi > 0.5)
    (ang_inv, loc_inv) = AAL.reverse_transform_ang_loc(rec['angle'], rec['loc'])
    phi_mr = GR.rotate(phi_m, angle=ang_inv, loc_r=loc_inv, default_val=0)
    v_s = template_guided_segmentation(v=v, m=phi_mr, op=segmentation_tg_op)
    if ((v_s is not None) and (v_s[N.isfinite(v_s)].std() > 0)):
        v = v_s
        del v_s
        v_t = N.zeros(v.shape)
        v_f = N.isfinite(v)
        v_t[v_f] = ((v[v_f] - v[v_f].mean()) / v[v_f].std())
        v = v_t
        del v_f, v_t
    return {'v': v, 'phi_m': phi_m, 'phi_mr': phi_mr, }

def align_to_templates__batch(self, op, data_json, segmentation_tg_op, tmp_dir, tem_keys):
    if (('template' in op) and ('match' in op['template']) and ('priority' in op['template']['match'])):
        task_priority = op['template']['match']['priority']
        task_priority = (2000 + N.random.randint(100))
    print('align against templates', 'segmentation_tg_op', (segmentation_tg_op if op['template']['match']['use_segmentation_mask'] else None), 'task priority', task_priority)
    at_ress = []
    for f in os.listdir(tmp_dir):
        if (not f.endswith('.pickle')):
        res_file = os.path.join(tmp_dir, f)
        if (not os.path.isfile(res_file)):
        with open(res_file, 'rb') as f:
            at_ress_t = pickle.load(f)
    if (len(at_ress) > 0):
        print('loaded previous', len(at_ress), ' resutlts')
    completed_subtomogram_set = set([_.result['vol_key']['subtomogram'] for _ in at_ress])
    tasks = []
    for rec in data_json:
        if (rec['subtomogram'] in completed_subtomogram_set):
        tasks.append(self.runner.task(priority=task_priority, module='tomominer.pursuit.multi.util', method='align_to_templates', kwargs={'rec': rec, 'segmentation_tg_op': (segmentation_tg_op if op['template']['match']['use_segmentation_mask'] else None), 'tem_keys': tem_keys, 'align_op': op['align'], 'multiprocessing': False, }))
    for at_ress_t in self.runner.run__except(tasks):
        res_file = os.path.join(tmp_dir, ('%s.pickle' % at_ress_t.task_id))
        with open(res_file, 'wb') as f:
            pickle.dump(at_ress_t, f, protocol=0)
    return at_ress

def cluster_formation_alignment_fsc__by_global_maximum(self, dj, op=None):
    if ('debug' not in op):
        op['debug'] = False
    dj = copy.deepcopy(dj)
    djm = defaultdict(list)
    for r in dj:
        if ('template' not in r):
    djm_org = copy.deepcopy(djm)
    for k in djm:
        djmt = djm[k]
        djmt = sorted(djmt, key=(lambda _: float(_['score'])), reverse=True)
        if (('max_expansion_size' in op) and (len(djmt) > op['max_expansion_size'])):
            djmt = djmt[:op['max_expansion_size']]
        djm[k] = djmt
    ssnr_sequential_op = copy.deepcopy(op['ssnr_sequential'])
    ssnr_sequential_op['n_chunk'] = op['n_chunk']
    ssnr_s = SS.ssnr_sequential_parallel(self=self, data_json_dict=djm, op=ssnr_sequential_op)
    fsc_sum = {}
    for k in ssnr_s:
        fsc_sum[k] = N.array([N.sum(_) for _ in ssnr_s[k]['fsc']])
    import scipy.ndimage.filters as SDF
    if ('gaussian_smooth_sigma' in op):
        for k in fsc_sum:
            fsc_sum[k] = SDF.gaussian_filter1d(fsc_sum[k], op['gaussian_smooth_sigma'])
    if ('min_expansion_size' in op):
        for k in copy.deepcopy(list(fsc_sum.keys())):
            if (len(fsc_sum[k]) < op['min_expansion_size']):
                del fsc_sum[k]
            fsc_sum[k][:op['min_expansion_size']] = (N.min(fsc_sum[k]) - 1)
    dj_gm = {}
    for k in fsc_sum:
        i = N.argmax(fsc_sum[k])
        if op['debug']:
            print('template', k, 'original subtomogram num', len(djm_org[k]), 'global maximum', i)
        dj_gm[k] = {'k': k, 'i': i, 'data_json': copy.deepcopy(djm[k][:(i + 1)]), 'fsc': ssnr_s[k]['fsc'][i], 'fsc_sum': fsc_sum[k][i], }
    return {'dj_gm': dj_gm, 'djm': djm, 'ssnr_s': ssnr_s, }