當pytorch模型寫明是eval()時有時表現的結果相對于train(True)差别非常巨大,這種差别經過逐層檢視,主要來源于使用了BN,在eval下,使用的BN是一個固定的running rate,而在train下這個running rate會根據輸入發生改變。
解決方案是凍住bn
def freeze_bn(m):
if isinstance(m, nn.BatchNorm2d):
m.eval()
model.apply(freeze_bn)
這樣可以獲得穩定輸出的結果。