天天看点

pytorch加载bert权重与转换成onnx

由于bert是google创造的模型,所以大部分都是用tensorflow编写。自从有了transformer库,pytorch版本的模型加载也简单了许多。

权重文件,如图所示:

pytorch加载bert权重与转换成onnx

config.json是bert的配置,包括hidden_size,drop此类超参,如下所示:

{
  "attention_probs_dropout_prob": 0.1, 
  "directionality": "bidi", 
  "hidden_act": "gelu", 
  "hidden_dropout_prob": 0.1, 
  "hidden_size": 768, 
  "initializer_range": 0.02, 
  "intermediate_size": 3072, 
  "max_position_embeddings": 512, 
  "num_attention_heads": 12, 
  "num_hidden_layers": 12, 
  "pooler_fc_size": 768, 
  "pooler_num_attention_heads": 12, 
  "pooler_num_fc_layers": 3, 
  "pooler_size_per_head": 128, 
  "pooler_type": "first_token_transform", 
  "type_vocab_size": 2, 
  "vocab_size": 21128
}
           

bin则是计算图和权重构成的2进制文件。

import os
import tempfile
import numpy as np
from onnxruntime import InferenceSession
import torch
from torch import nn
from transformers import BertPreTrainedModel, BertModel, BertForSequenceClassification

torch.set_grad_enabled(False)
class bert_model(BertPreTrainedModel):
    def __init__(self, config):
        super(bert_model, self).__init__(config)
        self.bert = BertForSequenceClassification(config)
        
    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        return bert_output.logits


def export_to_onnx(task, model_dir, output_model_name):
    if task == 1:
        model = bert_model.from_pretrained(model_dir, num_labels=2)
        dummy_input = {
            "input_ids": torch.tensor([[101, 2769, 1372, 2682, 2127, 102, 0]]),
            "attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 0]]),
            "token_type_ids": torch.tensor([[0, 0, 0, 0, 0, 0, 0]]),
        }
        dynamic_axes = {
            'input_ids': [0, 1],
            'attention_mask': [0, 1],
            'token_type_ids': [0, 1],
        }
        output_names = ['start_logits', 'end_logits']
    
    with tempfile.NamedTemporaryFile() as fp:
        torch.onnx.export(model,
                          args=tuple(dummy_input.values()),
                          f=fp,
                          input_names=list(dummy_input),
                          output_names=output_names,
                          dynamic_axes=dynamic_axes,
                          opset_version=10)
        sess = InferenceSession(fp.name)
        model.eval()
        if task == 1:
            old_start_logits, old_end_logits = model(**dummy_input.copy())
            new_start_logits, new_end_logits = sess.run(
                output_names=output_names,
                input_feed={key: value.numpy() for key, value in dummy_input.items()})
            np.testing.assert_almost_equal(old_start_logits.numpy(), new_start_logits, 5)
            np.testing.assert_almost_equal(old_end_logits.numpy(), new_end_logits, 5)
           

环境配置列表:

torch == 1.8.1

transformers == 4.6.1

onnxruntime == 1.8.0

加载只需要把bin文件与json合成一个文件夹,加载文件夹路径即可。

转换成onnx,由于输出有三个embeddings,torch.onnx.export中args使用tuple打包张量(tuple of arguments),input_names按顺序分配名称到图中的输入节点(list of strings)。

继续阅读