天天看點

pytorch 深度學習實踐 第8講 加載資料集第8講 加載資料集 dataset and dataloader

第8講 加載資料集 dataset and dataloader

pytorch學習視訊——B站視訊連結:《PyTorch深度學習實踐》完結合集_哔哩哔哩_bilibili

以下是視訊内容筆記以及小練習源碼,筆記純屬個人了解,如有錯誤歡迎路過的大佬指出 。

前言:dataset——構造資料集,資料集應該支援索引操作;dataloader——為訓練提供mini-batch資料

1. 基本概念了解

  • mini-batch的定義及作用

    将資料分成若幹個批次,周遊一個批次的資料集計算一次損失函數,然後計算函數對各個參數的梯度,更新參數。使用mini-batch可以有效解決鞍點問題,同時也降低了随機性(和訓練一個資料更新參數比較)。

  • epoch

    一個周期:所有樣本都進行了一次前饋、回報和更新的過程,也就是所有樣本都參與了一次訓練。

  • batch-size

    批量大小:訓練一次所需要的樣本數量,也就是一個mini-batch的大小。

  • iteration

    疊代次數:一個mini-batch經過多少次疊代把所有樣本訓練完一次,直覺上來看就是總的batch所包含的mini-batch的數量。

如圖所示:

pytorch 深度學習實踐 第8講 加載資料集第8講 加載資料集 dataset and dataloader
  • Dataloader

    shuffle=True表示打亂樣本順序;然後将樣本分成2個一組batch。如圖所示

    pytorch 深度學習實踐 第8講 加載資料集第8講 加載資料集 dataset and dataloader

2. 代碼說明

  • Dataset是一個抽象類,不能被執行個體化,需要 定義一個類繼承自 Dataset類——>DiabetesDataset類
  • 定義的DiabetesDataset類需要實作Dataset類的getitem和len方法
    • _getitem_ (魔法函數):為資料集增加索引,可以通過索引将資料集的某一個資料取出
    • _len_:擷取整個資料的長度/數量
    • _init_:初始化資料集時有兩種方式:第一種是直接将所有資料讀入記憶體,适用于資料容量不大的時候;第二種是将各個資料樣本的檔案名讀取到清單中,使用索引來讀取資料集,不一次性讀取所有資料。

源碼 dataset_and_dataloader.py

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

# 新定義一個類繼承自Dataset
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        # 直接取xy資料集的形狀的行,表示長度,即樣本數量
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[ :, :-1])
        self.y_data = torch.from_numpy(xy[ :, [-1]])

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len


# 執行個體化資料集對象
dataset = DiabetesDataset('../dataset/diabetes.csv.gz')

# 加載資料集
# dataset是資料集對象,batch-size是批次大小,shuffle表示是否打亂樣本順序,num_workers表示使用多線程
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()
        # self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x


model = Model()

criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

loss_list = []
if __name__ == '__main__':
    for epoch in range(100):
        for i, data in enumerate(train_loader, 0):
            # 1. 準備資料
            inputs, labels = data
            # 2. 前饋
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            # print(epoch, i, loss.item())
            # 3. 回報
            optimizer.zero_grad()
            optimizer.step()
        print(epoch, loss.item())
        loss_list.append(loss.item())

    print(loss_list)
    plt.plot(range(100), loss_list)
    plt.xlabel('epoch')
    plt.ylabel('cost')
    plt.show()
           

結果怎麼是這個鬼樣子:

pytorch 深度學習實踐 第8講 加載資料集第8講 加載資料集 dataset and dataloader

(以上代碼已經在pycharm上經過測試,筆記純屬個人了解,如有錯誤勿介或者歡迎路過的大佬指出 嘻嘻嘻。)

——未完待續……