錯誤
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict:
Missing key(s) in state_dict: “conv1.weight”, “bn1.weight”,
Unexpected key(s) in state_dict: “module.conv1.weight”, “module.bn1.weight”
原因
在訓練模型的時候使用了torch.nn.DataParallel但是加載卻沒有進行這一步
解決
- 讀取模型前将新模型也使用torch.nn.DataParallel進行包裝
- 使用OrderdDict去除module.,具體代碼如下:
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
new_state_dict[k[7:]] = v
model.load_state_dict(new_state_dict)
- 可以試試model.load_state_dict(new_state_dict, strict=False)
前兩種方法是确定可以成功的,第三種不一定。
題外話
如果你需要修改網絡層,比如想修改fc層,那麼需要在使用
torch.nn.DataParallel
前進行修改,是以需要使用方法二來解決,在将參數全部設定好并且初始化完畢擴充或者修改的層後再使用
torch.nn.DataParallel
,否則會報錯網絡沒有fc層的類似錯誤