"""BERT NER""" __author__ = "Chenyang Yuan" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import csv import os import random import modeling import optimization import tokenization import tensorflow as tf from tf_metrics import precision, recall, f1 import numpy as np from tensorflow.contrib.layers.python.layers import initializers # from bilm_crf_layer import BLSTM_CRF flags = tf.flags FLAGS = flags.FLAGS ## Required parameters flags.DEFINE_string( "data_dir", None, "The input data dir. Should contain the .tsv files (or other data files) " "for the task.") flags.DEFINE_string( "bert_config_file", None, "The config json file corresponding to the pre-trained BERT model. " "This specifies the model architecture.") flags.DEFINE_string("task_name", None, "The name of the task to train.") flags.DEFINE_string("vocab_file", None, "The vocabulary file that the BERT model was trained on.") flags.DEFINE_string( "output_dir", None, "The output directory where the model checkpoints will be written.") ## Other parameters flags.DEFINE_string( "init_checkpoint", None, "Initial checkpoint (usually from a pre-trained BERT model).") flags.DEFINE_bool( "do_lower_case", True, "Whether to lower case the input text. Should be True for uncased " "models and False for cased models.") flags.DEFINE_integer( "max_seq_length", 128, "The maximum total input sequence length after WordPiece tokenization. " "Sequences longer than this will be truncated, and sequences shorter " "than this will be padded.") flags.DEFINE_bool("do_train", False, "Whether to run training.") flags.DEFINE_bool("do_eval", False, "Whether to run eval on the dev set.") flags.DEFINE_bool( "do_predict", False, "Whether to run the model in inference mode on the test set.") flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.") flags.DEFINE_integer("eval_batch_size", 32, "Total batch size for eval.") flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.") flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.") flags.DEFINE_float("num_train_epochs", 3.0, "Total number of training epochs to perform.") flags.DEFINE_float( "warmup_proportion", 0.1, "Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10% of training.") flags.DEFINE_integer("save_checkpoints_steps", 1000, "How often to save the model checkpoint.") flags.DEFINE_integer("iterations_per_loop", 1000, "How many steps to make in each estimator call.") flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.") tf.flags.DEFINE_string( "tpu_name", None, "The Cloud TPU to use for training. This should be either the name " "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 " "url.") tf.flags.DEFINE_string( "tpu_zone", None, "[Optional] GCE zone where the Cloud TPU is located in. If not " "specified, we will attempt to automatically detect the GCE project from " "metadata.") tf.flags.DEFINE_string( "gcp_project", None, "[Optional] Project name for the Cloud TPU-enabled project. If not " "specified, we will attempt to automatically detect the GCE project from " "metadata.") tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.") flags.DEFINE_integer( "num_tpu_cores", 8, "Only used if `use_tpu` is True. Total number of TPU cores to use.") flags.DEFINE_integer( "lstm_size", 128, "lstm units of bilm.") flags.DEFINE_string('cell', 'lstm', 'which rnn cell used') flags.DEFINE_integer('num_layers', 1, 'number of rnn layers, default is 1') flags.DEFINE_float('dropout_rate', 0.5, 'Dropout rate') class InputFeatures(object): """A single set of features of data.""" def __init__(self, input_ids, input_mask, segment_ids, label_ids, seq_length, is_real_example=True): self.input_ids = input_ids self.input_mask = input_mask self.segment_ids = segment_ids self.label_ids = label_ids self.is_real_example = is_real_example self.seq_length = seq_length class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" def get_train_examples(self, data_dir): """Gets a collection of `InputExample`s for the train set.""" raise NotImplementedError() def get_dev_examples(self, data_dir): """Gets a collection of `InputExample`s for the dev set.""" raise NotImplementedError() def get_test_examples(self, data_dir): """Gets a collection of `InputExample`s for prediction.""" raise NotImplementedError() def get_labels(self): """Gets the list of labels for this data set.""" raise NotImplementedError() @classmethod def _read_tsv(cls, input_file, quotechar=None): """Reads a tab separated value file.""" with tf.gfile.Open(input_file, "r") as f: reader = csv.reader(f, delimiter="\t", quotechar=quotechar) lines = [] for line in reader: lines.append(line) return lines class NerProcessor(DataProcessor): def get_train_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "train.tsv")), "train") def get_dev_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "dev_matched.tsv")), "dev_matched") def get_test_examples(self, data_dir): """See base class.""" return self._create_examples( self._read_tsv(os.path.join(data_dir, "test_matched.tsv")), "test") def get_labels(self): """See base class.""" return ["contradiction", "BPER", "BLOC", "BORG", "BTIME", "IPER", "ILOC", "IORG", "ITIME", "EPER", "ELOC", "EORG", "ETIME", "O", "OTHER"] # return ["contradiction", "B", "I", "E", "O"] def _create_examples(self, lines, set_type): '''NER任务训练数据产生方法''' examples = [] for (i, line) in enumerate(lines): if i == 0: continue guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) text_a = tokenization.convert_to_unicode(line[1].strip()) # text_b = tokenization.convert_to_unicode(line[9]) label_t = tokenization.convert_to_unicode(line[2]) labels = list() if set_type == "test": # print("*****test*****") labels = ["contradiction" for i in range(FLAGS.max_seq_length)] else: labels = label_t.split() examples.append( InputExample(guid=guid, text_a=text_a, labels=labels)) return examples class InputExample(object): """A single training/test example for simple sequence classification.""" def __init__(self, guid, text_a, labels=None): """Constructs a InputExample. Args: guid: Unique id for the example. text_a: string. The untokenized text of the first sequence. For single sequence tasks, only this sequence must be specified. text_b: (Optional) string. The untokenized text of the second sequence. Only must be specified for sequence pair tasks. label: (Optional) string. The label of the example. This should be specified for train and dev examples, but not for test examples. """ self.guid = guid self.text_a = text_a self.labels = labels class PaddingInputExample(object): """Fake example so the num input examples is a multiple of the batch size. When running eval/predict on the TPU, we need to pad the number of examples to be a multiple of the batch size, because the TPU requires a fixed batch size. The alternative is to drop the last batch, which is bad because it means the entire output data won't be generated. We use this class instead of `None` because treating `None` as padding battches could cause silent errors. """ def convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer): """Converts a single `InputExample` into a single `InputFeatures`.""" if isinstance(example, PaddingInputExample): return InputFeatures( input_ids=[0] * max_seq_length, input_mask=[0] * max_seq_length, segment_ids=[0] * max_seq_length, label_ids=[0] * max_seq_length, seq_length = max_seq_length, is_real_example=False) label_map = {} for (i, label) in enumerate(label_list): label_map[label] = i tokens_a = tokenizer.tokenize(example.text_a) tokens_b = None if len(tokens_a) > max_seq_length: tokens_a = tokens_a[0:(max_seq_length)] tokens = [] segment_ids = [] # tokens.append("[CLS]") # segment_ids.append(0) for token in tokens_a: tokens.append(token) segment_ids.append(0) # tokens.append("[SEP]") # segment_ids.append(0) input_ids = tokenizer.convert_tokens_to_ids(tokens) seq_length = len(input_ids) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. input_mask = [1] * len(input_ids) # Zero-pad up to the sequence length. while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) segment_ids.append(0) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length label_id = None label_id = [] # label_id.append(4) for l in example.labels: label_id.append(label_map[l]) # label_id.append(5) while len(label_id) < max_seq_length: label_id.append(0) if ex_index < 5: tf.logging.info("*** Example ***") tf.logging.info("guid: %s" % (example.guid)) tf.logging.info("tokens: %s" % " ".join( [tokenization.printable_text(x) for x in tokens])) tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) tf.logging.info("label: %s (id = %s)" % (example.labels, label_id)) tf.logging.info("seq_length: %s (id = %s)" % (seq_length, label_id)) feature = InputFeatures( input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, label_ids=label_id, seq_length = seq_length, is_real_example=True) return feature def _truncate_seq_pair(tokens_a, tokens_b, max_length): """Truncates a sequence pair in place to the maximum length.""" # This is a simple heuristic which will always truncate the longer sequence # one token at a time. This makes more sense than truncating an equal percent # of tokens from each, since if one sequence is very short then each token # that's truncated likely contains more information than a longer sequence. while True: total_length = len(tokens_a) + len(tokens_b) if total_length <= max_length: break if len(tokens_a) > len(tokens_b): tokens_a.pop() else: tokens_b.pop() def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, seq_length, num_labels, use_one_hot_embeddings): """Creates a classification model.""" model = modeling.BertModel( config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, use_one_hot_embeddings=use_one_hot_embeddings) embedding = model.get_sequence_output() embeddings = tf.layers.dropout(embedding, rate=FLAGS.dropout_rate, training=is_training) with tf.variable_scope('Graph', reuse=None, custom_getter=None): # LSTM t = tf.transpose(embeddings, perm=[1, 0, 2]) lstm_cell_fw = tf.contrib.rnn.LSTMBlockFusedCell(128) # 序列标注问题中一般lstm单元个数就是max_seq_length lstm_cell_bw = tf.contrib.rnn.LSTMBlockFusedCell(128) lstm_cell_bw = tf.contrib.rnn.TimeReversedFusedRNN(lstm_cell_bw) output_fw, _ = lstm_cell_fw(t, dtype=tf.float32, sequence_length=seq_length) output_bw, _ = lstm_cell_bw(t, dtype=tf.float32, sequence_length=seq_length) output = tf.concat([output_fw, output_bw], axis=-1) output = tf.transpose(output, perm=[1, 0, 2]) output = tf.layers.dropout(output, rate=0.5, training=is_training) # CRF logits = tf.layers.dense(output, num_labels) crf_params = tf.get_variable("crf", [num_labels, num_labels], dtype=tf.float32) trans = tf.get_variable( "transitions", shape=[num_labels, num_labels], initializer=initializers.xavier_initializer()) pred_ids, trans = tf.contrib.crf.crf_decode(logits, crf_params, seq_length) log_likelihood, _ = tf.contrib.crf.crf_log_likelihood( logits, label_ids, seq_length, crf_params) loss = tf.reduce_mean(-log_likelihood) # if mode == tf.estimator.ModeKeys.EVAL: # return tf.estimator.EstimatorSpec( # mode, loss=loss, eval_metric_ops=metrics) # elif mode == tf.estimator.ModeKeys.TRAIN: return loss, logits, trans, pred_ids def model_fn_builder(bert_config, init_checkpoint, learning_rate, num_train_steps, num_warmup_steps, use_tpu, use_one_hot_embeddings): """Returns `model_fn` closure for TPUEstimator.""" def model_fn(features, mode, params): # pylint: disable=unused-argument """The `model_fn` for TPUEstimator.""" tf.logging.info("*** Features ***") for name in sorted(features.keys()): tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] label_ids = features["label_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] seq_length = features["seq_length"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) (total_loss, logits, trans, pred_ids) = create_model( bert_config, is_training, input_ids, input_mask, segment_ids, label_ids, seq_length, params['num_labels'], use_one_hot_embeddings) tf.logging.info("****num_labels****") tf.logging.info("number_labels: %s" % params['num_labels']) tvars = tf.trainable_variables() scaffold_fn = None # 加载BERT模型 if init_checkpoint: (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) if use_tpu: def tpu_scaffold(): tf.train.init_from_checkpoint(init_checkpoint, assignment_map) return tf.train.Scaffold() scaffold_fn = tpu_scaffold else: tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.logging.info("**** Trainable Variables ****") # 打印加载模型的参数,调试的时候比较方便看到自己的程序哪里出了问题 # for var in tvars: # init_string = "" # if var.name in initialized_variable_names: # init_string = ", *INIT_FROM_CKPT*" # tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape, # init_string) output_spec = None if not mode == tf.estimator.ModeKeys.PREDICT: eval_metrics = None def metric_fn(seq_length, max_len, label_ids, pred_ids): indices = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # indice参数告诉评估矩阵评估哪些标签 # Metrics weights = tf.sequence_mask(seq_length, maxlen=max_len) metrics = { 'acc': tf.metrics.accuracy(label_ids, pred_ids, weights), 'precision': precision(label_ids, pred_ids, params['num_labels'], indices, weights), 'recall': recall(label_ids, pred_ids, params['num_labels'], indices, weights), 'f1': f1(label_ids, pred_ids, params['num_labels'], indices, weights), } for metric_name, op in metrics.items(): tf.summary.scalar(metric_name, op[1]) return eval_metrics eval_metrics = metric_fn(seq_length, FLAGS.max_seq_length, label_ids, pred_ids) if mode == tf.estimator.ModeKeys.EVAL: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, scaffold_fn=scaffold_fn, eval_metrics=eval_metrics) if mode == tf.estimator.ModeKeys.TRAIN: train_op = optimization.create_optimizer( total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu) output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) # 钩子,这里用来将BERT中的参数作为我们模型的初始值 else: output_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, predictions=pred_ids, # prediction参数告诉estimator在预测时的输出,使用tensorflow server的时候很有用 scaffold_fn=scaffold_fn ) return output_spec return model_fn def input_fn_builder(features, seq_length, is_training, drop_remainder): """Creates an `input_fn` closure to be passed to TPUEstimator.""" all_input_ids = [] all_input_mask = [] all_segment_ids = [] all_label_ids = [] for feature in features: all_input_ids.append(feature.input_ids) all_input_mask.append(feature.input_mask) all_segment_ids.append(feature.segment_ids) all_label_ids.append(feature.label_id) def input_fn(params): """The actual input function.""" # batch_size = params["batch_size"] num_examples = len(features) batch_size = params["batch_size"] # This is for demo purposes and does NOT scale to large data sets. We do # not use Dataset.from_generator() because that uses tf.py_func which is # not TPU compatible. The right way to load data is with TFRecordReader. d = tf.data.Dataset.from_tensor_slices({ "input_ids": tf.constant( all_input_ids, shape=[num_examples, seq_length], dtype=tf.int32), "input_mask": tf.constant( all_input_mask, shape=[num_examples, seq_length], dtype=tf.int32), "segment_ids": tf.constant( all_segment_ids, shape=[num_examples, seq_length], dtype=tf.int32), "label_ids": tf.constant(all_label_ids, shape=[num_examples, seq_length], dtype=tf.int32), }) if is_training: d = d.repeat() d = d.shuffle(buffer_size=100) d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder) return d return input_fn # This function is not used by this file but is still used by the Colab and # people who depend on it. def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): """Convert a set of `InputExample`s to a list of `InputFeatures`.""" features = [] for (ex_index, example) in enumerate(examples): if ex_index % 10000 == 0: tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) feature = convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer) features.append(feature) return features def file_based_convert_examples_to_features( examples, label_list, max_seq_length, tokenizer, output_file): """Convert a set of `InputExample`s to a TFRecord file.""" writer = tf.python_io.TFRecordWriter(output_file) for (ex_index, example) in enumerate(examples): if ex_index % 10000 == 0: tf.logging.info("Writing example %d of %d" % (ex_index, len(examples))) feature = convert_single_example(ex_index, example, label_list, max_seq_length, tokenizer) def create_int_feature(values): if isinstance(values, list): f = tf.train.Feature(int64_list=tf.train.Int64List(value=values)) else: f = tf.train.Feature(int64_list=tf.train.Int64List(value=[values])) return f features = collections.OrderedDict() features["input_ids"] = create_int_feature(feature.input_ids) features["input_mask"] = create_int_feature(feature.input_mask) features["segment_ids"] = create_int_feature(feature.segment_ids) features["label_ids"] = create_int_feature(feature.label_ids) features["seq_length"] = create_int_feature([feature.seq_length]) features["is_real_example"] = create_int_feature( [int(feature.is_real_example)]) tf_example = tf.train.Example(features=tf.train.Features(feature=features)) writer.write(tf_example.SerializeToString()) writer.close() def file_based_input_fn_builder(input_file, seq_length, is_training, drop_remainder): """Creates an `input_fn` closure to be passed to TPUEstimator.""" name_to_features = { "input_ids": tf.FixedLenFeature([seq_length], tf.int64), "input_mask": tf.FixedLenFeature([seq_length], tf.int64), "segment_ids": tf.FixedLenFeature([seq_length], tf.int64), "label_ids": tf.FixedLenFeature([seq_length], tf.int64), "seq_length": tf.FixedLenFeature([], tf.int64), "is_real_example": tf.FixedLenFeature([], tf.int64), # "seq_length_vec": tf.FixedLenFeature([], tf.int64), } def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example.""" example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example def input_fn(params): """The actual input function.""" batch_size = params["batch_size"] # For training, we want a lot of parallel reading and shuffling. # For eval, we want no shuffling and parallel reading doesn't matter. d = tf.data.TFRecordDataset(input_file) if is_training: d = d.repeat() d = d.shuffle(buffer_size=100) d = d.apply( tf.contrib.data.map_and_batch( lambda record: _decode_record(record, name_to_features), batch_size=batch_size, drop_remainder=drop_remainder)) return d return input_fn def main(_): tf.logging.set_verbosity(tf.logging.INFO) bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) # validate_flags_or_throw(bert_config) tf.gfile.MakeDirs(FLAGS.output_dir) tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case, FLAGS.init_checkpoint) if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict: raise ValueError( "At least one of `do_train`, `do_eval` or `do_predict' must be True.") bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file) if FLAGS.max_seq_length > bert_config.max_position_embeddings: raise ValueError( "Cannot use sequence length %d because the BERT model " "was only trained up to sequence length %d" % (FLAGS.max_seq_length, bert_config.max_position_embeddings)) processor = NerProcessor() label_list = processor.get_labels() label_num = len(label_list) tokenizer = tokenization.FullTokenizer( vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) tpu_cluster_resolver = None if FLAGS.use_tpu and FLAGS.tpu_name: tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 run_config = tf.contrib.tpu.RunConfig( cluster=tpu_cluster_resolver, master=FLAGS.master, model_dir=FLAGS.output_dir, save_checkpoints_steps=FLAGS.save_checkpoints_steps, tpu_config=tf.contrib.tpu.TPUConfig( iterations_per_loop=FLAGS.iterations_per_loop, num_shards=FLAGS.num_tpu_cores, per_host_input_for_training=is_per_host)) train_examples = None num_train_steps = None num_warmup_steps = None if FLAGS.do_train: train_examples = processor.get_train_examples(FLAGS.data_dir) num_train_steps = int( len(train_examples) / FLAGS.train_batch_size * FLAGS.num_train_epochs) num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion) eval_examples = processor.get_dev_examples(FLAGS.data_dir) # Pre-shuffle the input to avoid having to make a very large shuffle # buffer in in the `input_fn`. rng = random.Random(12345) rng.shuffle(train_examples) model_fn = model_fn_builder( bert_config=bert_config, init_checkpoint=FLAGS.init_checkpoint, learning_rate=FLAGS.learning_rate, num_train_steps=num_train_steps, num_warmup_steps=num_warmup_steps, use_tpu=FLAGS.use_tpu, use_one_hot_embeddings=FLAGS.use_tpu) # If TPU is not available, this will fall back to normal Estimator on CPU # or GPU. params={ "num_labels": label_num, } estimator = tf.contrib.tpu.TPUEstimator( use_tpu=False, model_fn=model_fn, config=run_config, train_batch_size=FLAGS.train_batch_size, eval_batch_size=FLAGS.eval_batch_size, predict_batch_size=FLAGS.predict_batch_size, params=params) if FLAGS.do_train: # We write to a temporary file to avoid storing very large constant tensors # in memory. if os.path.exists(os.path.join(FLAGS.output_dir, "train.tf_record")): train_file = os.path.join(FLAGS.output_dir, "train.tf_record") else: train_file = os.path.join(FLAGS.output_dir, "train.tf_record") file_based_convert_examples_to_features( train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) if os.path.exists(os.path.join(FLAGS.output_dir, "eval.tf_record")): eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") else: eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record") file_based_convert_examples_to_features( eval_examples, label_list, FLAGS.max_seq_length, tokenizer, eval_file) tf.logging.info("***** Running training *****") tf.logging.info(" Num examples = %d", len(train_examples)) tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) tf.logging.info(" Num steps = %d", num_train_steps) tf.logging.info("***** evaluation *****") tf.logging.info(" Num examples = %d", len(eval_examples)) tf.logging.info(" Batch size = %d", FLAGS.train_batch_size) tf.logging.info(" Num steps = %d", num_train_steps) del train_examples del eval_examples train_input_fn = file_based_input_fn_builder( input_file=train_file, seq_length=FLAGS.max_seq_length, is_training=True, drop_remainder=True) eval_input_fn = file_based_input_fn_builder( input_file=eval_file, seq_length=FLAGS.max_seq_length, is_training=True, drop_remainder=True) # estimator.train(input_fn=train_input_fn, max_steps=num_train_steps) hook = tf.contrib.estimator.stop_if_no_increase_hook( estimator, 'loss', 100000, min_steps=50000, run_every_secs=300) train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, hooks=[hook]) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, throttle_secs=300) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) if __name__ == "__main__": flags.mark_flag_as_required("vocab_file") flags.mark_flag_as_required("bert_config_file") flags.mark_flag_as_required("output_dir") tf.app.run()