天天看點

解決pytorch訓練的過程中記憶體一直增加的問題

代碼中存在累加loss,但每步的loss沒加item()。

pytorch中,.item()方法 是得到一個元素張量裡面的元素值

具體就是 用于将一個零維張量轉換成浮點數,比如計算loss,accuracy的值

就比如:

loss = (y_pred - y).pow(2).sum()

print(loss.item())

for epoch in range(100):
    index=np.arange(train_sample.shape[0])
    np.random.shuffle(index)
    train_set=train_sample[index].tolist()

    model.train()
    loss,s=0,0

    for s in tqdm(range(0,train_sample.shape[0],batch_size)):
        if s+batch_size>train_sample.shape[0]:
            break
        batch_loss=model(train_set[s:s+batch_size])
        
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()
        
        # 會導緻記憶體一直增加,需改為loss+=batch_loss.item()
        loss+=batch_loss
        s+=batch_size
        
    loss/=total_batch
    print(epoch,loss)
    if (epoch+1) % 10 ==0:
        model.eval()
        model.save_embedding(epoch)
           

以上代碼會導緻記憶體占用越來越大,解決的方法是:loss+=batch_loss.item()。值得注意的是,要複現記憶體越來越大的問題,模型中需要切換model.train() 和 model.eval(),train_loss以及eval_loss的作用是儲存模型的平均誤差(這裡是累積誤差),儲存到tensorboard中。

繼續閱讀