天天看點

tensorflow中slim進階庫的應用                                  tensorflow中slim庫學習

                                  tensorflow中slim庫學習

        在閱讀用tensorflow實作的深度學習網絡結構的源碼時,經常會看到作者使用TF中封裝的slim進階庫,看起來(實際上也是)比直接調用TF的API簡潔好多。為了弄懂網絡源碼和學習slim庫應用,特地查閱了一些資料,在這裡做一下學習時的記錄。         tensorflow中關于 slim庫的介紹         某位部落客關于上面slim英文介紹的一些 翻譯         下面直接貼出我實作的一個應用slim庫的代碼:

import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('./data/MNIST',one_hot=True)


def cal_loss(y_pre,y_label):
    return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_label, logits=y_pre))

def cal_accuracy(y_pre,y_label):
    return tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(y_pre, dimension=1),tf.arg_max(y_label, dimension=1)),tf.float32))

def network(inputs,y_label):
    with slim.arg_scope([slim.conv2d],######可以在清單裡添加其他要簡化的操作,比如再添加全連接配接。函數中下面的參數是預設執行的操作
                        activation_fn=tf.nn.relu,########可以應用自己編寫的激活函數
                        weights_initializer=slim.xavier_initializer(),####預設xavier_initializer初始化權值
                        biases_initializer=tf.zeros_initializer(),
                        weights_regularizer=slim.l2_regularizer(0.0005),
                        padding='SAME'):
        print inputs.get_shape()
        net = slim.conv2d(inputs,num_outputs=32,kernel_size=[3,3],stride=1,scope='conv1')
        print net.get_shape()
        net = slim.max_pool2d(net, kernel_size=[2,2], stride=2, scope='pool1')
        print net.get_shape()
        net = slim.conv2d(net, num_outputs=64, kernel_size=[3,3], stride=1, scope='conv2')
        print net.get_shape()
        net = slim.max_pool2d(net, kernel_size=[2,2], stride=2, scope='pool2')
        print net.get_shape()
        net = slim.conv2d(net,num_outputs=64,kernel_size=[3,3],scope='conv3')
        print net.get_shape()
        fc_flat = slim.flatten(net)
        print fc_flat.get_shape()
        fc1 = slim.fully_connected(fc_flat, num_outputs=512, scope='fc1')
        print fc1.get_shape()
        y_out = slim.fully_connected(fc1, num_outputs=10, scope='y_out')
        print y_out.get_shape()
        
        accuracy = cal_accuracy(y_out, y_label)
        l2_loss = tf.add_n(slim.losses.get_regularization_losses())
        return cal_loss(y_out,y_label) + l2_loss , accuracy
    
def main():
    x_data = tf.placeholder(dtype=tf.float32, shape=[None,784], name='x_data')
    y_label = tf.placeholder(dtype=tf.float32, shape=[None,10], name='y_label')
    
    x_input = tf.reshape(x_data, shape=[-1,28,28,1], name='x_input')
    loss,accuracy = network(x_input,y_label)
    train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        for i in range(30001):
            xs,ys = mnist.train.next_batch(64)
            if i % 1000 == 0:
                loss_op,ac = sess.run([loss,accuracy],feed_dict={x_data:xs,y_label:ys})
                print 'the %dth iteration loss: %f'%(i,loss_op)
                print 'the %dth iteration accuracy: %f'%(i,ac)
            sess.run(train_op,feed_dict={x_data:xs,y_label:ys})
        total_acc = sess.run(accuracy,feed_dict={x_data:mnist.validation.images,y_label:mnist.validation.labels})
        print 'the total accuracy: %f'%(total_acc)
        
            

if __name__ == '__main__':
    main()
           

代碼運作結果:

tensorflow中slim進階庫的應用                                  tensorflow中slim庫學習

        在上面添加的連結中有關于slim的詳細介紹,大家仔細看看即可,有不懂的或者我寫錯了的地方,可以交流~

繼續閱讀