在開始之前,請記住,存儲量不是參數量!!!!
存儲量不是參數量!!!!
存儲量不是參數量!!!!
通常來說我們在訓練模型的時候會用checkpoint的方法把模型儲存下來,一個模型小則幾十M,大則上百M,并且我們很多時候會把這個存儲量誤認為是參數量,比如
但是實際上這個是存儲量而不是參數量,那麼在pytorch中參數量怎麼計算呢?
實際上我們可以直接在pytorch代碼中print一個模型的參數:
話不多說,代碼如下:
params = list(self.get_model().parameters())#所有參數放在params裡
k = 0
for i in params:
l = 1
for j in i.size():
l*=j #每層的參數存入l,這裡也可以print 每層的參數
k = k+l #各層參數相加
print("all params:"+ str(k)) #輸出總的參數
通常來說這個params最後print出來要除以100萬轉換為M為機關(論文中常用的格式)
好了,這就是pytorch中模型參數量的計算方法