天天看點

pytorch 通過load_state_dict加載權重

1、如果預訓練網絡與目前網絡一模一樣

net = model(input_channels=input_channels, angle_cls=args.angle_bins)
# 加載預訓練模型
pretrained_dict = torch.load("預訓練權重路徑")
net.load_state_dict(pretrained_dict) 
           

2、如果預訓練網絡和目前網絡中有一部分網絡層的名稱不同

如,預訓練網絡中有一層卷積定義如下:

但是目前網絡把這一層去掉了。通過以下方式加載權重

net = model(input_channels=input_channels, angle_cls=args.angle_bins)
# 加載預訓練模型
pretrained_dict = torch.load("預訓練權重路徑")
model_dict = net.state_dict()
state_dict = {k:v for k,v in pretrained_dict.items() if k in sgdn_dict.keys()}
model_dict .update(state_dict)
net.load_state_dict(model_dict)
           

3、如果如果預訓練網絡和目前網絡中有一部分網絡層的名稱相同,但形狀不同

比如有個網絡,預訓練時以4通道的RGBD圖像作為輸入;現在訓練時,以單通道的深度圖為輸入,即第一個層的權重尺寸不同。通過以下方式加載權重

net = model(input_channels=input_channels, angle_cls=args.angle_bins)
# 加載預訓練模型
pretrained_dict = torch.load("預訓練權重路徑")
model_dict = net.state_dict()
state_dict = {k:v for k,v in pretrained_dict.items() if k in sgdn_dict.keys() and v.shape == model_dict [k].shape} 
model_dict .update(state_dict)
net.load_state_dict(model_dict) 
           

繼續閱讀