"""This script merges multitrack pianorolls to five-track pianorolls. """ import os.path import argparse from pypianoroll import Multitrack, Track from utils import make_sure_path_exists, change_prefix, findall_endswith from config import CONFIG if CONFIG['multicore'] > 1: import joblib TRACK_INFO = ( ('Drums', 0), ('Piano', 0), ('Guitar', 24), ('Bass', 32), ('Strings', 48), ) def parse_args(): """Return the parsed command line arguments.""" parser = argparse.ArgumentParser() parser.add_argument('src', help="root path to the source dataset") parser.add_argument('dst', help="root path to the destination dataset") args = parser.parse_args() return args.src, args.dst def get_merged(multitrack): """Merge the multitrack pianorolls into five instrument families and return the resulting multitrack pianoroll object.""" track_lists_to_merge = [[] for _ in range(5)] for idx, track in enumerate(multitrack.tracks): if track.is_drum: track_lists_to_merge[0].append(idx) elif track.program//8 == 0: track_lists_to_merge[1].append(idx) elif track.program//8 == 3: track_lists_to_merge[2].append(idx) elif track.program//8 == 4: track_lists_to_merge[3].append(idx) elif track.program < 96 or 104 <= track.program < 112: track_lists_to_merge[4].append(idx) tracks = [] for idx, track_list_to_merge in enumerate(track_lists_to_merge): if track_list_to_merge: merged = multitrack[track_list_to_merge].get_merged_pianoroll('max') tracks.append(Track(merged, TRACK_INFO[idx][1], (idx == 0), TRACK_INFO[idx][0])) else: tracks.append(Track(None, TRACK_INFO[idx][1], (idx == 0), TRACK_INFO[idx][0])) return Multitrack(None, tracks, multitrack.tempo, multitrack.downbeat, multitrack.beat_resolution, multitrack.name) def merger(filepath, src, dst): """Load and merge a multitrack pianoroll and save to the given path.""" # Load and merge the multitrack pianoroll multitrack = Multitrack(filepath) merged = get_merged(multitrack) # Save the merged multitrack pianoroll result_path = change_prefix(filepath, src, dst) make_sure_path_exists(os.path.dirname(result_path)) merged.save(result_path) def main(): """Main function.""" src, dst = parse_args() make_sure_path_exists(dst) if CONFIG['multicore'] > 1: joblib.Parallel(n_jobs=CONFIG['multicore'], verbose=5)( joblib.delayed(merger)(npz_path, src, dst) for npz_path in findall_endswith('.npz', src)) else: for npz_path in findall_endswith('.npz', src): merger(npz_path, src, dst) if __name__ == "__main__": main()