tf.cast()函數的作用是執行 tensorflow 中張量資料類型轉換,比如讀入的圖檔如果是int8類型的,一般在要在訓練前把圖像的資料格式轉換為float32。
cast定義:
cast(x, dtype, name=None)
- 第一個參數 x: 待轉換的資料(張量)
- 第二個參數 dtype: 目标資料類型
- 第三個參數 name: 可選參數,定義操作的名稱
int32轉換為float32:
import tensorflow as tf
t1 = tf.Variable([1,2,3,4,5])
t2 = tf.cast(t1,dtype=tf.float32)
print 't1: {}'.format(t1)
print 't2: {}'.format(t2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(t2)
print t2.eval()
# print(sess.run(t2))
輸出:
t1: <tf.Variable 'Variable:0' shape=(5,) dtype=int32_ref>
t2: Tensor("Cast:0", shape=(5,), dtype=float32)
[ 1. 2. 3. 4. 5.]
tensorflow中的資料類型清單:
資料類型 | Python 類型 | 描述 |
| | 32 位浮點數. |
| | 64 位浮點數. |
| | 64 位有符号整型. |
| | 32 位有符号整型. |
| | 16 位有符号整型. |
| | 8 位有符号整型. |
| | 8 位無符号整型. |
| | 可變長度的位元組數組.每一個張量元素都是一個位元組數組. |
| | 布爾型. |
| | 由兩個32位浮點數組成的複數:實數和虛數. |
| | 用于量化Ops的32位有符号整型. |
| | 用于量化Ops的8位有符号整型. |
|