定義一個特征提取的類:
參考pytorch論壇:How to extract features of an image from a trained model
from torchvision.models import resnet18
import torch.nn as nn
myresnet=resnet18(pretrained=True)
print (myresnet)
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
super(FeatureExtractor, self).__init__()
self.submodule = submodule
self.extracted_layers = extracted_layers
def forward(self, x):
outputs = []
for name, module in self.submodule._modules.items():
if name is "fc": x = x.view(x.size(0), -1)
x = module(x) # last layer output put into current layer input
print(name)
if name in self.extracted_layers:
outputs.append(x)
return outputs
exact_list=["conv1","layer1","avgpool"]
myexactor=FeatureExtractor(myresnet,exact_list).cuda()
x = Variable(torch.rand(5, 3, 224, 224), requires_grad=True).cuda()
y=myexactor(x) # 5x64x112x112 5x64x56x56 5x512x1x1
print (myexactor)
print(type(y))
print(type(y[0]))
for i in range(len(y)):
print y[i].data.cpu().numpy().size
print y[i].data.cpu().numpy().shape
# <type 'list'>
# <class 'torch.autograd.variable.Variable'>
# 4014080
# (5, 64, 112, 112)
# 1003520
# (5, 64, 56, 56)
# 2560
# (5, 512, 1, 1)
#特征輸出可視化
import matplotlib.pyplot as plt
for i in range(64):
ax = plt.subplot(8, 8, i + 1)
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
plt.imshow(x[0].data.numpy()[0,i,:,:],cmap='jet')
plt.show()
- Accessing and modifying different layers of a pretrained model in pytorch:https://github.com/mortezamg63/Accessing-and-modifying-different-layers-of-a-pretrained-model-in-pytorch
C/C++基本文法學習
STL
C++ primer