#!/usr/bin/env python # coding: utf-8 import plotly.offline as py import plotly.graph_objs as go import numpy as np import collections as cl import itertools as it from tqdm import tqdm import os import argparse as ap def _draw_scatter(all_vocabs, all_freqs, output_prefix): colors = [(s and t) and (s < t and s / t or t / s) or 0 for s, t in all_freqs] colors = [c and np.log(c) or 0 for c in colors] trace = go.Scattergl( x=[s for s, t in all_freqs], y=[t for s, t in all_freqs], mode='markers', text=all_vocabs, marker=dict(color=colors, showscale=True, colorscale='Viridis')) layout = go.Layout( title='Scatter plot of shared tokens', hovermode='closest', xaxis=dict(title='src freq', type='log', autorange=True), yaxis=dict(title='trg freq', type='log', autorange=True)) fig = go.Figure(data=[trace], layout=layout) py.plot( fig, filename='{}_scatter.html'.format(output_prefix), auto_open=False) def _draw_rate(all_vocabs, all_freqs, output_prefix): biases = np.array( [(s and t) and (s / t if s > t else t / s) or 0 for s, t in all_freqs]) freqs = np.array([s + t for s, t in all_freqs]) hist, bin_edges = np.histogram( biases[biases > 0], weights=freqs[biases > 0], bins=int(max(biases))) bin_centers = bin_edges[:-1] t1 = go.Scatter( x=bin_centers, y=hist, name='num of tokens', mode='lines', fill='tozeroy') share_token_rates = np.cumsum(hist) / sum(freqs) t2 = go.Scatter( x=bin_centers, y=share_token_rates, name='share token rates', mode='lines', yaxis='y2') layout = go.Layout( title='Shared tokens rates', xaxis=dict(title='bias', autorange=True), yaxis=dict(title='num of tokens', type='log', autorange=True), yaxis2=dict( title='accumlative share token rates', autorange=True, side='right', overlaying='y')) fig = go.Figure(data=[t1, t2], layout=layout) py.plot( fig, filename='{}_rate.html'.format(output_prefix), auto_open=False) def main(args): src_freqs = cl.Counter( w for l in tqdm( open(args.src_fname), desc='gen vocab from {}'.format(os.path.basename(args.src_fname))) for w in l.strip().split()) trg_freqs = cl.Counter( w for l in tqdm( open(args.trg_fname), desc='gen vocab from {}'.format(os.path.basename(args.trg_fname))) for w in l.strip().split()) if len(src_freqs) * len(trg_freqs) == 0: return all_vocabs = list(src_freqs.keys() | trg_freqs.keys()) all_freqs = [(src_freqs.get(v, 0), trg_freqs.get(v, 0)) for v in all_vocabs] if args.type == 'scatter': _draw_scatter(all_vocabs, all_freqs, args.output_prefix) elif args.type == 'rate': _draw_rate(all_vocabs, all_freqs, args.output_prefix) elif args.type == 'both': _draw_rate(all_vocabs, all_freqs, args.output_prefix) _draw_scatter(all_vocabs, all_freqs, args.output_prefix) if __name__ == '__main__': draw_parser = ap.ArgumentParser() draw_parser.add_argument( 'src_fname', type=str, help='the source file name.') draw_parser.add_argument( 'trg_fname', type=str, help='the target file name') draw_parser.add_argument( '--type', type=str, choices=['scatter', 'rate', 'both'], help='whether to only draw shared tokens') draw_parser.add_argument( '--output_prefix', default='pref', help='output prefix.') args = draw_parser.parse_args() main(args)