天天看點

深入了解 Tensorflow :如何讀訓練資料

深入了解 Tensorflow :如何讀訓練資料

以下分析來自 tensorflow slim 庫代碼精簡之後

dataset = dataset_factory.get_dataset(dataset_name, dataset_split_name, dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
          dataset,
          num_readers=num_readers,
          common_queue_capacity= * batch_size,
          common_queue_min= * batch_size)

    key, data = parallel_reader.parallel_read(
        dataset.data_sources,
        reader_class=dataset.reader,
        num_epochs=num_epochs,
        num_readers=num_readers,
        reader_kwargs=reader_kwargs,
        shuffle=shuffle,
        capacity=common_queue_capacity,
        seed=seed,
        scope=scope)

        data_files = get_data_files(dataset.data_sources)
        # 這裡對資料源建立一個 FIFO 隊列
        filename_queue = tf_input.string_input_producer(data_files, num_epochs=num_epochs, shuffle=shuffle, seed=seed, name='filenames')
            input_tensor = ops.convert_to_tensor(data_files, dtype=dtypes.string)
            if shuffle:
                input_tensor = random_ops.random_shuffle(input_tensor, seed=seed)
            # 最多讀 num_epochs 次,超過就會抛 OutOfRangeError,當 num_epochs 為 None 時,可以無限次讀
            input_tensor = limit_epochs(input_tensor, num_epochs)
            element_shape = input_tensor.shape[:].merge_with([])
            q = data_flow_ops.FIFOQueue(capacity=, dtypes=[input_tensor.dtype.base_dtype],
                                shapes=[element_shape], shared_name=shared_name, name=name)
            enq = q.enqueue_many([input_tensor])
            queue_runner.add_queue_runner(queue_runner.QueueRunner(q, [enq], cancel_op=cancel_op))
            return q

        if shuffle:
            common_queue = data_flow_ops.RandomShuffleQueue(
                capacity=capacity,
                min_after_dequeue=min_after_dequeue,
                dtypes=dtypes,
                seed=seed,
                name='common_queue')
        else:
            common_queue = data_flow_ops.FIFOQueue(capacity=capacity, dtypes=dtypes, name='common_queue')

        reader_kwargs = reader_kwargs or {}
        enqueue_ops = []
        for reader in [reader_class(**reader_kwargs) for _ in range(num_readers)]
          enqueue_ops.append(common_queue.enqueue(reader.read(queue)))

        queue_runner.add_queue_runner(queue_runner.QueueRunner(common_queue, enqueue_ops))
        return common_queue.dequeue(name=name)

    items = dataset.decoder.list_items()
    tensors = dataset.decoder.decode(data, items)
    items_to_tensors[record_key] = key

    return super(DatasetDataProvider, self).__init__(items_to_tensors=items_to_tensors, num_samples=dataset.num_samples)
           

由上分析可見,建立了兩組隊列

1. FIFOQueue 隊列,從 data_files 讀取資料,寫入該隊列尾部

2. num_readers 個 FIFOQueue 或 RandomShuffleQueue 隊列,從 FIFOQueue 隊列頭讀資料

其中 add_queue_runner 将各個 queue_runner 加入 ops.GraphKeys.QUEUE_RUNNERS,

當訓練開始的時候, 會調用 start_queue_runners,它會為 enqueue_ops 中的每個

操作啟動一個線程。 具體參考 python/training/queue_runner_impl.py

還有一點需要注意的,

  1. 隊列的實作是 cpp 來實作的,
  2. queue_runner 是 python 的線程。
  3. TFRecordReader 和 TFExampleDecoder 核心都是 cpp 實作的

備注:關于隊列部分和 TFRecordReader,我将開專門的文章分析。

這個實作有什麼問題?

  1. FIFOQueue 隊列的 capacity 太小隻要 32,是以,瓶頸可能在 FIFOQueue 隊列
  2. 隊列都是本機内的,無法跨主機,而事實上對于一個大型深度學習系統來說,資料一般不可能在同一台機器。跨機器通路是剛需
  3. 當然,如果程式中斷,網絡中斷,必須從頭開始,是以可靠性不夠

改進,将 FIFOQueue 隊列改為一個類似 kafka 的分布式隊列即可

繼續閱讀