天天看點

利用pytorch實作Visualising Image Classification Models and Saliency Maps

素材來源自cs231n-assignment3-NetworkVisualization

saliency map

saliency map即特征圖,可以告訴我們圖像中的像素點對圖像分類結果的影響。

計算它的時候首先要計算與圖像像素對應的正确分類中的标準化分數的梯度(這是一個标量)。如果圖像的形狀是(3, H, W),這個梯度的形狀也是(3, H, W);對于圖像中的每個像素點,這個梯度告訴我們當像素點發生輕微改變時,正确分類分數變化的幅度。

計算saliency map的時候,需要計算出梯度的絕對值,然後再取三個顔色通道的最大值;是以最後的saliency map的形狀是(H, W)為一個通道的灰階圖。

下圖即為例子:

上圖為圖像,下圖為特征圖,可以看到下圖中亮色部分為神經網絡感興趣的部分。

理論依據

需要注意一下:

程式解釋

下面為計算特征圖函數,上下文資訊通過注釋來擷取。

def compute_saliency_maps(X, y, model):
    """
    使用模型圖像(image)X和标記(label)y計算正确類的saliency map.

    輸入:
    - X: 輸入圖像; Tensor of shape (N, 3, H, W)
    - y: 對應X的标記; LongTensor of shape (N,)
    - model: 一個預先訓練好的神經網絡模型用于計算X.

    傳回值:
    - saliency: A Tensor of shape (N, H, W) giving the saliency maps for the input
    images.
    """
    # Make sure the model is in "test" mode
    model.eval()

    # Wrap the input tensors in Variables
    X_var = Variable(X, requires_grad=True)
    y_var = Variable(y)
    saliency = None
    ##############################################################################
    #
    # 首先進行前向操作,将輸入圖像pass through已經訓練好的model,再進行反向操作,
    # 進而得到對應圖像,正确分類分數的梯度
    # 
    ##############################################################################

    # 前向操作
    scores = model(X_var)

    # 得到正确類的分數,scores為[5]的Tensor
    scores = scores.gather(1, y_var.view(-1, 1)).squeeze() 

    #反向計算,從輸出的分數到輸入的圖像進行一系列梯度計算
    scores.backward(torch.FloatTensor([1.0,1.0,1.0,1.0,1.0])) # 參數為對應長度的梯度初始化
#     scores.backward() 必須有參數,因為此時的scores為非标量,為5個元素的向量

    # 得到正确分數對應輸入圖像像素點的梯度
    saliency = X_var.grad.data

    saliency = saliency.abs() # 取絕對值
    saliency, i = torch.max(saliency,dim=1)  # 從3個顔色通道中取絕對值最大的那個通道的數值
    saliency = saliency.squeeze() # 去除1維
#     print(saliency)

    return saliency           

再定義一個顯示圖像函數,進行圖像顯示

def show_saliency_maps(X, y):
    # Convert X and y from numpy arrays to Torch Tensors
    X_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in X], dim=0)
    y_tensor = torch.LongTensor(y)

    # Compute saliency maps for images in X
    saliency = compute_saliency_maps(X_tensor, y_tensor, model)

    # Convert the saliency map from Torch Tensor to numpy array and show images
    # and saliency maps together.
    saliency = saliency.numpy()
    N = X.shape[0]

    for i in range(N):
        plt.subplot(2, N, i + 1)
        plt.imshow(X[i])
        plt.axis('off')
        plt.title(class_names[y[i]])
        plt.subplot(2, N, N + i + 1)
        plt.imshow(saliency[i], cmap=plt.cm.hot)
        plt.axis('off')
        plt.gcf().set_size_inches(12, 5)
    plt.show()

show_saliency_maps(X, y)           

output:

另一種梯度的計算法,通過了損失函數計算出來的梯度

out = model( X_var )  
    loss_func = torch.nn.CrossEntropyLoss()
    loss = loss_func( out, y_var ) 
    loss.backward()
    grads = X_var.grad
    grads = grads.abs()
    mx, index_mx = torch.max( grads, 1 )
#     print(mx, index_mx)
    saliency = mx.data
#     print(saliency)           

這中方法的output為:

參考資料:

1、 Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. “Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps”, ICLR Workshop 2014.

2、

http://cs231n.stanford.edu/syllabus.html