天天看點

Transformers學習筆記2. HuggingFace資料集Datasets

作者:程式設計圈子
Transformers學習筆記2. HuggingFace資料集Datasets

一、簡介

Datasets庫是Hugging Face的一個重要的資料集庫。 當需要微調一個模型的時候,需要進行下面操作:

  1. 下載下傳資料集
  2. 使用Dataset.map() 預處理資料
  3. 加載和計算名額

    可以在官網來搜尋資料集:

    https://huggingface.co/datasets

二、操作

1. 下載下傳資料集

使用的示例資料集:

Transformers學習筆記2. HuggingFace資料集Datasets
from datasets import load_dataset

# 加載資料
dataset = load_dataset(path='seamew/ChnSentiCorp', split='train')

print(dataset)
           

列印結果:

Dataset({
    features: ['text', 'label'],
    num_rows: 9600
})
{'text': '選擇珠江花園的原因就是友善,有電動扶梯直接到達海邊,周圍餐館、食廊、商場、超市、攤位一應俱全。酒店裝修一般,但還算整潔。 泳池在大堂的屋頂,是以很小,不過女兒倒是喜歡。 包的早餐是西式的,還算豐富。 服務嗎,一般', 'label': 1}
           

2. 常用函數

(1)排序

sortData = dataset.sort('label')           
Transformers學習筆記2. HuggingFace資料集Datasets

(2)打亂順序

shuffleData = sortData.shuffle(seed=20);           

(3)選擇函數

從資料集中取出某些指定的部分。

dataset.select([0,1,2,3])           

(4)過濾

def filter(data):
    return data['text'].startswith('1')
b = dataset.filter(filter)           
Transformers學習筆記2. HuggingFace資料集Datasets

(5)切分資料集

dataset.train_test_split(test_size=0.1)           

把資料集切分,10%為測試集。

Transformers學習筆記2. HuggingFace資料集Datasets

(6)分桶

把資料集均數若幹份,取其中的第幾份。

dataset.shard(num_shards=5, index=0)           

(7)列重命名

c = a.rename_column('text', 'newColumn')           

(8)列删除

d = c.remove_columns(['newColumn'])           

(9)資料集轉換

set_format函數用來實作與其它庫資料格式的轉換;

# 轉為PyTorch資料集格式 
dataset.set_format(type='torch', columns=['label'])
# 轉為Pandas格式 
dataset.set_format(type='pandas', columns=['label'])           

(10)map函數

周遊資料,對每個資料進行處理

def handler(data):
	data['text'] = 'Prefix' + data['text']
	return data

datasetMap = dataset.map(handler)           

(11)資料的儲存和加載

dataset.save_to_disk('./')           
Transformers學習筆記2. HuggingFace資料集Datasets
from datasets import load_from_disk
dataset = load_from_disk('./')           

3. 評價名額 Evaluate

安裝Evaluate庫:

pip install evaluate           

(1)加載

import evaluate
accuracy = evaluate.load("accuracy")           

(2)從社群加載子產品

element_count = evaluate.load("lvwerra/element_count", module_type="measurement")
           

(3)列出可用子產品

evaluate.list_evaluation_modules(
  module_type="comparison",
  include_community=False,
  with_details=True)
           
Transformers學習筆記2. HuggingFace資料集Datasets

(4)子產品屬性

屬性 描述
description 評估子產品說明
citation 用于引用的 BibTex 字元串(如果可用)。
features 定義輸入格式的對象的特征
inputs_description 說明
homepage 子產品的首頁
license 子產品的許可證
codebase_urls 子產品代碼連結
reference_urls 其他引用網址

(5)計算,直接調用函數計算

# 評估值正确率有一半
accuracy.compute(references=[0,1,0,1], predictions=[1,0,0,1])
# 輸出
{'accuracy': 0.5}           

(6)計算單個或一批名額

for ref, pred in zip([0,1,0,1], [1,0,0,1]):
    accuracy.add(references=ref, predictions=pred)
accuracy.compute()           

輸出:

{'accuracy': 0.5}           

批添加:

for refs, preds in zip([[0,1],[0,1]], [[1,0],[0,1]]):
    accuracy.add_batch(references=refs, predictions=preds)
accuracy.compute()           

(7)可視化

import evaluate
from evaluate.visualization import radar_plot

data = [
   {"accuracy": 0.99, "precision": 0.8, "f1": 0.95, "latency_in_seconds": 33.6},
   {"accuracy": 0.98, "precision": 0.87, "f1": 0.91, "latency_in_seconds": 11.2},
   {"accuracy": 0.98, "precision": 0.78, "f1": 0.88, "latency_in_seconds": 87.6},
   {"accuracy": 0.88, "precision": 0.78, "f1": 0.81, "latency_in_seconds": 101.6}
   ]
model_names = ["Model 1", "Model 2", "Model 3", "Model 4"]
plot = radar_plot(data=data, model_names=model_names)
plot.show()
           
Transformers學習筆記2. HuggingFace資料集Datasets

繼續閱讀