天天看點

【實戰】一鍵訓練的公司名實體識别NER 基于Bert+crf

公司名實體識别

    • 模型定義
    • 模型訓練

實驗用的資料可以點選這裡

完整代碼:github或gitee

模型定義

from transformers.models.bert.modeling_bert import *
from torch.nn.utils.rnn import pad_sequence
from torchcrf import CRF
from transformers import (
  BertTokenizerFast,
  AutoModel,
)
from transformers import BertTokenizer, BertModel

class BertNER(BertPreTrainedModel):
    def __init__(self, config):
        super(BertNER, self).__init__(config)
        self.num_labels = config.num_labels

        self.bert = AutoModel.from_pretrained('ckiplab/albert-tiny-chinese')
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        # lstm_embedding_size=128,
        # lstm_dropout_prob=0.5
        # self.bilstm = nn.LSTM(
        #     input_size=lstm_embedding_size,  # 1024
        #     hidden_size=config.hidden_size // 2,  # 1024
        #     batch_first=True,
        #     num_layers=2,
        #     dropout=lstm_dropout_prob,  # 0.5
        #     bidirectional=True
        # )
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.crf = CRF(config.num_labels, batch_first=True)

        self.init_weights()

    def forward(self, input_data, token_type_ids=None, attention_mask=None, labels=None,
                position_ids=None, inputs_embeds=None, head_mask=None):
        input_ids, input_token_starts = input_data
        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            position_ids=position_ids,
                            head_mask=head_mask,
                            inputs_embeds=inputs_embeds)
        sequence_output = outputs[0]

        # 去除[CLS]标簽等位置,獲得與label對齊的pre_label表示
        origin_sequence_output = [layer[starts.nonzero().squeeze(1)]
                                  for layer, starts in zip(sequence_output, input_token_starts)]
        # 将sequence_output的pred_label次元padding到最大長度
        padded_sequence_output = pad_sequence(origin_sequence_output, batch_first=True)
        # dropout pred_label的一部分feature
        padded_sequence_output = self.dropout(padded_sequence_output)
        # lstm_output, _ = self.bilstm(padded_sequence_output)
        # 得到判别值
        logits = self.classifier(padded_sequence_output)
        # logits = padded_sequence_output
        outputs = (logits,)
        if labels is not None:#如果标簽存在就計算loss,否則就是輸出線性層對應的結果,這樣便于通過後續crf的decode函數解碼得到預測結果。
            loss_mask = labels.gt(-1)
            loss = self.crf(logits, labels, loss_mask) * (-1)
            outputs = (loss,) + outputs

        # contain: (loss), scores
        return outputs

           

模型訓練

def train(train_loader, dev_loader, model, optimizer, scheduler, model_dir):
    """train the model and test model performance"""
    # reload weights from restore_dir if specified
    if model_dir is not None and config.load_before:
        model = BertNER.from_pretrained(model_dir)
        model.to(config.device)
        logging.info("--------Load model from {}--------".format(model_dir))
    best_val_f1 = 0.0
    patience_counter = 0
    # start training
    for epoch in range(1, config.epoch_num + 1):
        train_epoch(train_loader, model, optimizer, scheduler, epoch)
        val_metrics = evaluate(dev_loader, model, mode='dev')
        val_f1 = val_metrics['f1']
        logging.info("Epoch: {}, dev loss: {}, f1 score: {}".format(epoch, val_metrics['loss'], val_f1))
        improve_f1 = val_f1 - best_val_f1
        if improve_f1 > 1e-5:
            best_val_f1 = val_f1
            model_dir_new = config.model_dir + str(val_f1)[:6] +'_' + str(val_metrics['loss'])[:6] +'_' + str(epoch) + '/'
            if not os.path.exists(model_dir_new):               #判斷檔案夾是否存在
                os.makedirs(model_dir_new)                       #建立檔案夾
            model.save_pretrained(model_dir_new)
            logging.info("--------Save best model!--------")
            if improve_f1 < config.patience:
                patience_counter += 1
            else:
                patience_counter = 0
        else:
            patience_counter += 1
        # Early stopping and logging best f1
        if (patience_counter >= config.patience_num and epoch > config.min_epoch_num) or epoch == config.epoch_num:
            logging.info("Best val f1: {}".format(best_val_f1))
            break
    logging.info("Training Finished!")
           

繼續閱讀