天天看点

java生成png热力图_热力图与原始图像融合

使用神经网络进行预测时,一个明显的缺陷就是缺少可解释性,我们不能通过一些简单的方法来知道网络做出决策或者预测的理由,这在很多方面就使得它的应用受限。

虽然不能通过一些数学方法来证明模型的有效性,但我们仍能够通过一些可视化热力图的方法来观测一下原始数据中的哪些部分对我们网络影响较大。

实现热力图绘制的方法有很多,如:CAM, Grad-CAM, Contrastive EBP等。在热力图生成之后,因为没有原始数据信息,所以我们并不能很直观地观测到模型到底重点关注了图像的哪些区域。这时将热力图叠加到原始图像上的想法就会很自然的产生。这里存在的一个问题是原始图像的色域空间可能和产生的热力图的色域空间是不一致的,当二者叠加的时候,会产生颜色的遮挡。并且因为产生的热力图的尺寸应该与原始图像尺寸一致或者调整到与原始尺寸一致,这样当二者直接简单地叠加的话,产生的图像可能并不是我们想要的,因此,我们需要先对热力图数据进行一些简单的像素处理,然后在考虑与原始图像的融合。以下部分的安排为:1. 热力图的产生 2. 热力图与原始图的叠加 3. 热力图与原始图融合优化

1. 热力图产生

在这里使用3D-Grad-CAM的方法来实现热力图绘制的方法,使用的图像尺寸为144, 168, 152 代码如下:

def cam(img_path, model_path, relu=True, sigmoid=False):

# grad-cam

img_data = np.load(img_path)

img_data = img_data[np.newaxis, :, :, :, np.newaxis]

max_ = np.max(img_data)

min_ = np.min(img_data)

img_data = (img_data - min_) / (max_ - min_)

model = load_model(model_path)

model.summary()

index = 0

pred = model.predict(img_data)

if sigmoid:

if pred >= 0.5:

index = 1

else:

max_ = np.max(pred)

for i in range(4):

if pred[0][i] == max_:

index = i

break

print(pred)

print("index: ", index)

pre_output = model.output[:, index]

last_conv_layer = model.get_layer('conv3d_7')

grads = K.gradients(pre_output, last_conv_layer.output)[0]

pooled_grads = K.mean(grads, axis=(0, 1, 2, 3))

iterate = K.function([model.input], [pooled_grads, last_conv_layer.output[0]])

pooled_grads_value, conv_layer_output_value = iterate([img_data])

if relu:

conv_layer_output_value[np.where(conv_layer_output_value < 0)] = 0

conv_max = np.max(conv_layer_output_value)

conv_min = np.min(conv_layer_output_value)

conv_layer_output_value = (conv_layer_output_value - conv_min) / (conv_max - conv_min)

pool_max = np.max(pooled_grads_value)

pool_min = np.min(pooled_grads_value)

pooled_grads_value = (pooled_grads_value - pool_min) / (pool_max - pool_min)

layer_number = len(pooled_grads_value)

for i in range(layer_number):

conv_layer_output_value[:, :, :, i] *= pooled_grads_value[i]

# along the last dim calculate the mean value

heatmap = np.mean(conv_layer_output_value, axis=-1)

# remove the value which less than 0

heatmap = np.maximum(heatmap, 0)

# uniformization

min_ = np.min(heatmap)

max_ = np.max(heatmap)

heatmap = (heatmap - min_) / (max_ - min_)

return heatmap

2. 热力图与原始图的叠加

通过以下代码获取热力图,并将其尺寸放缩到与原图一致:

heatmap = cam(img_path, model_path)

heatmap = resize(heatmap, (144, 168, 152))

加载数据:

img_data = np.load(img_path)

热力图与原图简单叠加:

def easy_show(data, heatmap):

plt.figure()

plt.subplot(221)

plt.axis('off')

plt.imshow(data, cmap='bone')

plt.subplot(222)

plt.axis('off')

plt.imshow(heatmap, cmap='rainbow')

plt.subplot(223)

plt.axis('off')

plt.imshow(data, cmap='bone')

plt.imshow(heatmap, cmap='rainbow', alpha=0.7)

plt.subplot(224)

plt.axis('off')

plt.imshow(data, cmap='bone')

plt.imshow(heatmap, cmap='rainbow', alpha=0.3)

plt.savefig(r'E:\study\研究生\笔记\studyNote\others\imgs\tmp.png')

# 使用

heatmap = np.load("CNcam.npy")

img_data = np.load(img_path)

easy_show(img_data[:, 84, :], heatmap[:, 84, :])

图像融合结果:

java生成png热力图_热力图与原始图像融合

3. 热力图与原始图融合优化

上面图像融合之后存在的问题是,前景热力图完全遮挡了原图,使得最终的展示图中,原图结构存在模糊。首先对热力图进行优化,使背景颜色变为白色且去掉一些权重过小热力。然后将热力图剩余的部分叠加到原图上。

def img_fusion(img1, img2, save_path):

dpi = 100

save_fig(img1, dpi, "cam.png")

img = Image.open("cam.png")

img = np.array(img)

for i in range(len(img)):

for j in range(len(img[0])):

if img[i][j][0] == 127 and img[i][j][1] == 0 and img[i][j][2] == 255 \

and img[i][j][3] == 255:

img[i][j][:] = 255

save_fig(img2, dpi, "data.png", "bone")

cam_img = cv2.imread("cam.png")

data_img = cv2.imread("data.png")

cam_gray = cv2.cvtColor(cam_img, cv2.COLOR_BGR2GRAY)

rest, mask = cv2.threshold(cam_gray, 80, 255, cv2.THRESH_BINARY)

cam_fg = cv2.bitwise_and(cam_img, cam_img, mask=mask)

dst = cv2.addWeighted(cam_fg, 0.4, data_img, 1, 0)

add_cubic = cv2.resize(dst, (dst.shape[1] * 4, dst.shape[0] * 4), cv2.INTER_CUBIC)

cv2.imwrite(save_path, add_cubic)

使用上面的函数(上面的图像不正,首先向左旋转90°,之后再进行融合):

heatmap = np.load("CNcam.npy")

img_data = np.load(img_path)

heatmap = np.where(heatmap < 0.3, 0, heatmap) * 255

img_data = np.rot90(img_data[:, 84, :], 1) # 向左旋转90度

heatmap = np.rot90(heatmap[:, 84, :], 1)

img_fusion(heatmap, img_data, r'tmp.png')

绘制结果:

java生成png热力图_热力图与原始图像融合