pytorch權值初始化
官方論壇關于weight-initilzation的讨論
對模型參數進行初始化
官方論壇連結:https://discuss.pytorch.org/t/weight-initilzation/157/3
單獨定義一個weights_init函數,輸入參數是m(torch.nn.module或者自己定義的繼承nn.module的子類)
然後使用net.apply()進行參數初始化
m.__class__.__name__ 獲得nn.module的名字
DCGAN的Github連結
# DCGAN中權重初始化代碼
def weights_init(m):
classname = m.__class__.__name__ # 得到網絡層的名字,如ConvTransposed2d
if classname.find('Conv') != -1: # 使用find函數,如果不存在傳回值為 -1 ,是以讓其不等于 -1
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# Generator Code
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
######################################################################
# Now, we can instantiate the generator and apply the ``weights_init``
# function. Check out the printed model to see how the generator object is
# structured.
# Create the generator
netG = Generator(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
netG = nn.DataParallel(netG, list(range(ngpu)))
# Apply the weights_init function to randomly initialize all weights
# to mean=0, stdev=0.2.
netG.apply(weights_init)
torch.nn.Module.apply(fn)中apply的應用
參考Pytorch官方手冊troch.nn中apply
torch.nn.Module.apply(fn)
# 遞歸的調用weights_init函數,周遊nn.Module的submodule作為參數
# 常用來對模型的參數進行初始化
# fn是對參數進行初始化的函數的句柄,fn以nn.Module或者自己定義的nn.Module的子類作為參數
# fn (Module -> None) – function to be applied to each submodule
# Returns: self
# Return type: Module
import torch
import torch.nn as nn
def init_weights(m):
classname = m.__class__.__name__
print(m)
if classname.find('Line') != -1:
m.weight.data.fill_(1.0)
print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(3, 3))
net.apply(init_weights)
結果:
Linear(in_features=2, out_features=2, bias=True)
1 1
1 1
[torch.FloatTensor of size (2,2)]
Linear(in_features=3, out_features=3, bias=True)
1 1 1
1 1 1
1 1 1
[torch.FloatTensor of size (3,3)]
Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=3, out_features=3, bias=True)
)
Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=3, out_features=3, bias=True)
apply(fn):将fn函數遞歸地應用到網絡模型的每個子模型中,主要用在參數的初始化。
使用apply()時,需要先定義一個參數初始化的函數。
之後,定義自己的網絡,得到網絡模型,使用apply()函數,就可以分别對conv層和bn層進行參數初始化。
Reference:
pytorch的weight-initilzation
pytorch使用記錄(二) 參數初始化
DCGAN TUTORIAL----做項目學pytorch