import h5py import math import numpy as np import pandas as pd from pathlib import Path from tqdm import tqdm from bokeh.io import export_png from bokeh.models import Label, Span, Title, ColumnDataSource from bokeh.plotting import figure from argparse import ArgumentParser def get_annotations(path, fields, enum_field): data_labels = {} for field in fields: data_labels[field] = path[field] data_dtypes = {} if h5py.check_dtype(enum=path.dtype[enum_field]): dataset_dtype = h5py.check_dtype(enum=path.dtype[enum_field]) # data_dtype may lose some dataset dtypes there are duplicates of 'v' data_dtypes = {v: k for k, v in dataset_dtype.items()} labels_df = pd.DataFrame(data=data_labels) return labels_df, data_dtypes def create_figure(ch, s, e, sf, file, filename): sq = s * sf eq = e * sf n = e - s ch_str = 'Channel_{n}'.format(n=ch) step = 1 / sf x_data = np.arange(s, e, step) y_data = file["Raw"][ch_str]["Signal"][()] y_data = y_data[sq:eq] d = math.e ** 2.5 thin_factor = math.ceil(n / d) greater_delete_index = np.argwhere(y_data > 1500) x_data = np.delete(x_data, greater_delete_index) y_data = np.delete(y_data, greater_delete_index) lesser_delete_index = np.argwhere(y_data < 0) x_data = np.delete(x_data, lesser_delete_index) y_data = np.delete(y_data, lesser_delete_index) data = { 'x': x_data[::thin_factor], 'y': y_data[::thin_factor], } source = ColumnDataSource(data=data) p = figure( plot_height=965, plot_width=800 ) p.output_backend = 'canvas' p.add_layout(Title( text="Channel: {ch} Start: {st} End: {ed} Sample rate: {sf}".format( ch=ch, st=s, ed=e, sf=sf )), 'above' ) p.add_layout(Title( text="Bulk-file: {s}".format(s=filename.split('/')[-1])), 'above' ) p.toolbar.logo = None p.yaxis.axis_label = "Raw signal" p.yaxis.major_label_orientation = "horizontal" p.xaxis.axis_label = "Time (seconds)" p.line(source=source, x='x', y='y', line_width=1) p.xaxis.major_label_orientation = math.radians(45) p.x_range.range_padding = 0.05 # get annotations path = file["IntermediateData"][ch_str]["Reads"] fields = ['read_id', 'read_start', 'modal_classification'] label_df, label_dt = get_annotations(path, fields, 'modal_classification') label_df = label_df.drop_duplicates(subset=['read_id', 'modal_classification'], keep="first") label_df.read_start = label_df.read_start / sf label_df.read_id = label_df.read_id.str.decode('utf8') path = file["StateData"][ch_str]["States"] fields = ['acquisition_raw_index', 'summary_state'] state_label_df, state_label_dtypes = get_annotations(path, fields, 'summary_state') state_label_df.acquisition_raw_index = state_label_df.acquisition_raw_index / sf state_label_df = state_label_df.rename( columns={'acquisition_raw_index': 'read_start', 'summary_state': 'modal_classification'} ) label_df = label_df.append(state_label_df, ignore_index=True) label_df.sort_values(by='read_start', ascending=True, inplace=True) label_dt.update(state_label_dtypes) # Here labels are thinned out slim_label_df = label_df[(label_df['read_start'] >= s) & (label_df['read_start'] <= e)] for index, label in slim_label_df.iterrows(): event_line = Span( location=label.read_start, dimension='height', line_color='green', line_dash='dashed', line_width=1 ) p.add_layout(event_line) labels = Label( x=label.read_start, y=800, text="{cl} - {ri}".format(cl=label_dt[label.modal_classification], ri=label.read_id), level='glyph', x_offset=0, y_offset=0, render_mode='canvas', angle=-300 ) p.add_layout(labels) return p def get_args(): parser = ArgumentParser( description="""Generate plots for all reads in a fused_reads.txt file. This uses bokeh to render a plot and requires selenium, phantomjs, and Pillow to be installed. These are available via conda/pip.""", add_help=False) general = parser.add_argument_group( title='General options') general.add_argument("-h", "--help", action="help", help="Show this help and exit" ) in_args = parser.add_argument_group( title='Input sources' ) in_args.add_argument("-f", "--fused", help="A fused read file generated by whale_watch.py", type=str, default="", required=True, metavar='' ) in_args.add_argument("-b", "--bulk-file", help="An ONT bulk-fast5-file", type=str, default='', required=True, metavar='' ) out_args = parser.add_argument_group( title='Output files' ) out_args.add_argument('-D', '--out-dir', help='''Specify the output directory where plots will be saved. Defaults to current working directory''', type=str, default='', metavar='' ) return parser.parse_args() def main(): args = get_args() # create output dir if not exist dir = Path.resolve(Path(args.out_dir)) dir.mkdir(parents=True, exist_ok=True) # open bulkfile bulkfile = h5py.File(args.bulk_file, "r") # get run_id and sf sf = int(bulkfile["UniqueGlobalKey"]["context_tags"].attrs["sample_frequency"].decode('utf8')) run_id = bulkfile["UniqueGlobalKey"]["tracking_id"].attrs["run_id"].decode('utf8') fields = ['coords', 'run_id'] fused_df = pd.read_csv(args.fused, sep='\t', usecols=fields) # limit to matching run fused_df = fused_df[fused_df['run_id'] == run_id] for index, row in tqdm(fused_df.iterrows(), total=fused_df.shape[0]): line = row['coords'] coords = line.split(":") times = coords[1].split("-") channel_num = coords[0] start_time, end_time = int(times[0]), int(times[1]) plot = create_figure(channel_num, start_time, end_time, sf, bulkfile, args.bulk_file) line = '{c}.png'.format(c=line.replace(':', '_')) write_path = Path(dir / line) export_png(plot, filename=write_path) return if __name__ == '__main__': main()