天天看點

【Pytorch】nn.Module類

一、繼承nn.Module類并自定義層

我們要利用pytorch提供的很多便利的方法,則需要将很多自定義操作封裝成nn.Module類。

首先,簡單實作一個Mylinear類:

from torch import nn

# Mylinear繼承Module
class Mylinear(nn.Module):
    # 傳入輸入次元和輸出次元
    def __init__(self,in_d,out_d):
        # 調用父類構造函數
        super(Mylinear,self).__init__()
        # 使用Parameter類将w和b封裝,這樣可以通過nn.Module直接管理,并提供給優化器優化
        self.w = nn.Parameter(torch.randn(out_d,in_d))
        self.b = nn.Parameter(torch.randn(out_d))

    # 實作forward函數,該函數為預設執行的函數,即計算過程,并将輸出傳回
    def forward(self, x):
        x = [email protected].w.t() + self.b
        return x
           

這樣就可以将我們自定義的Mylinear加入整個網絡

# 網絡結構
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()

        self.model = nn.Sequential(
            #nn.Linear(784, 200),
            Mylinear(784,200),
            nn.BatchNorm1d(200, eps=1e-8),
            nn.LeakyReLU(inplace=True),
            #nn.Linear(200, 200),
            Mylinear(200, 200),  
            nn.BatchNorm1d(200, eps=1e-8),
            nn.LeakyReLU(inplace=True),
            #nn.Linear(200, 10),
            Mylinear(200,10),
            nn.LeakyReLU(inplace=True)
        )
           

我們可以看出,MLP網絡實際上也是繼承自Module,這就說明了,nn.Module實際上可以實作一個嵌套的結構,我們的整個網絡就是由一個嵌套的樹形結構組成的。例如:

# Mylinear繼承Module
class Mylinear(nn.Module):
    # 傳入輸入次元和輸出次元
    def __init__(self, in_d, out_d):
        # 調用父類構造函數
        super(Mylinear, self).__init__()
        # 使用Parameter類将w和b封裝,這樣可以通過nn.Module直接管理,并提供給優化器優化
        self.w = nn.Parameter(torch.randn(out_d, in_d))
        self.b = nn.Parameter(torch.randn(out_d))

    # 實作forward函數,該函數為預設執行的函數,即計算過程,并将輸出傳回
    def forward(self, x):
        x = x @ self.w.t() + self.b
        return x


# 将幾個nn.Module元件綜合成一個
class Mylayer(nn.Module):
    def __init__(self, in_d, out_d):
        super(Mylayer, self).__init__()
        # 包含一個全連接配接層,一個BN層,一個Leaky Relu層
        self.lin = Mylinear(in_d, out_d)
        self.bn = nn.BatchNorm1d(out_d, eps=1e-8)
        self.lrelu = nn.LeakyReLU(inplace=True)

    # 按順序跑一遍3種網絡,傳回最終結果
    def forward(self, x):
        x = self.lin(x)
        x = self.bn(x)
        x = self.lrelu(x)
        return x


# 網絡結構
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()

        self.model = nn.Sequential(
            Mylayer(784, 200),
            Mylayer(200, 200),
            # nn.Linear(200, 10),
            Mylinear(200, 10),
            nn.LeakyReLU(inplace=True)
        )
           

上述代表表示的結構如下圖所示:

【Pytorch】nn.Module類

其中所有的類都繼承自nn.Module,從前往後是嵌套的關系。在上述代碼中,真正做計算的是橙色部分1-8,而其他的都隻是作為封裝。其中nn.Sequential、nn.BatchNorm1d、nn.LeakyReLU是pytorch提供的類,Mylinear和Mylayer是我們自己封裝的類。

二、實作一個常用類Flatten類

Flatten就是将2D的特征圖壓扁為1D的特征向量,用于全連接配接層的輸入。

# Flatten繼承Module
class Flatten(nn.Module):
    # 構造函數,沒有什麼要做的
    def __init__(self):
        # 調用父類構造函數
        super(Flatten, self).__init__()

    # 實作forward函數
    def forward(self, input):
        # 儲存batch次元,後面的次元全部壓平,例如輸入是28*28的特征圖,壓平後為784的向量
        return input.view(input.size(0), -1)
           

三、nn.Module類的作用

1.便于儲存模型:

# 每隔N epoch儲存一次模型
torch.save(net.state_dict(),'ckpt_n_epoch.mdl')
# 下次訓練時可以直接導入接着訓練
net.load_state_dict(torch.load('ckpt_n_epoch.mdl'))
           

2.友善切換train和val模式

### 不同模式對于某些層的操作時不同的,例如BN,dropout層等
# 切換到train模式
net.train()
# 切換到validation模式
net.eval()
           

3.友善将網絡轉移到GPU上

# 定義GPU裝置
device = torch.device('cuda')
# 将網絡轉移到GPU,注意to函數傳回的是net的引用(引用是不變的)
# 不同的是net中的參數都轉移到GPU上去了
net.to(device)
    
# 不同于參數直接轉移,轉移後的w2(在GPU上)和轉移前的w(在CPU上)兩者完全是不一樣的
# 我們要使之在GPU上運作,則必須使用w2
#w2 = w.to(device)
           

4.友善檢視各層參數

# 擷取由每一層參數組成的清單
para_list = list(net.parameters())
# 擷取一個(name,每層參數)的tuple組成的清單
para_named_list = list(net.named_parameters())
# 擷取一個{'model.0.weight': 參數,'model.0.bias': 參數, 'model.1.weight': 參數}
para_named_dict = dict(net.named_parameters())
           

四、資料增強

torchvision提供了很友善的資料預處理工具,資料增強可以一次性搞定

from torchvision import datasets, transforms

train_data_trans = datasets.MNIST('../data', train=True, download=True,
                            transform=transforms.Compose([
                                # 水準翻轉,50%執行
                                transforms.RandomHorizontalFlip(),
                                # 垂直翻轉,50%執行
                                transforms.RandomVerticalFlip(),
                                # 随機旋轉範圍在正負15°之間,也可以寫(-15,15)
                                transforms.RandomRotation(15),
                                # 旋轉範圍在90-270之間
                                #transforms.RandomRotation([90,270]),
                                # 将圖檔方縮放到指定大小
                                transforms.Resize([32,32]),
                                # 随機剪裁圖檔到指定大小
                                transforms.RandomCrop([28,28]),

                                transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))
                            ]))
           

如果pytorch沒有提供需要的預處理類,我們可以參照源碼仿造寫一個自定義處理的類來進行處理。例如對圖檔添加白噪聲,按通道變換顔色等等。

*文章來源:*https://www.cnblogs.com/leokale-zz/p/11294912.html

繼續閱讀