@A couple of tricks with Pytorch
最近在将Pytorch模型轉為caffe時發現,不少人在Github上給出了
pytorch2caffe
的轉換工具。但這些轉換工具對于Pytorch的某些操作并沒有很好的支援,例如
torch.nn.functional.conv2d
、
torch.nn.functional.pad
等。對此,我們常常要修改代碼添加caffe層。為友善起見,我們将以常用的Pytorch語句替換這些操作。
center crop
在網絡設計中,我們可能需要對特征圖進行中心裁剪:
import torch
import torch.nn as nn
import torch.nn.functional as F
class center_crop(nn.Module):
def __init__(self, kernel):
super(center_crop, self).__init__()
self.pad = (kernel - 1) // 2
def forward(self, x):
return F.pad(x, (-self.pad, -self.pad, -self.pad, -self.pad)).contiguous()
以上語句,我們可以采用
torch.nn.Conv2d
替換:
import numpy as np
import torch
import torch.nn as nn
class CenterCrop(nn.Module):
def __init__(self, kernel_size, channel_num):
super(CenterCrop, self).__init__()
self.kernel_size = kernel_size
self.channel_num = channel_num
self.conv = nn.Conv2d(in_channels = self.channel_num, out_channels = self.channel_num, kernel_size = self.kernel_size, stride = 1, padding = 0, groups = self.channel_num, bias = False)
self.conv.weight.data = self.kernel(self.channel_num)
self.conv.weight.requires_grad = False
def kernel(self, channel_num):
kernel = np.zeros(shape = (self.kernel_size, self.kernel_size))
kernel[self.kernel_size // 2, self.kernel_size // 2] = 1
return torch.from_numpy(kernel).unsqueeze(0).unsqueeze(0).float().repeat(channel_num, 1, 1, 1)
def forward(self, x):
return self.conv(x)
torch.nn.functional.conv2d
在caffe中,
convolution
層的輸入
blob
隻有一個,且
weight
是固定的。如果我們需要将權重作為變量,則可使用
torch.nn.functional.conv2d
實作,為友善将該語句轉為caffe而不對轉換工具代碼做太多修改,我們可以綜合
torch.nn.ConvTranspose2d
、
torch.mul
、
torch.split
、
torch.cat
等:
import torch
import torch.nn as nn
class Fconv2d(nn.Module):
def __init__(self, inp, oup, input_size, kernel_size):
super(Fconv2d, self).__init__()
'''
only fit for batch_size = 1
'''
self.inp = inp
self.oup = oup
self.kernel_size = kernel_size
self.output_size = input_size - kernel_size + 1
self.split_list = [kernel_size for _ in range(2 * self.output_size - 1)]
for m in range(self.output_size - 1):
self.split_list[2 * m + 1] = self.output_size
self.tronv = nn.ConvTranspose2d(in_channels = inp, out_channels = inp, kernel_size = self.output_size, stride = 1, padding = 0, output_padding = 0, bias = False, groups = inp, dilation = input_size)
self.kernel_tronv = nn.ConvTranspose2d(in_channels = inp, out_channels = inp, kernel_size = self.output_size, stride = 1, padding = 0, output_padding = 0, bias = False, groups = inp, dilation = kernel_size)
self.tronv.weight.data = torch.ones(self.tronv.weight.data.shape)
self.kernel_tronv.weight.data = torch.ones(self.kernel_tronv.weight.data.shape)
self.tronv.weight.requires_grad = False
self.kernel_tronv.weight.requires_grad = False
def _trans(self, x):
self.conv_sum = nn.Conv2d(in_channels = self.inp, out_channels = 1, kernel_size = self.kernel_size, stride = self.kernel_size, padding = 0, bias = False)
self.conv_sum.weight.data = torch.ones(self.conv_sum.weight.data.shape)
self.conv_sum.weight.requires_grad = False
out = self.conv_sum(x)
return out.permute(1, 0, 2, 3)
def forward(self, input, weight):
x = self.tronv(input)
tensors_x = torch.split(x, self.split_list, dim = 2)
tensors_y = torch.split(torch.cat([tensors_x[2 * j] for j in range(self.output_size)], 2), self.split_list, dim = 3)
input_old = torch.cat([tensors_y[2 * k] for k in range(self.output_size)], 3)
weight_old = self.kernel_tronv(weight)
input_new = input_old.view(1, self.inp, -1)
weight_new = weight_old.view(self.oup, self.inp, -1)
out_mul = torch.mul(weight_new, torch.cat([input_new for _ in range(self.oup)], 0)).view(self.oup, self.inp, self.output_size * self.kernel_size, self.output_size * self.kernel_size)
output = self._trans(out_mul)
return output
上述代碼中,
inp
為輸入的通道數,
oup
為輸出的通道數,
input_size = input.shape[-1]
、
kernel_size = weight.shape[-1]
。經實驗,上述代碼與
F.conv2d(input = input, weight = weight, groups = input.shape[0])
在
input.shape[0] = 1
時所得結果相同:
>>> input = torch.rand([1, 8, 22, 22])
>>> weight = torch.rand([10, 8, 8, 8])
>>> trans_conv = Fconv2d(inp = 8, oup = 10, input_size = 22, kernel_size = 8)
>>> output = trans_conv(input, weight)
>>> out = F.conv2d(input, weight, groups = 1)
>>> print(torch.sum(out - output > 1e-4))
tensor(0)