在模型訓練過程中,通常大家都會将注意力集中在模型加速以及提升GPU使用率,但是有時我們的耗時瓶頸也會在讀取資料上,gpu處理太快,反而cpu喂資料跟不上。當然架構也會提供一些資料讀取加速方案,比如tensorflow的 tf.data.TFRecordDataset,pytorch的DataLoader使用num_workers參數内部采用多線程方案等,還有些代碼是将所有資料制作到一個二進制檔案讀入記憶體,然後從記憶體中快速讀取資料,但是這種方案無法處理大資料項目。
tensorflow的record也需要先生成record檔案格式然後讀取,pytorch的DataLoader在設定num_workers時特别在windows中有些版本設定為非0會存在一些問題,本文介紹自己使用python的多線程來處理資料的一種方案,然後結合pytorch的Dataset和DataLoader擷取資料,供大家參考。
一 建立buffer類
先建立一個buffer類,其中讀寫資料需要使用兩個鎖
import threading
import random
class Buffer:
def __init__(self, size):
self.size = size
self.buffer = []
self.lock = threading.Lock()
self.has_data = threading.Condition(self.lock)
self.has_pos = threading.Condition(self.lock)
def get_size(self):
return self.size
def get(self):
with self.has_data:
while len(self.buffer) == 0:
self.has_data.wait()
result = self.buffer[0]
# print("get buffer size", len(self.buffer))
del self.buffer[0]
self.has_pos.notify_all()
return result
def put(self, data):
with self.has_pos:
while len(self.buffer) >= self.size:
self.has_pos.wait()
self.buffer.append(data)
self.has_data.notify_all()
# test
def get():
while True:
get_data = buffer.get()
# test
def put():
while True:
data = random.randint(0, 9)
buffer.put(a)
複制
buffer類參考:https://cloud.tencent.com/developer/article/1724559
二 建立Dataset
生成一個DataReader建立多線程寫資料,以及單線程讀資料。以下為多線程的關鍵代碼
class DataReader:
def __init__(self, max_buffer_size=5000):
self.audio_files = files_to_list(training_files)
random.shuffle(self.audio_files)
self.buffer = Buffer(max_buffer_size)
# 消費資料
def comsume(self):
while True:
result = self.buffer.get()
# 生産資料
def produce(self):
while True:
global index
index += 1
if index >= len(self.audio_files)-1:
index = 0
start = time.time()
file = self.audio_files[index]
audio = load_wav(file)
end = time.time()
self.buffer.put(audio)
def run_produce(self, thread_num=16):
# 多線程生産
for _ in range(thread_num):
th = threading.Thread(target=self.produce)
th.start()
def get_item(self, index):
result = self.buffer.get()
return result
複制
下面使用一個Dataset來使用DataReader擷取資料
class AudioDataset(torch.utils.data.Dataset):
def __init__(self):
self.data_reader = DataReader()
self.data_reader.run_produce()
def __getitem__(self, index):
# 從buffer中擷取一個資料
start = time.time()
audio = self.data_reader.get_item(index)
# 進行資料處理
...
audio = torch.from_numpy(audio).float()
end = time.time()
# print("get item time cost", (end - start) * 1000, audio.shape)
return audio.unsqueeze(0)
def __len__(self):
return len(self.audio_files)
複制
三 建立DataLoader
最後就可以通過DataLoader從DataSet中循環擷取batch資料輸入到模型進行訓練了
dataset = AudioDataset()
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
)
複制