# Copyright 2019 Xiaomi, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ converter.Convert - class to manage top level of converting progress python convert.py --input=path/to/kaldi_model.mdl \ --output=path/to/save/onnx_model.onnx \ --conf=path/to/save/configure.conf \ --trans-model=path/to/save/transition.mdl \ --batch=b --chunk-size=cs --nnet-type=(2 or 3) \ --left-context=lc(required) --right-context=rc(required) \ --modulus=m(default is 1) \ --subsample-factor=sf(default is 1) \ --fuse-lstm=(true or false, default is true) \ --fuse-stats=(true or false, default is true) Using -h or --help for more details. """ import argparse import logging import os import six from common import * from graph import Graph from node import make_node from parser import Nnet2Parser, Nnet3Parser from utils import kaldi_check class Converter(object): def __init__(self, nnet_file, batch, chunk_size, left_context, right_context, modulus, nnet_type, subsample_factor=1, fuse_lstm=True, fuse_stats=True): self._components = [] self._nodes = [] self._inputs = [] self._outputs = [] self._input_dims = {} self._nnet_type = nnet_type self._batch = batch self._chunk_size = self.get_chunk_size(chunk_size, subsample_factor, modulus) self._nnet_file = nnet_file self._fuse_lstm = fuse_lstm self._fuse_stats = fuse_stats self._left_context = left_context self._right_context = right_context self._subsample_factor = subsample_factor self._modulus = modulus self._transition_model = [] logging.info( "frames per chunk: %s, left-context: %s, right-context: %s," " modulus: %s" % (self._chunk_size, self._left_context, self._right_context, self._modulus)) def run(self): # parse config file to get components self.parse_configs() # convert components to nodes, inputs and outputs self.convert_components() # to build graph, graph will take over the converting work g = Graph(self._nodes, self._inputs, self._outputs, self._batch, self._chunk_size, self._left_context, self._right_context, self._modulus, self._input_dims, self._subsample_factor, self._nnet_type, self._fuse_lstm, self._fuse_stats) onnx_model = g.run() input_info, output_info, cache_info = g.model_interface_info() input_nodes_str, input_shapes_str = self.nodes_info_to_str(input_info) output_nodes_str, output_shapes_str = \ self.nodes_info_to_str(output_info) left_context_conf = "--left-context=" + str(self._left_context) + "\n" right_context_conf = \ "--right-context=" + str(self._right_context) + "\n" modulus_conf = "--modulus=" + str(self._modulus) + "\n" frames_per_chunk_conf = \ "--frames-per-chunk=" + str(self._chunk_size) + "\n" subsample_factor_conf = \ "--frame-subsampling-factor=" + str(self._subsample_factor) + "\n" conf_lines = [left_context_conf, right_context_conf, modulus_conf, frames_per_chunk_conf, subsample_factor_conf] input_node_conf = "--input-nodes=" + input_nodes_str + "\n" input_shapes_conf = "--input-shapes=" + input_shapes_str + "\n" conf_lines.append(input_node_conf) conf_lines.append(input_shapes_conf) output_node_conf = "--output-nodes=" + output_nodes_str + "\n" output_shapes_conf = "--output-shapes=" + output_shapes_str + "\n" conf_lines.append(output_node_conf) conf_lines.append(output_shapes_conf) if len(cache_info) > 0: cache_nodes_str, cache_shapes_str =\ self.nodes_info_to_str(cache_info) cache_node_conf = "--cache-nodes=" + cache_nodes_str + "\n" cache_shapes_conf = "--cache-shapes=" + cache_shapes_str + "\n" conf_lines.append(cache_node_conf) conf_lines.append(cache_shapes_conf) return onnx_model, conf_lines, self._transition_model @staticmethod def get_chunk_size(chunk, subsample_factor, modulus): frames_per_chunk = chunk while frames_per_chunk % subsample_factor > 0 or\ frames_per_chunk % modulus > 0: frames_per_chunk += 1 if frames_per_chunk >= MaxChunkSize: raise Exception( "The chunk size(%s) is over-ranged, maximum value is %s)." % (frames_per_chunk, MaxChunkSize)) return frames_per_chunk @staticmethod def shape_to_str(shape): shape_str = '' for i in shape: shape_str += str(i) shape_str += ',' if shape_str.endswith(','): return shape_str[:-1] return shape_str def nodes_info_to_str(self, node_info): names_str = '' shapes_str = '' for name, shape in node_info.items(): names_str += name names_str += ' ' shapes_str += self.shape_to_str(shape) shapes_str += ' ' return names_str, shapes_str def parse_configs(self): if self._nnet_type == NNet3: parser = Nnet3Parser(self._nnet_file) elif self._nnet_type == NNet2: parser = Nnet2Parser(self._nnet_file) else: raise Exception("nnet-type should be 2 or 3.") self._components, self._transition_model = parser.run() def convert_components(self): nodes = [] for component in self._components: kaldi_check('type' in component, "'type' is required in component: %s" % component) type = component['type'] if type in KaldiOps: node = self.node_from_component(component) nodes.append(node) elif type == 'Input': self.convert_input(component) elif type == 'Output': self.convert_output(component) else: raise Exception( "Unrecognised component type: {0}.".format(type)) self._nodes = nodes def node_from_component(self, component): kaldi_check('name' in component and 'input' in component and 'type' in component, "'name', 'type' and 'input'" " are required in component: %s" % component) type = component['type'] name = component['name'] inputs = component['input'] if not isinstance(inputs, list): inputs = [inputs] inputs = [input if isinstance(input, six.string_types) else str(input) for input in inputs] attrs = {} if type in ATTRIBUTE_NAMES: attrs_names = ATTRIBUTE_NAMES[type] for key, value in component.items(): if key in attrs_names: attrs[key] = value if type == KaldiOpType.ReplaceIndex.name: attrs['chunk_size'] = self._chunk_size attrs['left_context'] = self._left_context attrs['right_context'] = self._right_context if type == KaldiOpType.IfDefined.name: attrs['chunk_size'] = self._chunk_size consts = {} if type in CONSTS_NAMES: param_names = CONSTS_NAMES[type] for p_name in param_names: if p_name in component: p_values = component[p_name] p_tensor_name = name + '_' + p_name consts[p_tensor_name] = p_values inputs.append(p_tensor_name) return make_node(name, type, inputs, [name], attrs, consts) def convert_input(self, component): kaldi_check('input_dim' in component or 'dim' in component, "input_dim or dim attribute is required in input" " component: %s" % component) if 'input_dim' in component: dim = component['input_dim'] else: dim = component['dim'] self._input_dims[component['name']] = int(dim) self._inputs.append(component['name']) def convert_output(self, component): outputs = component['input'] self._outputs.extend(outputs) def str2bool(v): if v.lower() in ('yes', 'true', 'True', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'False', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Unsupported value encountered.') def get_args(): """Parse commandline.""" parser = argparse.ArgumentParser() parser.add_argument("--input", required=True, help="Input model file(*.mdl, should be text format).") parser.add_argument("--output", help="Output onnx model file(*.onnx)." "Using input model file's name as default.") parser.add_argument("--configure", dest='conf', help="Path to save configure file(*.conf)." "Using output model file's name as default.") parser.add_argument("--trans-model", dest='trans_path', help="Output transition model file(*.trans)." "Using output model file's name as default.") parser.add_argument('--chunk-size', type=int, dest='chunk_size', help='chunk size, default is 20', default=DefaultChunkSize) parser.add_argument('--batch', type=int, dest='batch', help='batch size, default is 1', default=DefaultBatch) parser.add_argument('--nnet-type', type=int, dest='nnet_type', help='nnet type: 2 or 3', default=NNet3) parser.add_argument('--fuse-lstm', type=str2bool, dest='fuse_lstm', help='fuse lstm four parts to dynamic lstm or not,' ' default is true', default=True) parser.add_argument('--fuse-stats', type=str2bool, dest='fuse_stats', help='fuse StatisticsExtraction/StatisticsPooling' ' or not, default is true', default=True) parser.add_argument('--left-context', required=True, type=int, dest='left_context', help='Add Left Context') parser.add_argument('--right-context', required=True, type=int, dest='right_context', help='Add RightContext') parser.add_argument('--modulus', type=int, dest='modulus', help='Modulus of the model.', default=1) parser.add_argument('--subsample-factor', type=int, dest='subsample_factor', help='Add Subsample factor, default is 1.', default=1) args = parser.parse_args() return args def main(): args = get_args() if args.input: with open(args.input, 'r') as model_file: converter = Converter(model_file, args.batch, args.chunk_size, args.left_context, args.right_context, args.modulus, args.nnet_type, args.subsample_factor, args.fuse_lstm, args.fuse_stats) onnx_model, configs, trans_model = converter.run() logging.info("Kaldi to ONNX converting finished!") if args.output: output_path = args.output else: output_path = os.path.splitext(args.input) output_path = output_path[0] + '.onnx' with open(output_path, "wb") as of: of.write(onnx_model.SerializeToString()) logging.info("The new onnx model file is: %s" % output_path) if len(trans_model) > 0: if args.trans_path: trans_path = args.trans_path else: trans_path = os.path.splitext(output_path) trans_path = trans_path[0] + '.trans' with open(trans_path, "w") as trans_file: trans_file.writelines(trans_model) logging.info( "The transition model file is: %s" % trans_path) trans_model_conf = "--trans-model=" + trans_path + "\n" configs.append(trans_model_conf) if len(configs) > 0: if args.conf: conf_path = args.conf else: conf_path = os.path.splitext(output_path) conf_path = conf_path[0] + '.conf' with open(conf_path, "w") as conf_file: conf_file.writelines(configs) logging.info("The configure file is: %s" % conf_path) else: raise Exception("invalid input file path: {0}.".format(args.input)) if __name__ == "__main__": logging.basicConfig( format="%(asctime)s %(name)s %(levelname)s %(message)s", level=logging.INFO) main()