天天看點

torch.init.normal_和torch.init.constant_用法

torch.init.normal_:給tensor初始化,一般是給網絡中參數weight初始化,初始化參數值符合正态分布。

torch.init.normal_(tensor,mean=,std=) ,mean:均值,std:正态分布的标準差

代碼示例:

import torch
import torch.nn as nn
l=nn.Conv2d(2,2,kernel_size=1)
a=l.weight
print("a:",a)
b=nn.init.normal_(l.weight,mean=0,std=0.01)
print("b:",b)

           

輸出:

a: Parameter containing:
tensor([[[[-0.1551]],
         [[-0.6292]]],
        [[[ 0.5094]],
         [[ 0.3613]]]], requires_grad=True)
b: Parameter containing:
tensor([[[[-0.0013]],
         [[-0.0062]]],
        [[[ 0.0093]],
         [[ 0.0019]]]], requires_grad=True)
           

torch.init.constant_:初始化參數使其為常值,即每個參數值都相同。一般是給網絡中bias進行初始化。

torch.nn.init.constant_(tensor,val),

val:常量數值

import torch
import torch.nn as nn
l=nn.Conv2d(2,2,kernel_size=1)
a=l.bias
print("a:",a)
b=nn.init.constant_(l.bias,val=0)
print("b:",b)
           

輸出:

a: Parameter containing:
tensor([ 0.1539, -0.5360], requires_grad=True)
b: Parameter containing:
tensor([0., 0.], requires_grad=True)