1、如果預訓練網絡與目前網絡一模一樣
net = model(input_channels=input_channels, angle_cls=args.angle_bins)
# 加載預訓練模型
pretrained_dict = torch.load("預訓練權重路徑")
net.load_state_dict(pretrained_dict)
2、如果預訓練網絡和目前網絡中有一部分網絡層的名稱不同
如,預訓練網絡中有一層卷積定義如下:
但是目前網絡把這一層去掉了。通過以下方式加載權重
net = model(input_channels=input_channels, angle_cls=args.angle_bins)
# 加載預訓練模型
pretrained_dict = torch.load("預訓練權重路徑")
model_dict = net.state_dict()
state_dict = {k:v for k,v in pretrained_dict.items() if k in sgdn_dict.keys()}
model_dict .update(state_dict)
net.load_state_dict(model_dict)
3、如果如果預訓練網絡和目前網絡中有一部分網絡層的名稱相同,但形狀不同
比如有個網絡,預訓練時以4通道的RGBD圖像作為輸入;現在訓練時,以單通道的深度圖為輸入,即第一個層的權重尺寸不同。通過以下方式加載權重
net = model(input_channels=input_channels, angle_cls=args.angle_bins)
# 加載預訓練模型
pretrained_dict = torch.load("預訓練權重路徑")
model_dict = net.state_dict()
state_dict = {k:v for k,v in pretrained_dict.items() if k in sgdn_dict.keys() and v.shape == model_dict [k].shape}
model_dict .update(state_dict)
net.load_state_dict(model_dict)