天天看點

根據神經網絡中間層結果繪制heatmap

以resnet18為例,輸出4個stage的結果heatmap

import cv2
from torchvision.models.resnet import resnet18
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np

# img = np.zeros((256,256),dtype=np.uint8)
# for r in range(16):
#     for c in range(16):
#         val = r*16 + c
#         xmin = c*16
#         xmax = (c+1)*16
#         ymin = r*16
#         ymax = (r+1)*16
#         img[ymin:ymax,xmin:xmax] = val
# plt.imshow(img)
# plt.show()

list_features = []
def viz(module,input,outputs):
    x = input[0][0]
    # min_num = np.minimum(4,x.size(0))
    # num = x.size(0)
    # img_perline = 10
    # num_lines = int(np.ceil(num / img_perline))
    # for i in range(num):
    #     plt.subplot(num_lines,img_perline,i+1)
    #     plt.imshow(x[i].detach().numpy())
    # plt.show()
    list_features.append(x)


trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
model = resnet18(True).eval()
for name,m in model.named_modules():
    # if isinstance(m,torch.nn.Conv2d):
    if name in ['layer{}'.format(i) for i in range(1,5)]:
        m.register_forward_hook(viz)
img = cv2.imread('Screenshot from 2021-04-09 10-13-37.png')
img = cv2.resize(img,(224,224))
# print(img.shape)
# img_tensor = torch.from_numpy(img).float().unsqueeze(0).cuda()
img_tensor = trans(img).unsqueeze(0)
result = model(img_tensor).softmax(dim=-1)
print(torch.max(result))
print()

for heatmaps in list_features:
    heatmaps = heatmaps.detach().numpy()
    for heatmap in heatmaps:
        v_min = heatmap.min()
        v_max = heatmap.max()
        heatmap = (heatmap - v_min) / max((v_max - v_min),1e-10)
        heatmap = cv2.resize(heatmap,(224,224)) * 255

        heatmap = heatmap.astype(np.uint8)
        cv2.imshow("heatmap1",heatmap)
        # cv2.waitKey()
        heatmap2 = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        # heatmap2 = cv2.cvtColor(heatmap,cv2.COLOR_GRAY2BGR)
        cv2.imshow("heatmap2", heatmap2)
        # cv2.waitKey()
        superimposed_img = heatmap2 * 0.4 + img * 0.6
        superimposed_img = np.clip(superimposed_img,0,255).astype(np.uint8)
        cv2.imshow("superimposed_img", superimposed_img)
        cv2.waitKey()
           

繼續閱讀