模型构建方法
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.c1 = nn.Conv2d(3, 6 ,5)
self.p = nn.MaxPool2d(2,2)
self.c2 = nn.Conv2d(6, 16, 5)
self.l1 = nn.Linear(16*5*5, 120)
self.l2 = nn.Linear(120, 84)
self.l3 = nn.Linear(84, 10)
def forward(self, x):
"""前向传播过程"""
# 卷积 过激活 池化
x = self.p(func.relu(self.c1(x)))
x = self.p(func.relu(self.c2(x)))
# 全连接
x = x.view(x.size(0), -1) # 全连接层前将向量特征拉直 -1代表由计算机进行计算
x = func.relu(self.l1(x)) # .size(0)表示求第0维个数
x = func.relu(self.l2(x))
x = self.l3(x)
return x
全连接第一层的输入为特征向量拉直后的个数。
损失函数构建方法
# 两种方案
# 1. 直接从functional模块中调用使用
loss = torch.nn.functional.cross_entropy(output, label)
# 2. 利用构造器构造损失函数 再调用
# CrossEntropy是个构造器
# 所以loss = torch.nn.CrossEntropyLoss()(output, target)这么写也对
loss_func = nn.CrossEntropyLoss()
loss = loss_func(output, label)
# 写成loss = torch.nn.CrossEntropyLoss(output, target) 则报错
# RuntimeError: bool value of Tensor with more than one value is ambiguous loss =
# nn.CrossEntropy...
值得注意的是,关于pytorch中对标签的转换,如果使用了交叉熵损失函数,并不需要我们单独进行one-hot编码,因为该函数已经替我们执行了这一操作,我们只需要出入longtensor类型的label就可以
.item()方法
计算地到的loss是一个Tensor标量,需要使用.item()方法取出loss值
模型存取
# 模型快速存储
path = "./cifar_net.pth" # .state_dict()是网络的状态字典
torch.save(net.state_dict(), path) # 存放训练过程中需要学习的权重和偏执系数
# 模型读取
net = CNN()
path = "./cifar_net.pth"
net.load_state_dict(torch.load(path))
GPU加速
# 需要分别将要训练的数据和模型都迁移至GPU上
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0")
net = CNN()
net.to(device)
inpt, label = data
inpt = inpt.to(device)
label = label.to(device)
CIFAR10数据集本地加载
对于官方下载数据集慢的方式,使用torchvision下载也很慢,所以参照之前解决MINIST数据集本地文件下载的方法:
使用torchvision下载MINIST及配置踩坑笔记
1、首先按照上面的链接提前下载好数据集
2、进入CIFAR10函数
3、修改url,将url由官网下载改成本地下载
![](https://img.laitimes.com/img/__Qf2AjLwojIjJCLyojI0JCLiAzNfRHLGZkRGZkRfJ3bs92YsYTMfVmepNHLxUleOFza61UNNpHW4Z0MMBjVtJWd0ckW65UbM5WOHJWa5kHT20ESjBjUIF2X0hXZ0xCMx81dvRWYoNHLrdEZwZ1Rh5WNXp1bwNjW1ZUba9VZwlHdssmch1mclRXY39CXldWYtlWPzNXZj9mcw1ycz9WL49zZuBnL4AzN0QDMyEjM0IDMxAjMwIzLc52YucWbp5GZzNmLn9Gbi1yZtl2Lc9CX6MHc0RHaiojIsJye.png)