def forward里面的内容不显示吗?当调用vgg的是后会显示,默认是resnet
class vgg16(nn.Module):
def __init__(self):
super(vgg16, self).__init__()
self.cfg = {'tun': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'tun_ex': [512, 512, 512]}
self.extract = [8, 15, 22, 29] # feature map in 'tun' -> c(2), c(3), c(4), c(5) # [3, 8, 15, 22, 29]
self.extract_ex = [5]
self.base = nn.ModuleList(vgg(self.cfg['tun'], 3))
self.base_ex = vgg_ex(self.cfg['tun_ex'], 512)
# init paramiter
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, 0.01)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def load_pretrained_model(self, model):
self.base.load_state_dict(model)
def forward(self, x, multi=0):
tmp_x = []
# through the 'tun' layer by layer
for k in range(len(self.base)): # 'tun' 37 layer -> 64 64, 128 128, 256 256 256, 512 512 512, 512 512 512
print('=>len(self.base)', len(self.base)) # not show ???
x = self.base[k](x) # get new x through every layer in 'tun'
if k in self.extract: # feature map in 'tun' -> c(2), c(3), c(4), c(5) -> self.extract = [8, 15, 22, 29]
tmp_x.append(x)
x = self.base_ex(x) # 'tun_ex' layer -> 512 512 512
tmp_x.append(x) # # feature map in 'tun_ex' -> c(6)
if multi == 1:
tmp_y = []
tmp_y.append(tmp_x[0]) # feature map in 'tun' -> c(2)
return tmp_y
else:
return tmp_x