# -*- coding: utf-8 -*- """ @author:XuMing(xuming624@qq.com) @description: """ import os import sys import tensorflow as tf sys.path.append('../../..') from pycorrector.transformer.tf import config from pycorrector.transformer.tf.corpus_reader import CGEDReader, save_word_dict from pycorrector.transformer.tf.model import train, model, checkpoint def main(model_dir='', src_train_path='', tgt_train_path='', src_vocab_path='', tgt_vocab_path='', batch_size=32, maximum_length=100, train_steps=10000, save_every=1000, report_every=50): data_reader = CGEDReader(src_train_path) src_input_texts = data_reader.build_dataset(src_train_path) tgt_input_texts = data_reader.build_dataset(tgt_train_path) # load or save word dict if not os.path.exists(src_vocab_path): print('Training data...') print('input_texts:', src_input_texts[0]) print('target_texts:', tgt_input_texts[0]) max_input_texts_len = max([len(text) for text in src_input_texts]) print('num of samples:', len(src_input_texts)) print('max sequence length for inputs:', max_input_texts_len) src_vocab = data_reader.read_vocab(src_input_texts) id2char = {i: j for i, j in enumerate(src_vocab)} char2id = {j: i for i, j in id2char.items()} save_word_dict(char2id, src_vocab_path) tgt_vocab = data_reader.read_vocab(tgt_input_texts) id2char = {i: j for i, j in enumerate(tgt_vocab)} char2id = {j: i for i, j in id2char.items()} save_word_dict(char2id, tgt_vocab_path) data_config = { "source_vocabulary": src_vocab_path, "target_vocabulary": tgt_vocab_path } model.initialize(data_config) checkpoint_manager = tf.train.CheckpointManager(checkpoint, model_dir, max_to_keep=5) if checkpoint_manager.latest_checkpoint is not None: tf.get_logger().info("Restoring parameters from %s", checkpoint_manager.latest_checkpoint) checkpoint.restore(checkpoint_manager.latest_checkpoint) train(src_train_path, tgt_train_path, checkpoint_manager, batch_size=batch_size, maximum_length=maximum_length, train_steps=train_steps, save_every=save_every, report_every=report_every) if __name__ == "__main__": if config.gpu_id > -1: os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_id) main(config.model_dir, src_train_path=config.src_train_path, tgt_train_path=config.tgt_train_path, src_vocab_path=config.src_vocab_path, tgt_vocab_path=config.tgt_vocab_path, batch_size=config.batch_size, maximum_length=config.maximum_length, train_steps=config.train_steps, save_every=config.save_every, report_every=config.report_every)