try: from distributed_single_sentence_classification.model_interface import model_zoo from distillation import distillation_utils from loss import loss_utils except: from distributed_single_sentence_classification.model_interface import model_zoo from distillation import distillation_utils from loss import loss_utils import tensorflow as tf import numpy as np from model_io import model_io from task_module import classifier import tensorflow as tf from metric import tf_metrics from task_module import pretrain from utils.bert import bert_utils from optimizer import distributed_optimizer as optimizer from utils.simclr import simclr_utils def build_accuracy(logits, labels, mask, loss_type): mask = tf.cast(mask, tf.float32) if loss_type == 'contrastive_loss': temp_sim = tf.subtract(tf.ones_like(logits), tf.rint(logits), name="temp_sim") #auto threshold 0.5 correct = tf.equal( tf.cast(temp_sim, tf.float32), tf.cast(labels, tf.float32) ) accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)*mask)/(1e-10+tf.reduce_sum(mask)) elif loss_type == 'exponent_neg_manhattan_distance_mse': temp_sim = tf.rint(logits) correct = tf.equal( tf.cast(temp_sim, tf.float32), tf.cast(labels, tf.float32) ) accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)*mask)/(1e-10+tf.reduce_sum(mask)) return accuracy def model_fn_builder(model, model_config, num_labels, init_checkpoint, model_reuse=None, load_pretrained=True, model_io_config={}, opt_config={}, exclude_scope="", not_storage_params=[], target="a", label_lst=None, output_type="sess", task_layer_reuse=None, **kargs): def model_fn(features, labels, mode): task_type = kargs.get("task_type", "cls") label_ids = tf.cast(features["{}_label_ids".format(task_type)], tf.float32) if task_type in ['mnli', 'cmnli']: loss_mask = tf.cast(features["{}_loss_multipiler".format(task_type)], tf.float32) nerual_label = tf.not_equal( label_ids, tf.zeros_like(label_ids) ) pos_label = tf.equal( label_ids, tf.ones_like(label_ids) ) neg_label = tf.not_equal( label_ids, 2*tf.ones_like(label_ids) ) loss_mask *= tf.cast(nerual_label, dtype=tf.float32) # make neural label label_ids *= tf.cast(neg_label, dtype=tf.float32) else: loss_mask = tf.cast(features["{}_loss_multipiler".format(task_type)], tf.float32) num_task = kargs.get('num_task', 1) model_io_fn = model_io.ModelIO(model_io_config) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob is_training = True else: dropout_prob = 0.0 is_training = False if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope if kargs.get("get_pooled_output", "pooled_output") == "pooled_output": pooled_feature = model.get_pooled_output() elif kargs.get("get_pooled_output", "task_output") == "task_output": pooled_feature_dict = model.get_task_output() pooled_feature = pooled_feature_dict['pooled_feature'] if kargs.get('apply_head_proj', False): with tf.variable_scope(scope+"/head_proj", reuse=tf.AUTO_REUSE): feature_a = simclr_utils.projection_head(pooled_feature_dict['feature_a'], is_training, head_proj_dim=128, num_nlh_layers=1, head_proj_mode='nonlinear', name='head_contrastive') pooled_feature_dict['feature_a'] = feature_a with tf.variable_scope(scope+"/head_proj", reuse=tf.AUTO_REUSE): feature_b = simclr_utils.projection_head(pooled_feature_dict['feature_b'], is_training, head_proj_dim=128, num_nlh_layers=1, head_proj_mode='nonlinear', name='head_contrastive') pooled_feature_dict['feature_b'] = feature_b tf.logging.info("****** apply contrastive feature projection *******") loss = tf.constant(0.0) params_size = model_io_fn.count_params(model_config.scope) print("==total encoder params==", params_size) if kargs.get("feature_distillation", True): universal_feature_a = features.get("input_ids_a_features", None) universal_feature_b = features.get("input_ids_b_features", None) if universal_feature_a is None or universal_feature_b is None: tf.logging.info("****** not apply feature distillation *******") feature_loss = tf.constant(0.0) else: feature_a = pooled_feature_dict['feature_a'] feature_a_shape = bert_utils.get_shape_list(feature_a, expected_rank=[2,3]) pretrain_feature_a_shape = bert_utils.get_shape_list(universal_feature_a, expected_rank=[2,3]) if feature_a_shape[-1] != pretrain_feature_a_shape[-1]: with tf.variable_scope(scope+"/feature_proj", reuse=tf.AUTO_REUSE): proj_feature_a = tf.layers.dense(feature_a, pretrain_feature_a_shape[-1]) # with tf.variable_scope(scope+"/feature_rec", reuse=tf.AUTO_REUSE): # proj_feature_a_rec = tf.layers.dense(proj_feature_a, feature_a_shape[-1]) # loss += tf.reduce_mean(tf.reduce_sum(tf.square(proj_feature_a_rec-feature_a), axis=-1))/float(num_task) tf.logging.info("****** apply auto-encoder for feature compression *******") else: proj_feature_a = feature_a feature_a_norm = tf.stop_gradient(tf.sqrt(tf.reduce_sum(tf.pow(proj_feature_a, 2), axis=-1, keepdims=True))+1e-20) proj_feature_a /= feature_a_norm feature_b = pooled_feature_dict['feature_b'] if feature_a_shape[-1] != pretrain_feature_a_shape[-1]: with tf.variable_scope(scope+"/feature_proj", reuse=tf.AUTO_REUSE): proj_feature_b = tf.layers.dense(feature_b, pretrain_feature_a_shape[-1]) # with tf.variable_scope(scope+"/feature_rec", reuse=tf.AUTO_REUSE): # proj_feature_b_rec = tf.layers.dense(proj_feature_b, feature_a_shape[-1]) # loss += tf.reduce_mean(tf.reduce_sum(tf.square(proj_feature_b_rec-feature_b), axis=-1))/float(num_task) tf.logging.info("****** apply auto-encoder for feature compression *******") else: proj_feature_b = feature_b feature_b_norm = tf.stop_gradient(tf.sqrt(tf.reduce_sum(tf.pow(proj_feature_b, 2), axis=-1, keepdims=True))+1e-20) proj_feature_b /= feature_b_norm feature_a_distillation = tf.reduce_mean(tf.square(universal_feature_a-proj_feature_a), axis=-1) feature_b_distillation = tf.reduce_mean(tf.square(universal_feature_b-proj_feature_b), axis=-1) feature_loss = tf.reduce_mean((feature_a_distillation + feature_b_distillation)/2.0)/float(num_task) loss += feature_loss tf.logging.info("****** apply prertained feature distillation *******") if kargs.get("embedding_distillation", True): word_embed = model.emb_mat random_embed_shape = bert_utils.get_shape_list(word_embed, expected_rank=[2,3]) print("==random_embed_shape==", random_embed_shape) pretrained_embed = kargs.get('pretrained_embed', None) if pretrained_embed is None: tf.logging.info("****** not apply prertained feature distillation *******") embed_loss = tf.constant(0.0) else: pretrain_embed_shape = bert_utils.get_shape_list(pretrained_embed, expected_rank=[2,3]) print("==pretrain_embed_shape==", pretrain_embed_shape) if random_embed_shape[-1] != pretrain_embed_shape[-1]: with tf.variable_scope(scope+"/embedding_proj", reuse=tf.AUTO_REUSE): proj_embed = tf.layers.dense(word_embed, pretrain_embed_shape[-1]) else: proj_embed = word_embed embed_loss = tf.reduce_mean(tf.reduce_mean(tf.square(proj_embed-pretrained_embed), axis=-1))/float(num_task) loss += embed_loss tf.logging.info("****** apply prertained feature distillation *******") if kargs.get('loss', 'contrastive_loss') == 'contrastive_loss': feature_a = tf.nn.l2_normalize(1e-20+pooled_feature_dict['feature_a'], axis=-1) feature_b = tf.nn.l2_normalize(1e-20+pooled_feature_dict['feature_b'], axis=-1) # feature_a = pooled_feature_dict['feature_a'] # feature_b = pooled_feature_dict['feature_b'] per_example_loss, logits = loss_utils.contrastive_loss(label_ids, feature_a, feature_b, kargs.get('margin', 1.0)) tf.logging.info("****** contrastive_loss *******") elif kargs.get('loss', 'contrastive_loss') == 'exponent_neg_manhattan_distance_mse': feature_a = tf.nn.l2_normalize(1e-20+pooled_feature_dict['feature_a'], axis=-1) feature_b = tf.nn.l2_normalize(1e-20+pooled_feature_dict['feature_b'], axis=-1) # feature_a = pooled_feature_dict['feature_a'] # feature_b = pooled_feature_dict['feature_b'] per_example_loss, logits = loss_utils.exponent_neg_manhattan_distance(label_ids, feature_a, feature_b, 'mse') tf.logging.info("****** exponent_neg_manhattan_distance_mse *******") else: feature_a = tf.nn.l2_normalize(1e-20+pooled_feature_dict['feature_a'], axis=-1) feature_b = tf.nn.l2_normalize(1e-20+pooled_feature_dict['feature_b'], axis=-1) # feature_a = pooled_feature_dict['feature_a'] # feature_b = pooled_feature_dict['feature_b'] per_example_loss, logits = loss_utils.contrastive_loss(label_ids, feature_a, feature_b, kargs.get('margin', 1.0)) tf.logging.info("****** contrastive_loss *******") # loss_mask = tf.cast(features["{}_loss_multipiler".format(task_type)], tf.float32) masked_per_example_loss = per_example_loss * loss_mask task_loss = tf.reduce_sum(masked_per_example_loss) / (1e-10+tf.reduce_sum(loss_mask)) loss += task_loss # with tf.variable_scope(scope+"/{}/classifier".format(task_type), reuse=task_layer_reuse): # feature_a = pooled_feature_dict['feature_a'] # feature_b = pooled_feature_dict['feature_a'] # logtis_feature = tf.concat([feature_a, feature_b], axis=-1) # (_, # cls_per_example_loss, # cls_logits) = classifier.classifier(model_config, # logtis_feature, # num_labels, # label_ids, # dropout_prob) # loss_mask = tf.cast(features["{}_loss_multipiler".format(task_type)], tf.float32) # masked_per_example_loss = cls_per_example_loss * loss_mask # task_loss = tf.reduce_sum(masked_per_example_loss) / (1e-10+tf.reduce_sum(loss_mask)) # loss += task_loss if mode == tf.estimator.ModeKeys.TRAIN: multi_task_config = kargs.get("multi_task_config", {}) if multi_task_config[task_type].get("lm_augumentation", False): print("==apply lm_augumentation==") masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = pretrain.get_masked_lm_output( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=model_reuse) masked_lm_loss_mask = tf.expand_dims(loss_mask, -1) * tf.ones((1, multi_task_config[task_type]["max_predictions_per_seq"])) masked_lm_loss_mask = tf.reshape(masked_lm_loss_mask, (-1, )) masked_lm_label_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_loss_mask *= tf.cast(masked_lm_label_weights, tf.float32) masked_lm_example_loss *= masked_lm_loss_mask# multiply task_mask masked_lm_loss = tf.reduce_sum(masked_lm_example_loss) / (1e-10+tf.reduce_sum(masked_lm_loss_mask)) loss += multi_task_config[task_type]["masked_lm_loss_ratio"]*masked_lm_loss masked_lm_label_ids = tf.reshape(masked_lm_ids, [-1]) print(masked_lm_log_probs.get_shape(), "===masked lm log probs===") print(masked_lm_label_ids.get_shape(), "===masked lm ids===") print(masked_lm_label_weights.get_shape(), "===masked lm mask===") lm_acc = build_accuracy(masked_lm_log_probs, masked_lm_label_ids, masked_lm_loss_mask) if kargs.get("task_invariant", "no") == "yes": print("==apply task adversarial training==") with tf.variable_scope(scope+"/dann_task_invariant", reuse=model_reuse): (_, task_example_loss, task_logits) = distillation_utils.feature_distillation(model.get_pooled_output(), 1.0, features["task_id"], kargs.get("num_task", 7), dropout_prob, True) masked_task_example_loss = loss_mask * task_example_loss masked_task_loss = tf.reduce_sum(masked_task_example_loss) / (1e-10+tf.reduce_sum(loss_mask)) loss += kargs.get("task_adversarial", 1e-2) * masked_task_loss tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) if mode == tf.estimator.ModeKeys.TRAIN: multi_task_config = kargs.get("multi_task_config", {}) if multi_task_config[task_type].get("lm_augumentation", False): print("==apply lm_augumentation==") masked_lm_pretrain_tvars = model_io_fn.get_params("cls/predictions", not_storage_params=not_storage_params) tvars.extend(masked_lm_pretrain_tvars) try: params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) except: print("==not count params==") # print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: acc = build_accuracy(logits, label_ids, loss_mask, loss_type=kargs.get('loss', 'contrastive_loss')) return_dict = { "loss":loss, "logits":logits, "task_num":tf.reduce_sum(loss_mask), "tvars":tvars, "positive_label":tf.reduce_sum(label_ids*loss_mask) } return_dict["{}_acc".format(task_type)] = acc if kargs.get("task_invariant", "no") == "yes": return_dict["{}_task_loss".format(task_type)] = masked_task_loss task_acc = build_accuracy(task_logits, features["task_id"], loss_mask) return_dict["{}_task_acc".format(task_type)] = task_acc if multi_task_config[task_type].get("lm_augumentation", False): return_dict["{}_masked_lm_loss".format(task_type)] = masked_lm_loss return_dict["{}_masked_lm_acc".format(task_type)] = lm_acc if kargs.get("embedding_distillation", True): return_dict["embed_loss"] = embed_loss*float(num_task) else: return_dict["embed_loss"] = task_loss if kargs.get("feature_distillation", True): return_dict["feature_loss"] = feature_loss*float(num_task) else: return_dict["feature_loss"] = task_loss return_dict["task_loss"] = task_loss return return_dict elif mode == tf.estimator.ModeKeys.EVAL: eval_dict = { "loss":loss, "logits":logits, "feature":model.get_pooled_output() } if kargs.get("adversarial", "no") == "adversarial": eval_dict["task_logits"] = task_logits return eval_dict return model_fn