在pytorch 中一些常用的功能都已經被封裝成了子產品,是以我們隻需要繼承并重寫部分函數即可。首先介紹一下本文最終希望實作的目标, 對本地的一維資料 (1xn)的ndarry 進行一個多分類,資料集為mn的資料,标簽為m1的數組。下面是結合代碼記錄一下踩坑過程。
繼承Dataset類,可以看到我這裡重寫了三個函數,init 函數用于載入numpy資料并将其轉化為相應的tensor,__gititem__函數用于定義訓練時會傳回的單個資料與标簽,__len__表示資料數量m。
通過繼承nn.Module來自定義神經網絡
其中__init__函數來自定義定義我們需要的網絡參數,這裡我們block1 的in_channels為1,輸出參數可根據需要自己設定,但而且目前層的輸出channel應該和下一層的輸入channel相同,
注意:MaxPool1d的inchannel需要自己計算一下,當然如果你不想算,可以給個參數直接運作,看報錯資訊的提示
__forward__ 函數定義了網絡的連接配接方式,注意此處應傳回x。
主程式。為了更好的說明,先放一下主程式。這裡的程式是已經載入了資料的,data是mn 數組,label為m1數組。
執行個體化DataLoader的第一個參數是Dataset的執行個體,通過DataLoader,其功能是為下文訓練和測試過程提供資料。
定義訓練階段,從DataLoader中取出資料,這裡X,y分别為batch_sizen,batch_size1的資料。
首先要進行一個調整,将X調整為batch_size1n的float,設定float的轉化過程放在Dataset的初始化函數裡完成了
注意:如果沒有這一步會報錯
期望是long但得到了float的錯誤。(雖然我也不明白為啥錯誤不是期望float...)
y為1*batch的數組并轉化成long(這裡y的形式可能與損失函數有關)
定義測試過程(同上)
大體流程就是這些,最後記得修改加入輸出語句與儲存模型等操作。