天天看點

訓練神經網絡時train loss或是validation loss出現nan

最近使用帶有SE block的網絡在pytorch架構下做訓練。training loss 随着epoch增多不斷下降,但是突然到某一個epoch出現loss為nan的情況,但是兩三個epoch之後,loss竟然又恢複正常,而且下降了。

這幾篇部落格是我debug的借鑒,真的非常有用。

這篇介紹了出現nan的基本解決思路。

https://blog.csdn.net/qq_32799915/article/details/80612342

這篇介紹了為什麼在多層dense layer之後某一層dense layer輸出可能會出現nan,以及weight初始化的重要性及初始化方法。

https://yey.world/2020/12/17/Pytorch-14/

我自己的基本排查思路是:

  1. 首先檢查資料中是否有inf 或者nan的情況。

普通numpy數組可用

我自己用的是醫療圖像nifiti格式的壓縮圖像,是以使用代碼如下:

def check_image(img_fname: str):
	#檢查圖像/numpy數組中是否有nan存在
    npy = sitk.GetArrayFromImage(sitk.ReadImage(img_fname))
    return np.any(np.isnan(npy))

def check_image_inf(img_fname: str):
	#檢查圖像/numpy數組中資料是否都是finit的(提示:若使用此函數,單獨檢查nan的函數可不用)
    npy = sitk.GetArrayFromImage(sitk.ReadImage(img_fname))
    return np.all(np.isfinite(npy))

def check_for_nan(input_folder: str):
    nii_files = subfiles(input_folder, suffix='.nii.gz')
    for n in nii_files:
        if check_image(n):
            print("nans found in ", n)
        elif not check_image_inf(n):
            print("infs found in ", n)
            
img_fold = '' #儲存資料的檔案夾
check_for_nan(img_fold)
           

如果資料沒有問題,檢查資料是否有normalization,如果沒有歸一化也可能出現nan或者網絡層計算中出現infinit。

  1. 檢查使用loss是否帶有除法,算log的時候有負數或者很小的數。

    我所用的檢查loss是否為nan的方法:

如果loss中有infit或者nan,則會輸出

'loss is nan or ifinit', loss(這裡會輸出loss的值)
           

如果确認loss也并沒有問題,那麼問題可能出現在forward path中。

  1. 檢查forward path每一層的輸出結果,進行問題定位。在每一層後加入:

如果是某一層計算出問題,考慮是不是初始化函數沒有使用或者用得不對。

接下來,就到我的檢查血淚史了。我發現我SE block的nn.AdaptiveAvgPool3d(1)輸出中有inf!!!奇怪的是,這一次層的輸入卻沒這個問題。後來我直接用了torch.nn.AvgPool3d代替前面的函數,它終于正常跑起來了。

b, c, D, H, W = x.size()  #b: batch size, c: channels, (D, H, W) data shape
y = torch.nn.AvgPool3d((D,H,W), padding=0)(x)
y = y.view(b, c) # 将y變成shape為(b,c)的tensor, 後面接全連接配接層
           

一周的辛酸血淚,終于好了。

具體nn.AdaptiveAvgPool3d()函數的内部實作我還沒來得及研究,之後搞明白了再來分享到底為啥出問題。