在pytorch 中一些常用的功能都已经被封装成了模块,所以我们只需要继承并重写部分函数即可。首先介绍一下本文最终希望实现的目标, 对本地的一维数据 (1xn)的ndarry 进行一个多分类,数据集为mn的数据,标签为m1的数组。下面是结合代码记录一下踩坑过程。
继承Dataset类,可以看到我这里重写了三个函数,init 函数用于载入numpy数据并将其转化为相应的tensor,__gititem__函数用于定义训练时会返回的单个数据与标签,__len__表示数据数量m。
通过继承nn.Module来自定义神经网络
其中__init__函数来自定义定义我们需要的网络参数,这里我们block1 的in_channels为1,输出参数可根据需要自己设定,但而且当前层的输出channel应该和下一层的输入channel相同,
注意:MaxPool1d的inchannel需要自己计算一下,当然如果你不想算,可以给个参数直接运行,看报错信息的提示
__forward__ 函数定义了网络的连接方式,注意此处应返回x。
主程序。为了更好的说明,先放一下主程序。这里的程序是已经载入了数据的,data是mn 数组,label为m1数组。
实例化DataLoader的第一个参数是Dataset的实例,通过DataLoader,其功能是为下文训练和测试过程提供数据。
定义训练阶段,从DataLoader中取出数据,这里X,y分别为batch_sizen,batch_size1的数据。
首先要进行一个调整,将X调整为batch_size1n的float,设置float的转化过程放在Dataset的初始化函数里完成了
注意:如果没有这一步会报错
期望是long但得到了float的错误。(虽然我也不明白为啥错误不是期望float...)
y为1*batch的数组并转化成long(这里y的形式可能与损失函数有关)
定义测试过程(同上)
大体流程就是这些,最后记得修改加入输出语句与保存模型等操作。