import operator from pyspark import RDD, Row from models.code2vec_features import Code2VecFeatures from ast import literal_eval as make_tuple from sourced.ml.transformers import Transformer class Vocabulary2Id(Transformer): def __init__(self, sc, output: str, **kwargs): super().__init__(**kwargs) self.output = output self.sc = sc def __call__(self, rows: RDD): value2index, path2index, value2freq, path2freq = self.build_vocabularies(rows) doc2path_contexts = self.build_doc2pc(value2index, path2index, rows) doc2path_contexts = doc2path_contexts.collect() Code2VecFeatures().construct(value2index=value2index, path2index=path2index, value2freq=value2freq, path2freq=path2freq, path_contexts=doc2path_contexts).save( self.output) @staticmethod def _unstringify_path_context(row): """ Takes a row containing ((pc, doc), freq) and returns a tuple (u, path, v) (removes namespace prefix v.) """ return make_tuple(row[0][0][2:]) def build_vocabularies(self, rows: RDD): """ Process rows to gather values and paths with their frequencies. :param rows: row structure is ((key, doc), val) where: * key: str with the path context * doc: file name * val: number of occurrences of key in doc """ def _flatten_row(row: Row): # 2: removes the namespace v. from the string to parse it as tuple k = Vocabulary2Id._unstringify_path_context(row) return [(k[0], 1), (k[1], 1), (k[2], 1)] rows = rows \ .flatMap(_flatten_row) \ .reduceByKey(operator.add) \ .persist() values = rows.filter(lambda x: type(x[0]) == str).collect() paths = rows.filter(lambda x: type(x[0]) == tuple).collect() value2index = {w: id for id, (w, _) in enumerate(values)} path2index = {w: id for id, (w, _) in enumerate(paths)} value2freq = {w: freq for _, (w, freq) in enumerate(values)} path2freq = {w: freq for _, (w, freq) in enumerate(paths)} rows.unpersist() return value2index, path2index, value2freq, path2freq def build_doc2pc(self, value2index: dict, path2index: dict, rows: RDD): """ Process rows and build elements (doc, [path_context_1, path_context_2, ...]) :param value2index_freq: value -> id :param path2index_freq: path -> id """ bc_value2index = self.sc.broadcast(value2index) bc_path2index = self.sc.broadcast(path2index) def _doc2pc(row: Row): (u, path, v), doc = Vocabulary2Id._unstringify_path_context(row), row[0][1] return doc, (bc_value2index.value[u], bc_path2index.value[path], bc_value2index.value[v]) rows = rows \ .map(_doc2pc) \ .distinct() \ .combineByKey(lambda value: [value], lambda x, value: x + [value], lambda x, y: x + y) bc_value2index.unpersist(blocking=True) bc_path2index.unpersist(blocking=True) return rows