@A couple of tricks with Pytorch
最近在将Pytorch模型转为caffe时发现,不少人在Github上给出了
pytorch2caffe
的转换工具。但这些转换工具对于Pytorch的某些操作并没有很好的支持,例如
torch.nn.functional.conv2d
、
torch.nn.functional.pad
等。对此,我们常常要修改代码添加caffe层。为方便起见,我们将以常用的Pytorch语句替换这些操作。
center crop
在网络设计中,我们可能需要对特征图进行中心裁剪:
import torch
import torch.nn as nn
import torch.nn.functional as F
class center_crop(nn.Module):
def __init__(self, kernel):
super(center_crop, self).__init__()
self.pad = (kernel - 1) // 2
def forward(self, x):
return F.pad(x, (-self.pad, -self.pad, -self.pad, -self.pad)).contiguous()
以上语句,我们可以采用
torch.nn.Conv2d
替换:
import numpy as np
import torch
import torch.nn as nn
class CenterCrop(nn.Module):
def __init__(self, kernel_size, channel_num):
super(CenterCrop, self).__init__()
self.kernel_size = kernel_size
self.channel_num = channel_num
self.conv = nn.Conv2d(in_channels = self.channel_num, out_channels = self.channel_num, kernel_size = self.kernel_size, stride = 1, padding = 0, groups = self.channel_num, bias = False)
self.conv.weight.data = self.kernel(self.channel_num)
self.conv.weight.requires_grad = False
def kernel(self, channel_num):
kernel = np.zeros(shape = (self.kernel_size, self.kernel_size))
kernel[self.kernel_size // 2, self.kernel_size // 2] = 1
return torch.from_numpy(kernel).unsqueeze(0).unsqueeze(0).float().repeat(channel_num, 1, 1, 1)
def forward(self, x):
return self.conv(x)
torch.nn.functional.conv2d
在caffe中,
convolution
层的输入
blob
只有一个,且
weight
是固定的。如果我们需要将权重作为变量,则可使用
torch.nn.functional.conv2d
实现,为方便将该语句转为caffe而不对转换工具代码做太多修改,我们可以综合
torch.nn.ConvTranspose2d
、
torch.mul
、
torch.split
、
torch.cat
等:
import torch
import torch.nn as nn
class Fconv2d(nn.Module):
def __init__(self, inp, oup, input_size, kernel_size):
super(Fconv2d, self).__init__()
'''
only fit for batch_size = 1
'''
self.inp = inp
self.oup = oup
self.kernel_size = kernel_size
self.output_size = input_size - kernel_size + 1
self.split_list = [kernel_size for _ in range(2 * self.output_size - 1)]
for m in range(self.output_size - 1):
self.split_list[2 * m + 1] = self.output_size
self.tronv = nn.ConvTranspose2d(in_channels = inp, out_channels = inp, kernel_size = self.output_size, stride = 1, padding = 0, output_padding = 0, bias = False, groups = inp, dilation = input_size)
self.kernel_tronv = nn.ConvTranspose2d(in_channels = inp, out_channels = inp, kernel_size = self.output_size, stride = 1, padding = 0, output_padding = 0, bias = False, groups = inp, dilation = kernel_size)
self.tronv.weight.data = torch.ones(self.tronv.weight.data.shape)
self.kernel_tronv.weight.data = torch.ones(self.kernel_tronv.weight.data.shape)
self.tronv.weight.requires_grad = False
self.kernel_tronv.weight.requires_grad = False
def _trans(self, x):
self.conv_sum = nn.Conv2d(in_channels = self.inp, out_channels = 1, kernel_size = self.kernel_size, stride = self.kernel_size, padding = 0, bias = False)
self.conv_sum.weight.data = torch.ones(self.conv_sum.weight.data.shape)
self.conv_sum.weight.requires_grad = False
out = self.conv_sum(x)
return out.permute(1, 0, 2, 3)
def forward(self, input, weight):
x = self.tronv(input)
tensors_x = torch.split(x, self.split_list, dim = 2)
tensors_y = torch.split(torch.cat([tensors_x[2 * j] for j in range(self.output_size)], 2), self.split_list, dim = 3)
input_old = torch.cat([tensors_y[2 * k] for k in range(self.output_size)], 3)
weight_old = self.kernel_tronv(weight)
input_new = input_old.view(1, self.inp, -1)
weight_new = weight_old.view(self.oup, self.inp, -1)
out_mul = torch.mul(weight_new, torch.cat([input_new for _ in range(self.oup)], 0)).view(self.oup, self.inp, self.output_size * self.kernel_size, self.output_size * self.kernel_size)
output = self._trans(out_mul)
return output
上述代码中,
inp
为输入的通道数,
oup
为输出的通道数,
input_size = input.shape[-1]
、
kernel_size = weight.shape[-1]
。经实验,上述代码与
F.conv2d(input = input, weight = weight, groups = input.shape[0])
在
input.shape[0] = 1
时所得结果相同:
>>> input = torch.rand([1, 8, 22, 22])
>>> weight = torch.rand([10, 8, 8, 8])
>>> trans_conv = Fconv2d(inp = 8, oup = 10, input_size = 22, kernel_size = 8)
>>> output = trans_conv(input, weight)
>>> out = F.conv2d(input, weight, groups = 1)
>>> print(torch.sum(out - output > 1e-4))
tensor(0)