天天看點

pytorch讀取模型失敗RuntimeError: Error(s) in loading state_dict for ResNet: Missing key(s) in state_dict

錯誤

RuntimeError: Error(s) in loading state_dict for ResNet:

Missing key(s) in state_dict:

Missing key(s) in state_dict: “conv1.weight”, “bn1.weight”,

Unexpected key(s) in state_dict: “module.conv1.weight”, “module.bn1.weight”

原因

在訓練模型的時候使用了torch.nn.DataParallel但是加載卻沒有進行這一步

解決

  1. 讀取模型前将新模型也使用torch.nn.DataParallel進行包裝
  2. 使用OrderdDict去除module.,具體代碼如下:

    from collections import OrderedDict

    new_state_dict = OrderedDict()

    for k, v in state_dict.items():

      new_state_dict[k[7:]] = v

    model.load_state_dict(new_state_dict)

  3. 可以試試model.load_state_dict(new_state_dict, strict=False)

前兩種方法是确定可以成功的,第三種不一定。

題外話

如果你需要修改網絡層,比如想修改fc層,那麼需要在使用

torch.nn.DataParallel

前進行修改,是以需要使用方法二來解決,在将參數全部設定好并且初始化完畢擴充或者修改的層後再使用

torch.nn.DataParallel

,否則會報錯網絡沒有fc層的類似錯誤

繼續閱讀