#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File :pytorch學習 -> RNN_regression
@IDE :PyCharm
@Author :zgq
@Date :2021/3/23 12:11
@Desc :
=================================================='''
import torch
from torch import nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)
#hyper parameters
TIME_STEP=10
INPUT_SIZE=1
LR=0.02
#show data
steps=np.linspace(0,np.pi*2,100,dtype=np.float32)
x_np=np.sin(steps)
y_np=np.cos(steps)
plt.plot(steps, y_np, 'r-', label='target (cos)')
plt.plot(steps, x_np, 'b-', label='input (sin)')
plt.legend(loc='best')
plt.show()
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn=nn.RNN( #上一節使用的是LSTM,這次使用普通的rnn
input_size=INPUT_SIZE,
hidden_size=32,
num_layers=1,
batch_first=True
)
self.out=nn.Linear(32,1)
def forward(self,x,h_state):
r_out,h_state=self.rnn(x,h_state)
#他們分别的shape:
#x(batch,time_step,input_size
#h_state(n_layers,batch,hidden_size)
#r_out(batch,time_step,output_size
outs=[]
for time_step in range(r_out.size(1)): #将每個time_step中的hiddenlayer資料取出來來做一個上面的nn.Liner來将他從32維變成一個回歸的1維,
#最後要将所有步驟的輸出都放入這個outs 的list中
outs.append(self.out(r_out[:,time_step,:])) #self.out中的輸入就是我們這個time——steo時間點的輸入
return torch.stack(outs,dim=1),h_state
rnn=RNN()
print(rnn)
optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)
loss_func=nn.MSELoss()
h_state=None
plt.figure(1, figsize=(12, 5))
plt.ion() # continuously plot
for step in range(100):
start,end=step*np.pi,(step+1)*np.pi
#use sin predict cos
steps=np.linspace(start,end,TIME_STEP,dtype=np.float32)
x_np=np.sin(steps)
y_np=np.cos(steps)
#處理資料加一個次元
x=Variable(torch.from_numpy(x_np[np.newaxis,:,np.newaxis]))
y=Variable(torch.from_numpy(x_np[np.newaxis,:,np.newaxis]))
prediction,h_state=rnn(x,h_state)
h_state=h_state.data #!!再次傳入的時候一定要用variable再包起來
loss=loss_func(prediction,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# plotting
plt.plot(steps, y_np.flatten(), 'r-')
plt.plot(steps, prediction.data.numpy().flatten(), 'b-')
plt.draw();
plt.pause(0.05)
plt.ioff()
plt.show()