天天看点

A couple of tricks with Pytorch

@A couple of tricks with Pytorch

最近在将Pytorch模型转为caffe时发现,不少人在Github上给出了

pytorch2caffe

的转换工具。但这些转换工具对于Pytorch的某些操作并没有很好的支持,例如

torch.nn.functional.conv2d

torch.nn.functional.pad

等。对此,我们常常要修改代码添加caffe层。为方便起见,我们将以常用的Pytorch语句替换这些操作。

center crop

在网络设计中,我们可能需要对特征图进行中心裁剪:

import torch
import torch.nn as nn
import torch.nn.functional as F


class center_crop(nn.Module):
    def __init__(self, kernel):
        super(center_crop, self).__init__()
        self.pad = (kernel - 1) // 2
    def forward(self, x):
        return F.pad(x, (-self.pad, -self.pad, -self.pad, -self.pad)).contiguous()
           

以上语句,我们可以采用

torch.nn.Conv2d

替换:

import numpy as np
import torch
import torch.nn as nn


class CenterCrop(nn.Module):
    def __init__(self, kernel_size, channel_num):
        super(CenterCrop, self).__init__()
        self.kernel_size = kernel_size
        self.channel_num = channel_num
        self.conv = nn.Conv2d(in_channels = self.channel_num, out_channels = self.channel_num, kernel_size = self.kernel_size, stride = 1, padding = 0, groups = self.channel_num, bias = False)
        self.conv.weight.data = self.kernel(self.channel_num)
        self.conv.weight.requires_grad = False
    
    def kernel(self, channel_num):
        kernel = np.zeros(shape = (self.kernel_size, self.kernel_size))
        kernel[self.kernel_size // 2, self.kernel_size // 2] = 1
        return torch.from_numpy(kernel).unsqueeze(0).unsqueeze(0).float().repeat(channel_num, 1, 1, 1)
    
    def forward(self, x):
        return self.conv(x)
           

torch.nn.functional.conv2d

在caffe中,

convolution

层的输入

blob

只有一个,且

weight

是固定的。如果我们需要将权重作为变量,则可使用

torch.nn.functional.conv2d

实现,为方便将该语句转为caffe而不对转换工具代码做太多修改,我们可以综合

torch.nn.ConvTranspose2d

torch.mul

torch.split

torch.cat

等:

import torch
import torch.nn as nn


class Fconv2d(nn.Module):
   def __init__(self, inp, oup, input_size, kernel_size):
       super(Fconv2d, self).__init__()
       '''
       only fit for batch_size = 1
       '''
       self.inp = inp
       self.oup = oup
       self.kernel_size = kernel_size
       self.output_size = input_size - kernel_size + 1
       self.split_list = [kernel_size for _ in range(2 * self.output_size - 1)]
       for m in range(self.output_size - 1):
           self.split_list[2 * m + 1] = self.output_size
       self.tronv = nn.ConvTranspose2d(in_channels = inp, out_channels = inp, kernel_size = self.output_size, stride = 1, padding = 0, output_padding = 0, bias = False, groups = inp, dilation = input_size)
       self.kernel_tronv = nn.ConvTranspose2d(in_channels = inp, out_channels = inp, kernel_size = self.output_size, stride = 1, padding = 0, output_padding = 0, bias = False, groups = inp, dilation = kernel_size)
       self.tronv.weight.data = torch.ones(self.tronv.weight.data.shape)
       self.kernel_tronv.weight.data = torch.ones(self.kernel_tronv.weight.data.shape)
       self.tronv.weight.requires_grad = False
       self.kernel_tronv.weight.requires_grad = False
   
   def _trans(self, x):
       self.conv_sum = nn.Conv2d(in_channels = self.inp, out_channels = 1, kernel_size = self.kernel_size, stride = self.kernel_size, padding = 0, bias = False)
       self.conv_sum.weight.data = torch.ones(self.conv_sum.weight.data.shape)
       self.conv_sum.weight.requires_grad = False
       out = self.conv_sum(x)
       return out.permute(1, 0, 2, 3)
   
   def forward(self, input, weight):
       x = self.tronv(input)
       tensors_x = torch.split(x, self.split_list, dim = 2)
       tensors_y = torch.split(torch.cat([tensors_x[2 * j] for j in range(self.output_size)], 2), self.split_list, dim = 3)
       input_old = torch.cat([tensors_y[2 * k] for k in range(self.output_size)], 3)
       weight_old = self.kernel_tronv(weight)
       input_new = input_old.view(1, self.inp, -1)
       weight_new = weight_old.view(self.oup, self.inp, -1)
       out_mul = torch.mul(weight_new, torch.cat([input_new for _ in range(self.oup)], 0)).view(self.oup, self.inp, self.output_size * self.kernel_size, self.output_size * self.kernel_size)
       output = self._trans(out_mul)
       return output
           

上述代码中,

inp

为输入的通道数,

oup

为输出的通道数,

input_size = input.shape[-1]

kernel_size = weight.shape[-1]

。经实验,上述代码与

F.conv2d(input = input, weight = weight, groups = input.shape[0])

input.shape[0] = 1

时所得结果相同:

>>> input = torch.rand([1, 8, 22, 22])
>>> weight = torch.rand([10, 8, 8, 8])
>>> trans_conv = Fconv2d(inp = 8, oup = 10, input_size = 22, kernel_size = 8)
>>> output = trans_conv(input, weight)
>>> out = F.conv2d(input, weight, groups = 1)
>>> print(torch.sum(out - output > 1e-4))
tensor(0)