天天看點

【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的

目的

  • Ok,先來說說為什麼有這篇文章。
  • 作為一個才入門的小白,在使用unet訓練model時遇到各種問題,看過論文,查過資料,在github上找過大佬複現的unet,最終再使用pytorch自己提供的unet模型時,效果稍稍好了些,但還是存在問題,于是下定決心檢視segmentation_models_pytorch中Unet是怎麼實作的。話不多說,現在開始。

首先看見的是

如何調用

  • 如下圖,我是這麼寫的
  • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
  • 我使用的編輯器是pycharm,是以同時按下Ctrl + 滑鼠左鍵,即可進入源碼

于是,

點選進入Unet

  • 進來之後你可以看見如下内容
  • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
  • 這裡我們先看encoder,如圖所示可以看見

    self.encoder = get_encoder(...)

  • 其參數即為我開始初始化的設定
encoder_name = "resnet101"
in_channels = 3
depth = encoder_depth 		# 預設值5
weights = "imagenet"
           

下一步檢視

get_encoder

怎麼實作的

  • 再次進入檢視源碼,如下圖
  • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
  • 在這裡已經可以看見

    Encoder = encoders[name]["encoder"]

    • 首先看

      name

      , 根據上一步傳入的參數可以知道

      name = encoder_name = "restnet101"

    • 那麼

      encoders

      是什麼呢?可以猜測得到他是一個

      dict

      ,根據

      encoder_name

      取值
    • 寫在這裡: 我個人習慣是先不往下看,後面看到需要什麼參數了,再回來找

讓我們來看看

Encoder

到底是什麼

  • 首先,我們得找到

    encoders

  • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
  • 這些參數是提前提供好的,在上方

    import

    中導入
  • 而我使用的是

    resnet101

    , 于是大膽推測我需要的就是第一行的

    resnet_encoders

  • 這是我個人習慣,有點魯莽,個人建議還是仔細檢視認真核對,以防出錯
  • 接下來看一下

    resnet_encoders

    到底是什麼?
  • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
  • 可以看見他是一個

    dict

  • 在裡面找到我所需要的

    resnet101

  • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
  • 可以根據之前

    get_encoder

    裡面的代碼

    Encoder = encoders[name][encoder]

    得到這個

    Encoder = ResNetEncoder

  • 曬出

    ResNetEncoder

    源碼
  • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
  • 以程式的思維,先找到入口

    forward

    然後逐行執行
    • 1.傳入參數

      x

      ,這裡其實就是輸入的特征
    • 2.然後執行個體化

      stages = self.get_stages()

    • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
    • 可以看見他是一系列的網絡層
    • 然後就是一個

      for

      循環,循環次數

      _depth

      預設為5,也可自動傳入,

      先做個記号A

      ,待會會用到這裡
    • 在這裡可以看見

      stages

      的實際長度為6,如果你輸入的

      depth > 5

      就會出錯哦
    • 然後檢視

      stages

      • 第一個:

        nn.Identity()

        實際就是一個輸入層
      • 第二個:

        nn.Sequential(self.conv1, self.bn1, self.relu)

        ,這一層需要執行三個操作,繼續看着三個操作是什麼
      • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
      • 第三個:

        nn.Sequential(self.maxpool, self.layer1)

        , 這裡一看就是一個最大池化層,那

        self.layer1

        是什麼呢?還有下面的

        self.layer2

        self.layer3

        self.layer4

      • 繼續看看
      • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
      • 他們都是執行的一個方法

        self._make_layer

        ,隻是傳入的參數不同
      • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
      • 可以結合前面一步的

        self.layer1...self.layer4

        看出來,其實主要就是

        block

        blocks

        ,根據形參和實參位置對應,得到

        blocks = layers[*]

        , 也就是說主要的就是

        block

        layers

      • 那麼這兩個參數是在哪裡傳入的呢?
      • 仔細回想,我們之前的操作隻執行到了

        Encoder = encoders[name]["encoder"]

        , 隻是初始化了類,還并未執行個體化,更不用說調用了,那麼,問題就容易解決了,去看看在哪裡執行個體化的。直接回到初始位置

        get_encoder

        可以發現下圖
      • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
      • 那麼

        params

        是什麼?這裡也可以看見傳入的之前記号A

        depth

      • 繼續看
      • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
      • 問題解決,

        block = Bottleneck

        以及

        layers=[3, 4, 23, 3]

      • 先看

        Bottleneck

      • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
      • 再結合之前

        self._make_layer(...)

      • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
      • 可以看見

        block

        主要是對傳入的特征進行如下操作:

        省略參數

        • conv1: conv1x1

        • BatchNorm2d

        • ReLU

        • conv2: conv3x3

        • BatchNorm2d

        • ReLU

        • conv3: conv1x1

        • BatchNorm2d

        • 如果需要

          downsample

          則執行
          • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
          • conv1x1

            +

            BatchNorm2d

        • 這裡的

          conv1x1

          conv3x3

          分别為:

          主要差别是kernel_size不同

        • 【PyTorch】【segmentation_models_pytorch】【Unet】源碼解析 - 【Encoder】目的
        • 在經過

          downsample

          之後再次進行

          ReLU

到這裡就已經基本上可以清楚了解了

Unet.encoder

是如何組成的,其實可以畫個圖更加直覺,但是由于時間優先,這裡暫時先不加,後續畫完了我再添加上來,另外寫這篇文章,其實主要目的是幫我自己梳理思路,邊分析編寫可以記得更真切。同時如果能幫助到跟我一樣有迷惑的人,那就更好了。

——<未完待續>