In [1] : import numpy as np
import matplotlib.pyplot as plt
import sys,os,caffe
caffe_root='/home/che/caffe-master/'
sys.path.insert(0,caffe_root+'python')
os.chdir(caffe_root)
In[2]:caffe.set_device(0)
caffe.set_mode_gpu()
solver=caffe.SGDSolver('examples/cracks/solver.prototxt')
#niter等同于solver.prototxt檔案裡面的max_iter,即最大解算次數
In[3]:niter=150
test_interval=1
train_loss=np.zeros(niter)
test_acc=np.zeros(int(np.ceil(niter/test_interval)))
for it in range(niter):
solver.step(1)
train_loss[it]=solver.net.blobs['loss'].data
solver.test_nets[0].forward(start='conv1')
if it % test_interval==0:
acc=solver.test_nets[0].blobs['accuracy'].data
print 'Iteration', it, 'testing...','accuracy:',acc
test_acc[it // test_interval]=acc
In[4]:%matplotlib inline
In[5]:print test_acc
_,ax1=plt.subplots()
ax2=ax1.twinx()
ax1.plot(np.arange(niter),train_loss)
ax2.plot(test_interval*np.arange(len(test_acc)),test_acc,'r')
ax1.set_xlabel('iteration')
ax1.set_ylabel('train loss')
ax2.set_ylabel('test accuracy')