U-Net++:
一種基于嵌套和密集跳躍連接配接的新分割體系結構,這種設計的跳躍連接配接降低編碼和解碼子網絡的特征圖之間的間隙,這個分割體系結構基于這樣的假設:當将來自編碼器網絡的高分辨率特征圖與來自解碼器網絡的對應語義豐富的特征圖進行融合之前,逐漸豐富模型時,該模型可以有效的捕捉到前景對象的細粒度細節,當來自解碼器和編碼器網絡的特征映射在語義上相似時,網絡将處理更容易的學習任務。
基于U-net改進的一種網絡
其改進思路如下:

1.為什麼降到X(4,0)才開始上采樣
對于不同深度的U-net表現,并不是越深越好,不同層次的特征的重要性對于不同的資料集是不一樣的,是以并不是原文中的4層U-net就一定對所有資料集的分割問題表現最優。
使用淺層和深層的特征,利用不同深度的U-net來各自抓取不同層次的特征
這個網絡(上圖)的好處是不管哪個深度的特征有效,都用上,讓網絡自己去學習不同深度的特征的重要性,其次它共享了一個特征提取器,也就是不需要訓練一堆U-net,隻需要訓練一個encode,它不同層次的特征可以由不同的解碼路徑來還原,編碼路徑還是可以靈活的使用不同的backbone;但是它也存在缺點:這個網絡結構是不能被訓練的,因為不會由任何梯度出現在紅色三角形内,因為它和算loss_func的地方在反向傳播是斷開的。
解決這個問題的方法:1.加入深度監督2.把結構改為如圖
但這個網絡把U-net原來的長連接配接去掉了,U-net的長連接配接的重要性包含3點:1.解決梯度消失;2.學習低級别的特征;3.恢複在下采樣過程中丢失的資訊。
是以說這個長連接配接是必要的,它聯系了輸入圖像的很多資訊,有助于還原降采樣所帶來的資訊損失,在一定程度上與殘差的操作(residual)很類似,x+f(x)。是以這樣就想到長連接配接和短連結結合的方式。這樣就得到U-net++的網絡結構
與U-net網絡不同的是:
重要思想:在編碼器和解碼器特征圖融合之前彌合它們之間的語義間隙
- 在跳躍路徑上有卷積層(綠色),用來彌合編碼和解碼特征圖的語義間隙
- 在跳躍路徑上有緊密的跳過連接配接(藍色),進而改善了梯度流
- 加入了深度監督(紅色),可以進行模型修剪并進行改進或者最壞情況下,可以達到與僅使用一個損耗層相當的性能
重設計的跳躍路徑
U-net:編碼器的特征圖直接在解碼器中接收
U-net++:解碼器在接收編碼器特征圖的路徑中經曆密集的卷積層,這些卷積層數取決于金字塔等級的塊。
密集卷積塊使編碼器特征圖的語義級别更接近在解碼器中等待的特征圖的語義級别,這樣優化器更容易優化。
每個結點的輸出:
H(·)是緊跟激活函數後的卷積操作,U(·)是上采樣層,[]是融合層
深度監督
深度監督可以使模型以兩種模式運作:
1)精确模式,其中對所有細分分支的輸出求平均值;
2)快速模式,其中僅從分割分支之一中選擇最終分割圖,其選擇決定了模型修剪的程度和速度增益。
在訓練過程中在各個leve的子網絡中加深度監督,可以帶的好處:剪枝
由于嵌套的跳躍路徑,Unet++可以在多個語義級别上産生全分辨率特征圖
作者的思想:
測試時,剪掉的部分對剩餘表結構不做影響,訓練時,剪掉部分對剩餘部分有影響
(也就是說在測試階段由于輸入的圖像隻有前向傳播,剪掉的這部分對前面的速出完全沒有影響,而在訓練階段既有前向又有反向傳播,被剪掉的部分時會幫助其他部分做權重更新的。)
在深度監督過程中,每個子網絡的輸出就是圖像的分割結果了,如果其中的子網絡鳳娥結果足夠好,可以随意的裁剪剩餘的部分。
為什麼要在測試時剪枝,而不是直接拿剪完的L1,L2,L3訓練
剪掉的那部分對訓練時的反向傳播時時有貢獻的,如果直接拿L1,L2,L3訓練,就相當于隻訓練不同深度的U-NET,最後的結果會很差
如何去決定剪多少
訓練模型時将資料分為訓練集驗證集和測試集,訓練集是一定拟合好的,測試集當然是不能碰的,那麼根據子網絡在驗證集的結果來決定剪多少。具體的結果作者也給出了。
從結果看出網絡并需要很深就可以在一些資料集取得很好的效果。
将二進制交叉熵和Dice系數作為上述四個語義級别的每個損失函數:
其中^yb和yb分别指平均預測機率和平均真實機率,N是batchsize
**主網絡代碼:
基于pytorch:
class conv_block_nested(nn.Module):
def __init__(self, in_ch, mid_ch, out_ch):
super(conv_block_nested, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True)
self.bn1 = nn.BatchNorm2d(mid_ch)
self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True)
self.bn2 = nn.BatchNorm2d(out_ch)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.bn2(x)
output = self.activation(x)
return output
#Nested Unet
class NestedUNet(nn.Module):
"""
Implementation of this paper:
https://arxiv.org/pdf/1807.10165.pdf
"""
def __init__(self, in_ch=3, out_ch=1):
super(NestedUNet, self).__init__()
n1 = 64
filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
#backbone中下采樣的各個級别層
self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0])
self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1])
self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2])
self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3])
self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4])
#x^(0,1),x^(1,1),x^(2,1),x^(3,1)層
self.conv0_1 = conv_block_nested(filters[0] + filters[1], filters[0], filters[0])
self.conv1_1 = conv_block_nested(filters[1] + filters[2], filters[1], filters[1])
self.conv2_1 = conv_block_nested(filters[2] + filters[3], filters[2], filters[2])
self.conv3_1 = conv_block_nested(filters[3] + filters[4], filters[3], filters[3])
##x^(0,2),x^(1,2),x^(2,2)層
self.conv0_2 = conv_block_nested(filters[0]*2 + filters[1], filters[0], filters[0])
self.conv1_2 = conv_block_nested(filters[1]*2 + filters[2], filters[1], filters[1])
self.conv2_2 = conv_block_nested(filters[2]*2 + filters[3], filters[2], filters[2])
#x^(0,3),x^(1,3)層
self.conv0_3 = conv_block_nested(filters[0]*3 + filters[1], filters[0], filters[0])
self.conv1_3 = conv_block_nested(filters[1]*3 + filters[2], filters[1], filters[1])
#x^(0,4)層
self.conv0_4 = conv_block_nested(filters[0]*4 + filters[1], filters[0], filters[0])
self.final = nn.Conv2d(filters[0], out_ch, kernel_size=1)
def forward(self, x):
x0_0 = self.conv0_0(x)
x1_0 = self.conv1_0(self.pool(x0_0))
x0_1 = self.conv0_1(torch.cat([x0_0, self.Up(x1_0)], 1))
x2_0 = self.conv2_0(self.pool(x1_0))
x1_1 = self.conv1_1(torch.cat([x1_0, self.Up(x2_0)], 1))
x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up(x1_1)], 1))
x3_0 = self.conv3_0(self.pool(x2_0))
x2_1 = self.conv2_1(torch.cat([x2_0, self.Up(x3_0)], 1))
x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up(x2_1)], 1))
x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up(x1_2)], 1))
x4_0 = self.conv4_0(self.pool(x3_0))
x3_1 = self.conv3_1(torch.cat([x3_0, self.Up(x4_0)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.Up(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.Up(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3)], 1))
output = self.final(x0_4)
return output