#!/usr/bin/env python # coding: utf-8 import os import sys import numpy as np import collections as cl import argparse as ap from tqdm import tqdm as tqdm import time import plotly.offline as py import plotly.graph_objs as go def _pool(fname): sents = { idx: cl.Counter(l.strip().split()) for idx, l in enumerate(open(fname)) } total = np.array([sum(sents[idx].values()) for idx in range(len(sents))]) shared = np.array([0] * len(sents)) mask = np.array([False] * len(sents)) return dict(sents=sents, total=total, shared=shared, mask=mask) def _sample_k(pool, k, large=True): shared = np.ma.array(pool['shared'] / pool['total'], mask=pool['mask']) return shared.argsort(fill_value=0)[-k:] if large else shared.argsort( fill_value=np.inf)[:k] def _shared_vocab(src_sents, trg_sents): src_vocab = set.union(*[set(s.keys()) for s in src_sents.values()]) trg_vocab = set.union(*[set(s.keys()) for s in trg_sents.values()]) return src_vocab & trg_vocab def _update_pool(pool, idxs, vocab): # change mask for idx in idxs: pool['mask'][idx] = True # mark shared vocab for idx in pool['sents']: for v in vocab: if v in pool['sents'][idx]: pool['shared'][idx] += pool['sents'][idx].pop(v) def _r(src_pool, trg_pool): src_shared = np.ma.array(src_pool['shared'], mask=~src_pool['mask']) src_total = np.ma.array(src_pool['total'], mask=~src_pool['mask']) trg_shared = np.ma.array(trg_pool['shared'], mask=~trg_pool['mask']) trg_total = np.ma.array(trg_pool['total'], mask=~trg_pool['mask']) return (src_shared.sum() + trg_shared.sum()) / ( src_total.sum() + trg_total.sum()) def _draw(rs, steps, vocabs, filename): t_r = go.Scatter(x=steps, y=rs, yaxis='y2', name='token sharing rate') t_num_vocabs = go.Bar( y=[len(v) for v in vocabs], x=steps, text=vocabs, name='vocabs') data = [t_r, t_num_vocabs] layout = go.Layout( title='token rate sampling progress', yaxis=dict(title='num of vocabs', type='log'), yaxis2=dict( title='sampled token sharing rate', overlaying='y', side='right')) fig = go.Figure(data=data, layout=layout) py.plot(fig, filename=filename, auto_open=False) def _sample(src_pool, trg_pool, n, r, k=1, draw=''): src_sents = { idx: src_pool['sents'][idx] for idx in range(len(src_pool['mask'])) if src_pool['mask'][idx] } trg_sents = { idx: trg_pool['sents'][idx] for idx in range(len(trg_pool['mask'])) if trg_pool['mask'][idx] } shared_vocab = _shared_vocab( src_sents, trg_sents) if len(src_sents) and len(trg_sents) else set() shared_vocab_step = set() current_r = 0 t = tqdm( total=n, file=sys.stdout, initial=src_pool['mask'].sum(), ncols=100) if draw: d_rs = [current_r] d_steps = [t.n] d_shared_vocabs = [shared_vocab] _draw(d_rs, d_steps, d_shared_vocabs, filename=draw) while not (t.n > n and abs(current_r - r) < 0.001) and t.n < min( len(src_pool['sents']), len(src_pool['sents'])): large = True if current_r < r else False src_idxs = _sample_k(src_pool, k, large) trg_idxs = _sample_k(trg_pool, k, large) src_sents = {idx: src_pool['sents'][idx] for idx in src_idxs} trg_sents = {idx: trg_pool['sents'][idx] for idx in trg_idxs} # pdb.set_trace() shared_vocab_step = _shared_vocab(src_sents, trg_sents) shared_vocab_step = shared_vocab_step.difference(shared_vocab) shared_vocab.update(shared_vocab_step) _update_pool(src_pool, src_idxs, shared_vocab_step) _update_pool(trg_pool, trg_idxs, shared_vocab_step) # pdb.set_trace() current_r = _r(src_pool, trg_pool) d_rs.append(current_r) d_steps.append(t.n) d_shared_vocabs.append(shared_vocab_step) if t.n / k % 20 == 0: _draw(d_rs, d_steps, d_shared_vocabs, filename=draw) desc = 'v: {} | new v: {} | r: {:.2%} '.format( len(shared_vocab), len(shared_vocab_step), current_r) t.desc = desc t.update(k) t.close() def _save(pool, in_fname, out_fname): with open(out_fname, 'wt') as fout: for idx, l in enumerate(open(in_fname)): if pool['mask'][idx]: fout.write(l) def main(args): src_pool = _pool(args.src_fname) trg_pool = _pool(args.trg_fname) _sample(src_pool, trg_pool, args.n, args.r, args.k, draw=args.draw) _save(src_pool, args.src_fname, args.src_output) _save(trg_pool, args.trg_fname, args.trg_output) if __name__ == '__main__': sample_parser = ap.ArgumentParser() sample_parser.add_argument('src_fname', type=str, help='source file name.') sample_parser.add_argument('trg_fname', type=str, help='target file name.') sample_parser.add_argument( '-n', type=int, help= 'num of sampled sentences. should not larger than num of lines in either files.' ) sample_parser.add_argument( '-r', type=float, help='the target share token rate for sampling.') sample_parser.add_argument( '-k', type=int, help='num of sents extracted for each sample step.') sample_parser.add_argument( '-d', '--draw', type=str, help='if given, draw a graph of sampling process. should end with .html' ) sample_parser.add_argument( '--src_output', type=str, default='src_sampled.txt', help='source output filename.') sample_parser.add_argument( '--trg_output', type=str, default='trg_sampled.txt', help='target output filename.') args = sample_parser.parse_args() main(args)