天天看點

阿裡中文預訓練模型泛化能力挑戰賽 Task1背景Baseline報錯整理啟動訓練啟動評估

阿裡中文預訓練模型泛化能力挑戰賽 Task 1

  • 背景
  • Baseline報錯整理
  • 啟動訓練
  • 啟動評估

背景

賽題以自然語言處理為背景,要求選手通過算法實作泛化能力強的中文預訓練模型。通過這道賽題可以引導大家更好地了解預訓練模型的運作機制,探索深層次的模型建構和模型訓練,而不僅僅是針對特定任務進行簡單微調。

Baseline報錯整理

首先來看import這塊的報錯

import sys
import os
import tensorflow as tf
from easytransfer import base_model, Config, FLAGS
from easytransfer import layers
from easytransfer import model_zoo
from easytransfer import preprocessors
from easytransfer.datasets import TFRecordReader
from easytransfer.losses import softmax_cross_entropy
from sklearn.metrics import classification_report
import numpy as np
           

首先注意這裡需要額外導入的package包括以下兩個

pip install tensorflow-gpu --user ##這樣不容易爆權限錯誤

pip install easytransfer

輸出結果如下:

阿裡中文預訓練模型泛化能力挑戰賽 Task1背景Baseline報錯整理啟動訓練啟動評估

這個是因為tensorflow2.1已經沒有tf.logging了

逛壇子得知解決方法如下:

将tf.logging替換成tf.compat.v1.logging

但還是報錯 是以這裡我還是老老實實根據環境配置的tips

  • tensorflow-gpu 1.12.3
  • easytransfer 0.1.2

實際情況是沒有得到解決

阿裡中文預訓練模型泛化能力挑戰賽 Task1背景Baseline報錯整理啟動訓練啟動評估

然後在群裡看到水哥說baseline有用tf2的tf1.4的有用pytorch的

時間有限這裡沒有cover

class Application(base_model):
    def __init__(self, **kwargs):
        super(Application, self).__init__(**kwargs)
        self.user_defined_config = kwargs["user_defined_config"]

    def build_logits(self, features, mode=None):

        preprocessor = preprocessors.get_preprocessor(self.pretrain_model_name_or_path,
                                                      user_defined_config=self.user_defined_config)

        model = model_zoo.get_pretrained_model(self.pretrain_model_name_or_path)

        global_step = tf.train.get_or_create_global_step()

        tnews_dense = layers.Dense(15,
                     kernel_initializer=layers.get_initializer(0.02),
                     name='tnews_dense')

        ocemotion_dense = layers.Dense(7,
                             kernel_initializer=layers.get_initializer(0.02),
                             name='ocemotion_dense')

        ocnli_dense = layers.Dense(3,
                             kernel_initializer=layers.get_initializer(0.02),
                             name='ocnli_dense')

        input_ids, input_mask, segment_ids, label_ids = preprocessor(features)

        outputs = model([input_ids, input_mask, segment_ids], mode=mode)
        pooled_output = outputs[1]

        if mode == tf.estimator.ModeKeys.TRAIN:
            pooled_output = tf.nn.dropout(pooled_output, keep_prob=0.9)

        logits = tf.case([(tf.equal(tf.mod(global_step, 3), 0), lambda: tnews_dense(pooled_output)),
                          (tf.equal(tf.mod(global_step, 3), 1), lambda: ocemotion_dense(pooled_output)),
                          (tf.equal(tf.mod(global_step, 3), 2), lambda: ocnli_dense(pooled_output)),
                          ], exclusive=True)

        if mode == tf.estimator.ModeKeys.PREDICT:
            ret = {
                "tnews_logits": tnews_dense(pooled_output),
                "ocemotion_logits": ocemotion_dense(pooled_output),
                "ocnli_logits": ocnli_dense(pooled_output),
                "label_ids": label_ids
            }
            return ret

        return logits, label_ids

    def build_loss(self, logits, labels):
        global_step = tf.train.get_or_create_global_step()
        return tf.case([(tf.equal(tf.mod(global_step, 3), 0), lambda : softmax_cross_entropy(labels, 15, logits)),
                      (tf.equal(tf.mod(global_step, 3), 1), lambda : softmax_cross_entropy(labels, 7, logits)),
                      (tf.equal(tf.mod(global_step, 3), 2), lambda : softmax_cross_entropy(labels, 3, logits))
                      ], exclusive=True)

    def build_predictions(self, output):
        tnews_logits = output['tnews_logits']
        ocemotion_logits = output['ocemotion_logits']
        ocnli_logits = output['ocnli_logits']

        tnews_predictions = tf.argmax(tnews_logits, axis=-1, output_type=tf.int32)
        ocemotion_predictions = tf.argmax(ocemotion_logits, axis=-1, output_type=tf.int32)
        ocnli_predictions = tf.argmax(ocnli_logits, axis=-1, output_type=tf.int32)

        ret_dict = {
            "tnews_predictions": tnews_predictions,
            "ocemotion_predictions": ocemotion_predictions,
            "ocnli_predictions": ocnli_predictions,
            "label_ids": output['label_ids']
        }
        return ret_dict
           

啟動訓練

config = Config(mode="train", config_json=config_json)
app = Application(user_defined_config=config)

train_reader = MultiTaskTFRecordReader(input_glob=app.train_input_fp,
                                           is_training=True,
                                           input_schema=app.input_schema,
                                           batch_size=app.train_batch_size)

app.run_train(reader=train_reader)
           

啟動評估

config = Config(mode="predict", config_json=config_json)
app = Application(user_defined_config=config)
    
predict_reader = MultiTaskTFRecordReader(input_glob=app.predict_input_fp,
                                           is_training=False,
                                           input_schema=app.input_schema,
                                           batch_size=app.predict_batch_size)

ckpts = set()
with tf.gfile.GFile(os.path.join(app.config.model_dir, "checkpoint"), mode='r') as reader:
    for line in reader:
        line = line.strip()
        line = line.replace("oss://", "")
        ckpts.add(int(line.split(":")[1].strip().replace("\"", "").split("/")[-1].replace("model.ckpt-", "")))

best_macro_f1 = 0
best_ckpt = None
for ckpt in sorted(ckpts):
    checkpoint_path = os.path.join(app.config.model_dir, "model.ckpt-" + str(ckpt))
    tf.logging.info("checkpoint_path is {}".format(checkpoint_path))
    all_tnews_preds = []
    all_tnews_gts = []
    all_ocemotion_preds = []
    all_ocemotion_gts = []
    all_ocnli_preds = []
    all_ocnli_gts = []
    for i, output in enumerate(app.run_predict(reader=predict_reader, checkpoint_path=checkpoint_path)):
        label_ids = np.squeeze(output['label_ids'])
        if i%3 ==0:
            tnews_predictions = output['tnews_predictions']
            all_tnews_preds.extend(tnews_predictions.tolist())
            all_tnews_gts.extend(label_ids.tolist())
        elif i%3==1:
            ocemotion_predictions = output['ocemotion_predictions']
            all_ocemotion_preds.extend(ocemotion_predictions.tolist())
            all_ocemotion_gts.extend(label_ids.tolist())
        elif i%3==2:
            ocnli_predictions = output['ocnli_predictions']
            all_ocnli_preds.extend(ocnli_predictions.tolist())
            all_ocnli_gts.extend(label_ids.tolist())

        if i == 20:
            break

    tnews_report = classification_report(all_tnews_gts, all_tnews_preds, digits=4)
    tnews_macro_avg_f1 = float(tnews_report.split()[-8])

    ocemotion_report = classification_report(all_ocemotion_gts, all_ocemotion_preds, digits=4)
    ocemotion_macro_avg_f1 = float(ocemotion_report.split()[-8])

    ocnli_report = classification_report(all_ocnli_gts, all_ocnli_preds, digits=4)
    ocnli_macro_avg_f1 = float(ocnli_report.split()[-8])

    macro_f1 = (tnews_macro_avg_f1 + ocemotion_macro_avg_f1 + ocnli_macro_avg_f1)/3.0
    if macro_f1 >= best_macro_f1:
        best_macro_f1 = macro_f1
        best_ckpt = ckpt

tf.logging.info("best ckpt {}, best best_macro_f1 {}".format(best_ckpt, best_macro_f1))
           

繼續閱讀