#!/usr/bin/env python3 
# -*- coding: utf-8 -*- 



# Author: Xiaoy LI 
# Last update: 2019.04.23 
# First create: 2019.04.23 
# Description:
# bert_tagger.py 


import os 
import sys 


root_path = "/".join(os.path.realpath(__file__).split("/")[:-2])
if root_path not in sys.path:
    sys.path.insert(0, root_path)


import torch 
import torch.nn as nn 
from torch.nn import CrossEntropyLoss 


from layers.classifier import * 
from layers.bert_basic_model import * 
from layers.bert_layernorm import BertLayerNorm 




class BertTagger(nn.Module):
    def __init__(self, config, num_labels=5):
        super(BertTagger, self).__init__()
        self.num_labels = 5 

        bert_config = BertConfig.from_dict(config.bert_config.to_dict())
        self.bert = BertModel(bert_config)

        self.hidden_size = config.hidden_size 
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        self.classifier = MultiNonLinearClassifier(config.hidden_size, self.num_labels)
        self.bert = self.bert.from_pretrained(config.bert_model, )

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, 
        labels=None, input_mask=None):

        last_bert_layer, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, \
            output_all_encoded_layers=False)
        last_bert_layer = last_bert_layer.view(-1, self.hidden_size)
        last_bert_layer = self.dropout(last_bert_layer)
        logits = self.classifier(last_bert_layer) 

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            if input_mask is not None:
                masked_logits = torch.masked_select(logits, input_mask)
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            return loss 
        else:
            return logits