天天看點

學習筆記之思路整理

1.圖檔處理:(流程被配置設定在16個線程中處理)

圖檔會被統一裁剪到24x24像素大小,裁剪中央區域用于評估或随機裁剪用于訓練;

圖檔會進行近似的白化處理,使得模型對圖檔的動态範圍變化不敏感。

對圖像進行随機的左右翻轉;

随機變換圖像的亮度;

随機變換圖像的對比度;

訓練方法與損失的定義:

訓練一個可進行N維分類的網絡的常用方法是使用多項式邏輯回歸(softmax 回歸),

Softmax 回歸在網絡的輸出層上附加了一個softmax nonlinearity,

并且計算歸一化的預測值和label的1-hot encoding的交叉熵。

在正則化過程中,對所有學習變量應用權重衰減損失(使用了L2範式,強調模型的參數的稀疏性),

求交叉熵損失和所有權重衰減項的和,loss()函數的傳回值就是這個值

2.資料讀取:

(1)讀取:

讀取檔案隊列名,用read_cifar10()來擷取一個樣本的資訊結構體(大小、資料、标簽),

使用tf.cast轉換uint8成float32

(2)切割:read_cifar10(),該函數從二進制資料中讀取資料并規整,

每條樣本都是先标簽後資料,CIFAR10是一個位元組标簽,

CIFAR100是2位元組,使用切片函數tf.slice()

(3)處理原始圖檔:初步擷取資料後就需要變形成tensor了, tf.random_crop(reshaped_image,[height,

width,3]) 1D變換成3D,對圖像進行了很多随機扭曲處理…通過tf.train.shuffle_batch中設定隊列大小、緩沖區大小,

直接就保證整理好一個資料集合的隊列

3.建立訓練網絡:

(1)參數設定函數

_variable_with_weight_decay(name,shape,stddev,wd)

對應功能:輸入名稱、形狀、偏差和均值 就定義一個參數tensor

(2)生成資料

先設定常量參數,再由tf.nn.l2_loss(var)增加L2範式稀疏化

L2範式定義為:output = sum(t ** 2) / 2,然後乘以一個衰減系數wd做為一個訓練名額:

這個值應該盡量小,以保證稀疏性

用tf.add_to_collection(‘losses’,weight_decay)把所有的系數作為以losses為标簽進行收集

用summary用于檢視輸出的稀疏性:tf.scalar_summary(tensor_name+’/sparsity’,

tf.nn.zero_fraction(x)),統計0的比例反應稀疏性。tf.histogram_summary(tensor_name+’/activations’,

x),輸出數值的分布直接反應神經元的活躍性,如果全是很小的值說明不活躍。

全連接配接層的展開次元192、384這些數字與GPU的架構有關,全連接配接層的wd是0.004,略微強調了一下稀疏性

ps:多個GPU需要tf.get_variable()用于分享資料,而單個GPU隻需要tf.Variable()

4.損失函數:

cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits,

labels,name=’cross_entropy_per_example’)

5.訓練:

(1)學習率更新:首先是根據目前的訓練步數、衰減速度、之前的學習速率确定新的學習速率

lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE……… staircase=True)

式子:decayed_learning_rate=learening_rate*decay_rata^(global_step/decay_steps)

如果staircase=True則取整數

(2)均值線(ExponentialMovingAverage)

(3)計算梯度及更新梯度compute_gradients,opt.apply_gradients反向傳播

(4)summary和句柄

6.測試模型

(1)傳入驗證函數的參數:

eval_once(saver, saver是用讀取moving_average的

summary_writer, summary_writer和summary_op是儲存記錄的

top_k_op, top_k_op傳入了模型和驗證模型

summary_op)

(2) 讀取檢查點:

ckpt=tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)

從檢查點恢複圖和參數:

saver.restore(sess,ckpt.model_checkpoint_path)

繼續閱讀