天天看點

python多線程結合DataLoader加載資料

在模型訓練過程中,通常大家都會将注意力集中在模型加速以及提升GPU使用率,但是有時我們的耗時瓶頸也會在讀取資料上,gpu處理太快,反而cpu喂資料跟不上。當然架構也會提供一些資料讀取加速方案,比如tensorflow的 tf.data.TFRecordDataset,pytorch的DataLoader使用num_workers參數内部采用多線程方案等,還有些代碼是将所有資料制作到一個二進制檔案讀入記憶體,然後從記憶體中快速讀取資料,但是這種方案無法處理大資料項目。

tensorflow的record也需要先生成record檔案格式然後讀取,pytorch的DataLoader在設定num_workers時特别在windows中有些版本設定為非0會存在一些問題,本文介紹自己使用python的多線程來處理資料的一種方案,然後結合pytorch的Dataset和DataLoader擷取資料,供大家參考。

一 建立buffer類

先建立一個buffer類,其中讀寫資料需要使用兩個鎖

import threading
import random

class Buffer:
    def __init__(self, size):
        self.size = size
        self.buffer = []
        self.lock = threading.Lock()
        self.has_data = threading.Condition(self.lock)
        self.has_pos = threading.Condition(self.lock)

    def get_size(self):
        return self.size

    def get(self):
        with self.has_data:
            while len(self.buffer) == 0:
                self.has_data.wait()
            result = self.buffer[0]
            # print("get buffer size", len(self.buffer))
            del self.buffer[0]
            self.has_pos.notify_all()
        return result

    def put(self, data):
        with self.has_pos:
            while len(self.buffer) >= self.size:
                self.has_pos.wait()
            self.buffer.append(data)
            self.has_data.notify_all()

# test
def get():
    while True:
        get_data = buffer.get()
# test
def put():
    while True:
        data = random.randint(0, 9)
        buffer.put(a)           

複制

buffer類參考:https://cloud.tencent.com/developer/article/1724559

二 建立Dataset

生成一個DataReader建立多線程寫資料,以及單線程讀資料。以下為多線程的關鍵代碼

class DataReader:
    def __init__(self, max_buffer_size=5000):
        self.audio_files = files_to_list(training_files)
        random.shuffle(self.audio_files)
        self.buffer = Buffer(max_buffer_size)
    # 消費資料
    def comsume(self):
        while True:
            result = self.buffer.get()
    # 生産資料 
    def produce(self):
        while True:
            global index
            index += 1
            if index >= len(self.audio_files)-1:
                index = 0
            start = time.time()
            file = self.audio_files[index]
            audio = load_wav(file)
            end = time.time()
            self.buffer.put(audio)

    def run_produce(self, thread_num=16):
        # 多線程生産
        for _ in range(thread_num):
            th = threading.Thread(target=self.produce)
            th.start()

    def get_item(self, index):
        result = self.buffer.get()
        return result
                  

複制

下面使用一個Dataset來使用DataReader擷取資料

class AudioDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.data_reader = DataReader()
        self.data_reader.run_produce()
        
    def __getitem__(self, index):
        # 從buffer中擷取一個資料
        start = time.time()
        audio = self.data_reader.get_item(index)
        # 進行資料處理
        ...
        audio = torch.from_numpy(audio).float()
        end = time.time()
        # print("get item time cost", (end - start) * 1000, audio.shape)
        return audio.unsqueeze(0)
    def __len__(self):
        return len(self.audio_files)           

複制

三 建立DataLoader

最後就可以通過DataLoader從DataSet中循環擷取batch資料輸入到模型進行訓練了

dataset = AudioDataset()
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=pin_memory,
)           

複制