天天看點

關于checkpoints的問題

關于單與多GPU的checkpoints的問題

經過驗證,多GPU訓練并儲存的checkpoints與單GPU checkpoints有差別

驗證用四個gpu【0,1,2,3】訓練的模型可以用【0,1,2】三個gpu來test,也可以用【0,1,2,3,5】五個gpu,甚至可以用【0】單個gpu來驗證,需要将代碼設定為以下:

net = get_network(args)
device_ids = [0,1,2]
net = nn.DataParallel(net, device_ids=device_ids)
           

其他地方無需改動,net在get_network已經放到cuda上了

if use_gpu:
        net = net.cuda()
    return net
           

net.cuda()

單GPU儲存的checkpoints也要用單個GPU,需要将其注釋掉

net = get_network(args)
#device_ids = [0,1,2]
#net = nn.DataParallel(net, device_ids=device_ids)
           

繼續閱讀