#!/usr/bin/env python3 import sys import itertools import fileinput def get_mapping(file_path): with open(file_path, "r", encoding="utf-8") as f: for l in f: subwords = l.strip().split() yield list(itertools.accumulate([int('▁' in x) for x in subwords])) def convert(src_file, tgt_file): examples = zip(get_mapping(src_file), get_mapping(tgt_file), fileinput.input(files=["-"])) for src_map, tgt_map, line in examples: subword_alignments = {(int(a), int(b)) for a, b in (x.split("-") for x in line.split())} # Subtract 1 to ensure zero indexed alignments (Using --add_dummy_prefix 1 for spm temporarly changed that) word_alignments = {"{}-{}".format(src_map[a] - 1, tgt_map[b] - 1) for a, b in subword_alignments} yield word_alignments if __name__ == '__main__': if len(sys.argv) != 3: print("Two parameters are required, e.g.: {} text.spm.source text.spm.target < sentence_piece.talp > word.talp".format(sys.argv[0])) exit(1) for word_alignment in convert(*sys.argv[1:]): print(" ".join(word_alignment))