天天看點

Pytorch“ntimeError: Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict:“

對訓練好的模型進行測試,得到測試樣本。通過下面的程式将模型參數導入到建立的模型中。

由于我們直接用torch.load()存儲的模型資訊會比較大,是以我們可以隻存儲參數資訊,進行測試時再将參數資訊導入到模型中(一定要與儲存的模型大小和内容相同)會提高效率。

torch.save(model.state_dict(),'hscnn_5layer_dim10_276.pkl')
#不直接用torch.save(mode,'hscnn_5layer_dim10_276.pkl')
           

下載下傳儲存的模型參數到測試程式:

model_path = './models/hscnn_5layer_dim10_276.pkl'
img_path = './test_imgs/'
result_path = './test_results1/'
var_name = 'rad'

save_point = torch.load(model_path)
model_param = save_point['state_dict']
print(model_param.keys())
model = resblock(conv_relu_res_relu_block,16,3,31)
model = nn.DataParallel(model)
model.load_state_dict(model_param)

model = model.cuda()
model.eval()

           

運作上面的程式在“model.load_state_dict(model_param)”位置會出現錯誤:

File "E:\install\Anaconda3\envs\pytorch_GPU\lib\site-packages\torch\nn\modules\module.py", line 1052, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DataParallel:
	Missing key(s) in state_dict: "module.input_conv.weight", "module.input_conv.bias", "module.conv_seq.0.conv1.weight", "module.conv_seq.0.conv1.bias", "module.conv_seq.0.conv2.weight", "module.conv_seq.0.conv2.bias", "module.conv_seq.1.conv1.weight", "module.conv_seq.1.conv1.bias", "module.conv_seq.1.conv2.weight", "module.conv_seq.1.conv2.bias", "module.conv_seq.2.conv1.weight", "module.conv_seq.2.conv1.bias", "module.conv_seq.2.conv2.weight", "module.conv_seq.2.conv2.bias", "module.conv_seq.3.conv1.weight", "module.conv_seq.3.conv1.bias", "module.conv_seq.3.conv2.weight", "module.conv_seq.3.conv2.bias", "module.conv_seq.4.conv1.weight", "module.conv_seq.4.conv1.bias", "module.conv_seq.4.conv2.weight", "module.conv_seq.4.conv2.bias", "module.conv_seq.5.conv1.weight", "module.conv_seq.5.conv1.bias", "module.conv_seq.5.conv2.weight", "module.conv_seq.5.conv2.bias", "module.conv_seq.6.conv1.weight", "module.conv_seq.6.conv1.bias", "module.conv_seq.6.conv2.weight", "module.conv_seq.6.conv2.bias", "module.conv_seq.7.conv1.weight", "module.conv_seq.7.conv1.bias", "module.conv_seq.7.conv2.weight", "module.conv_seq.7.conv2.bias", "module.conv_seq.8.conv1.weight", "module.conv_seq.8.conv1.bias", "module.conv_seq.8.conv2.weight", "module.conv_seq.8.conv2.bias", "module.conv_seq.9.conv1.weight", "module.conv_seq.9.conv1.bias", "module.conv_seq.9.conv2.weight", "module.conv_seq.9.conv2.bias", "module.conv_seq.10.conv1.weight", "module.conv_seq.10.conv1.bias", "module.conv_seq.10.conv2.weight", "module.conv_seq.10.conv2.bias", "module.conv_seq.11.conv1.weight", "module.conv_seq.11.conv1.bias", "module.conv_seq.11.conv2.weight", "module.conv_seq.11.conv2.bias", "module.conv_seq.12.conv1.weight", "module.conv_seq.12.conv1.bias", "module.conv_seq.12.conv2.weight", "module.conv_seq.12.conv2.bias", "module.conv_seq.13.conv1.weight", "module.conv_seq.13.conv1.bias", "module.conv_seq.13.conv2.weight", "module.conv_seq.13.conv2.bias", "module.conv_seq.14.conv1.weight", "module.conv_seq.14.conv1.bias", "module.conv_seq.14.conv2.weight", "module.conv_seq.14.conv2.bias", "module.conv_seq.15.conv1.weight", "module.conv_seq.15.conv1.bias", "module.conv_seq.15.conv2.weight", "module.conv_seq.15.conv2.bias", "module.conv.weight", "module.conv.bias", "module.output_conv.weight", "module.output_conv.bias". 
	Unexpected key(s) in state_dict: "input_conv.weight", "input_conv.bias", "conv_seq.0.conv1.weight", "conv_seq.0.conv1.bias", "conv_seq.0.conv2.weight", "conv_seq.0.conv2.bias", "conv_seq.1.conv1.weight", "conv_seq.1.conv1.bias", "conv_seq.1.conv2.weight", "conv_seq.1.conv2.bias", "conv_seq.2.conv1.weight", "conv_seq.2.conv1.bias", "conv_seq.2.conv2.weight", "conv_seq.2.conv2.bias", "conv_seq.3.conv1.weight", "conv_seq.3.conv1.bias", "conv_seq.3.conv2.weight", "conv_seq.3.conv2.bias", "conv_seq.4.conv1.weight", "conv_seq.4.conv1.bias", "conv_seq.4.conv2.weight", "conv_seq.4.conv2.bias", "conv_seq.5.conv1.weight", "conv_seq.5.conv1.bias", "conv_seq.5.conv2.weight", "conv_seq.5.conv2.bias", "conv_seq.6.conv1.weight", "conv_seq.6.conv1.bias", "conv_seq.6.conv2.weight", "conv_seq.6.conv2.bias", "conv_seq.7.conv1.weight", "conv_seq.7.conv1.bias", "conv_seq.7.conv2.weight", "conv_seq.7.conv2.bias", "conv_seq.8.conv1.weight", "conv_seq.8.conv1.bias", "conv_seq.8.conv2.weight", "conv_seq.8.conv2.bias", "conv_seq.9.conv1.weight", "conv_seq.9.conv1.bias", "conv_seq.9.conv2.weight", "conv_seq.9.conv2.bias", "conv_seq.10.conv1.weight", "conv_seq.10.conv1.bias", "conv_seq.10.conv2.weight", "conv_seq.10.conv2.bias", "conv_seq.11.conv1.weight", "conv_seq.11.conv1.bias", "conv_seq.11.conv2.weight", "conv_seq.11.conv2.bias", "conv_seq.12.conv1.weight", "conv_seq.12.conv1.bias", "conv_seq.12.conv2.weight", "conv_seq.12.conv2.bias", "conv_seq.13.conv1.weight", "conv_seq.13.conv1.bias", "conv_seq.13.conv2.weight", "conv_seq.13.conv2.bias", "conv_seq.14.conv1.weight", "conv_seq.14.conv1.bias", "conv_seq.14.conv2.weight", "conv_seq.14.conv2.bias", "conv_seq.15.conv1.weight", "conv_seq.15.conv1.bias", "conv_seq.15.conv2.weight", "conv_seq.15.conv2.bias", "conv.weight", "conv.bias", "output_conv.weight", "output_conv.bias". 

           

造成改原因的是字典内容表示不比對。我們打開“load_state_dict()”函數檢視内容。

def load_state_dict(self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]],
                        strict: bool = True):



        r"""Copies parameters and buffers from :attr:`state_dict` into
        this module and its descendants. If :attr:`strict` is ``True``, then
        the keys of :attr:`state_dict` must exactly match the keys returned
        by this module's :meth:`~torch.nn.Module.state_dict` function.
           

發現當strict為Ture時參數與模型的字典必須完全對應,否則會報錯。我們改成False報錯解除。

上面雖然會報錯解除,但有時也會出現一種問題,因為我們重建立立的網絡模型要與訓練儲存的模型大小和型号要相同。

我們訓練過程建立的模型如下。

if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    if torch.cuda.is_available():
        # model = nn.DataParallel(model)
        model.cuda()
        print('使用GPU訓練')
           

是以多顯示卡訓練時利用了nn.DataParallel(model),是以測試時,在參數導入之前也要有該過程,否則會報錯。如果是單卡訓練,不需要上述語句,直接建立模型load參數即可。

如果沒執行nn.DataParallel(model)語句,則産生的參數為:

輸出為:

print(model_param.keys())
print(model_param['input_conv.weight'].size())
#odict_keys(['input_conv.weight', 'input_conv.bias', 'conv_seq.0.conv1.weight', 'conv_seq.0.conv1.bias', 'conv_seq.0.conv2.weight', 'conv_seq.0.conv2.bias', 'conv_seq.1.conv1.weight', 'conv_seq.1.conv1.bias', 'conv_seq.1.conv2.weight', 'conv_seq.1.conv2.bias', 'conv_seq.2.conv1.weight', 'conv_seq.2.conv1.bias', 'conv_seq.2.conv2.weight', 'conv_seq.2.conv2.bias', 'conv_seq.3.conv1.weight', 'conv_seq.3.conv1.bias', 'conv_seq.3.conv2.weight', 'conv_seq.3.conv2.bias', 'conv_seq.4.conv1.weight', 'conv_seq.4.conv1.bias', 'conv_seq.4.conv2.weight', 'conv_seq.4.conv2.bias', 'conv_seq.5.conv1.weight', 'conv_seq.5.conv1.bias', 'conv_seq.5.conv2.weight', 'conv_seq.5.conv2.bias', 'conv_seq.6.conv1.weight', 'conv_seq.6.conv1.bias', 'conv_seq.6.conv2.weight', 'conv_seq.6.conv2.bias', 'conv_seq.7.conv1.weight', 'conv_seq.7.conv1.bias', 'conv_seq.7.conv2.weight', 'conv_seq.7.conv2.bias', 'conv_seq.8.conv1.weight', 'conv_seq.8.conv1.bias', 'conv_seq.8.conv2.weight', 'conv_seq.8.conv2.bias', 'conv_seq.9.conv1.weight', 'conv_seq.9.conv1.bias', 'conv_seq.9.conv2.weight', 'conv_seq.9.conv2.bias', 'conv_seq.10.conv1.weight', 'conv_seq.10.conv1.bias', 'conv_seq.10.conv2.weight', 'conv_seq.10.conv2.bias', 'conv_seq.11.conv1.weight', 'conv_seq.11.conv1.bias', 'conv_seq.11.conv2.weight', 'conv_seq.11.conv2.bias', 'conv_seq.12.conv1.weight', 'conv_seq.12.conv1.bias', 'conv_seq.12.conv2.weight', 'conv_seq.12.conv2.bias', 'conv_seq.13.conv1.weight', 'conv_seq.13.conv1.bias', 'conv_seq.13.conv2.weight', 'conv_seq.13.conv2.bias', 'conv_seq.14.conv1.weight', 'conv_seq.14.conv1.bias', 'conv_seq.14.conv2.weight', 'conv_seq.14.conv2.bias', 'conv_seq.15.conv1.weight', 'conv_seq.15.conv1.bias', 'conv_seq.15.conv2.weight', 'conv_seq.15.conv2.bias', 'conv.weight', 'conv.bias', 'output_conv.weight', 'output_conv.bias'])
#torch.Size([64, 3, 3, 3])
           

若執行nn.DataParallel(model),輸出為:

print(model_param.keys())
print(model_param['module.input_conv.weight'].size())

#odict_keys(['module.input_conv.weight', 'module.input_conv.bias', 'module.conv_seq.0.conv1.weight', 'module.conv_seq.0.conv1.bias', 'module.conv_seq.0.conv2.weight', 'module.conv_seq.0.conv2.bias', 'module.conv_seq.1.conv1.weight', 'module.conv_seq.1.conv1.bias', 'module.conv_seq.1.conv2.weight', 'module.conv_seq.1.conv2.bias', 'module.conv_seq.2.conv1.weight', 'module.conv_seq.2.conv1.bias', 'module.conv_seq.2.conv2.weight', 'module.conv_seq.2.conv2.bias', 'module.conv_seq.3.conv1.weight', 'module.conv_seq.3.conv1.bias', 'module.conv_seq.3.conv2.weight', 'module.conv_seq.3.conv2.bias', 'module.conv_seq.4.conv1.weight', 'module.conv_seq.4.conv1.bias', 'module.conv_seq.4.conv2.weight', 'module.conv_seq.4.conv2.bias', 'module.conv_seq.5.conv1.weight', 'module.conv_seq.5.conv1.bias', 'module.conv_seq.5.conv2.weight', 'module.conv_seq.5.conv2.bias', 'module.conv_seq.6.conv1.weight', 'module.conv_seq.6.conv1.bias', 'module.conv_seq.6.conv2.weight', 'module.conv_seq.6.conv2.bias', 'module.conv_seq.7.conv1.weight', 'module.conv_seq.7.conv1.bias', 'module.conv_seq.7.conv2.weight', 'module.conv_seq.7.conv2.bias', 'module.conv_seq.8.conv1.weight', 'module.conv_seq.8.conv1.bias', 'module.conv_seq.8.conv2.weight', 'module.conv_seq.8.conv2.bias', 'module.conv_seq.9.conv1.weight', 'module.conv_seq.9.conv1.bias', 'module.conv_seq.9.conv2.weight', 'module.conv_seq.9.conv2.bias', 'module.conv_seq.10.conv1.weight', 'module.conv_seq.10.conv1.bias', 'module.conv_seq.10.conv2.weight', 'module.conv_seq.10.conv2.bias', 'module.conv_seq.11.conv1.weight', 'module.conv_seq.11.conv1.bias', 'module.conv_seq.11.conv2.weight', 'module.conv_seq.11.conv2.bias', 'module.conv_seq.12.conv1.weight', 'module.conv_seq.12.conv1.bias', 'module.conv_seq.12.conv2.weight', 'module.conv_seq.12.conv2.bias', 'module.conv_seq.13.conv1.weight', 'module.conv_seq.13.conv1.bias', 'module.conv_seq.13.conv2.weight', 'module.conv_seq.13.conv2.bias', 'module.conv_seq.14.conv1.weight', 'module.conv_seq.14.conv1.bias', 'module.conv_seq.14.conv2.weight', 'module.conv_seq.14.conv2.bias', 'module.conv_seq.15.conv1.weight', 'module.conv_seq.15.conv1.bias', 'module.conv_seq.15.conv2.weight', 'module.conv_seq.15.conv2.bias', 'module.conv.weight', 'module.conv.bias', 'module.output_conv.weight', 'module.output_conv.bias'])
#torch.Size([64, 3, 3, 3])
           

繼續閱讀