天天看點

pytorch Sequential ModuleList與python list建構網絡比較

參考别人部落格,自己實驗并總結,供參考。

import torch
import torch.nn as nn

"""
比較了nn.Sequential,ModuleList與python list
建構網絡的差別
"""


class net1(nn.Module):
    def __init__(self):
        super(net2, self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(3, 5, 5, 0))
        # 注意,直接用list的建構方式,層與參數均不會出現在網絡中
        self.linears = [nn.Linear(10, 10) for i in range(2)]

    def forward(self, x):
        x = self.conv(x)
        for m in self.linears:
            x = m(x)
        return x


class net2(nn.Module):
    def __init__(self):
        super(net2, self).__init__()
        self.conv = nn.Sequential(nn.Conv2d(3, 5, 5, 0))
        # 單獨的層去建構則會被自動加入網絡結構
        self.linear1 = nn.Linear(10, 10)
        self.linear2 = nn.Linear(10, 10)

    def forward(self, x):
        x = self.conv(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x


class net3(nn.Module):
    def __init__(self):
        super(net1, self).__init__()
        # 用ModuleList建構的層則會自動注冊在網絡結構中
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(2)])

    def forward(self, x):
        """
        1.
        nn.ModuleList 并沒有定義一個網絡,
        它隻是将不同的子產品儲存在一起,這些子產品之間并沒有什麼先後順序可言
        x = self.linears[2](x)
        x = self.linears[0](x)
        x = self.linears[1](x)
        2.
        一個子產品(層)可以被調用多次,但是被調用多次的子產品,
        使用的是同一組parameters,即他們是參數共享的,
        """
        for m in self.linears:
            x = m(x)
        return x


class net4(nn.Module):
    def __init__(self):
        super(net4, self).__init__()
        self.block = nn.Sequential(nn.Conv2d(1, 20, 5),
                                   nn.ReLU(),
                                   nn.Conv2d(20, 64, 5),
                                   nn.ReLU())

    """
    注意,nn.Sequential與nn.ModuleList的差別主要有兩個
    1.
    nn.Sequential内的子產品是按照順序排列的
    2.
    nn.Sequential已經實作了内部的forward函數,是以可以整個block直接調用
    """

    def forward(self, x):
        x = self.block(x)
        return x


class net5(nn.Module):
    def __init__(self):
        super(net5, self).__init__()
        self.net1 = net2()
        self.net2 = net4()


if __name__ == '__main__':
    net = net5()
    print(net)
           

繼續閱讀