天天看點

MxNet系列——how_to——bucketing在MXNet中使用Bucketing

部落格新址: http://blog.xuezhisd.top

郵箱:[email protected]

在MXNet中使用Bucketing

Bucketing是一種訓練多個不同但又相似的結構的網絡,這些網絡共享相同的參數集。一個典型的應用是循環神經網絡(RNNs)。在使用符号網絡定義的工具箱中,實作RNNs通常會沿時間軸将網絡顯式地展開。顯式地展開RNNs之前需要知道序列的長度。為了處理序列中的所有元素,我們需要将網絡展開成最大可能的序列長度。然而這很浪費資源,因為對于較短的序列,大部分計算都是在填充後的資料上執行的。

Bucketing,是從 Tensorflow’s sequence training example 借鑒而來的一個簡單的方法。它不再将網絡展開成最大可能長度,而是展開成多個不同長度的執行個體(比如,長度為5, 10, 20, 30)。在訓練過程中,對于不同長度的最小批資料,我們使用最恰當的展開模型。對于RNNs,盡管這些模型具有不同的架構,但參數在時間軸上是共享的。盡管選出的不同bucket的模型,并以不同的最小批來訓練,但本質上都是在優化相同的參數集。MXNet 在所有的執行器中重複使用中間的存儲緩存。

對于簡單的RNNs,可以使用一個for循環來周遊輸入序列,通過保持狀态和沿時間的梯度之間的連接配接的方式沿時間反向傳播。而然,這可能會使降低處理速度。這個方法能夠處理不同長度的序列。但對于更加複雜的模型(比如,使用序列到序列網絡的翻譯模型)來說,并不容易展開。在這個例程中,我們将介紹MXNet的允許我們事先bucketing的APIs。

不同長度的序列訓練PTB

在這個例程中,我們使用 PennTreeBank language model example 。如果你對這個例程不熟悉,請首先檢視 原教程 (in Julia)。

例程中使用的架構是兩個LSTM層,加一個簡單的單詞嵌入層。原例程将模型沿時間展開成固定長度(32)。本例程将介紹如何使用bucketing來實作變長序列訓練。

為了使用bucketing,MXNet需要知道如何為不同長度的序列建構一個新的展開的符号架構(圖)。為了實作這個目的,我們不是建構一個使用固定

Symbol

的模型,而是使用一個回調函數,該函數對新的bucket key 生成一個新的

Symbol

model = mx.model.FeedForward(
        ctx     = contexts,
        symbol  = sym_gen)
           

sym_gen

必須是一個函數,它隻有一個輸入,即

bucket_key

;并為這個bucket傳回一個

Symbol

。我們使用序列長度作為 bucket key。任何對象都可以用作bucket key。比如,在神經網絡翻譯應用中,不同長度的輸入和輸出序列的組合對應于不同的展開方式,一對長度值(輸入/輸出長度)可以用作bucket key。

def sym_gen(seq_len):
    return lstm_unroll(num_lstm_layer, seq_len, len(vocab),
                       num_hidden=num_hidden, num_embed=num_embed,
                       num_label=len(vocab))
           

資料疊代器需要報告

default_bucket_key

,它允許MXNet在讀取資料之前初始化參數。現在,模型能夠以不同的buckets進行訓練,這是通過共享參數和不同buckets之間的計算緩存。

為了訓練,我們還需要為

DataIter

添加一些額外的bits。除了報告之前提到的

default_bucket_key

之外,還需要為每最小批報告目前的

bucket_key

。更具體的說,在每個最下批中,通過

DataIter

傳回的

DataBatch

對象需要包含下面的附加屬性:

  • bucket_key

    : 對應于一批資料的 bucket key。 在本例程中,它是指一批資料的序列長度。如果該bucket key對應的執行器還沒有建立,将根據由函數

    gen_sym

    以bucket key為參數生成的符号模型,建構該bucket key對應的執行器。該執行器将會放在緩存中,以便未來使用。注意:生成的

    Symbol

    s 可能是任意的,但他們應具有相同的可訓練參數和輔助狀态。
  • provide_data

    : 和

    DataIter

    對象報告的資訊相同。 因為現在每個bucket都對應一個不同的架構,它們可以有不同的輸入。同時,確定

    DataIter

    對象傳回的

    provide_data

    資訊和

    default_bucket_key

    的架構是相容的。.
  • provide_label

    : 和

    provide_data

    相同。

現在,

DataIter

負責将資料分到不同的 buckets。 假如已經激活随機化,在麼個最小批中,

DataIter

随機選擇一個 bucket (根據一個由bucket尺寸均衡的分布),然後從bucket中随機選擇一個序列來組成一個最小批資料。如果有必要,它将對最小批中的不同長度的序列進行填充。

擷取一個讀取文本序列的

DataIter

(它通過實作上述的API)的完整實作,請檢視 example/rnn/lstm_ptb_bucketing.py。在本例中,你可以使用靜态配置的 bucketing (比如,

buckets = [10, 20, 30, 40, 50, 60]

), 或者讓 MXnet 根據dataset (

buckets = []

)自動生成 bucketing。後一種方法是通過添加一個和長度和輸入數量相同的bucket(bucket足夠長)來實作的。擷取更多資訊,請檢視 default_gen_buckets().

Beyond Sequence Training

在本例程中,簡單的描述了bucketing API是如何工作的。然而,bucketing API不限于上文使用的序列長度的bucketing。bucket的鍵(key)可以是任意的對象,隻要

gen_sym

傳回的架構相容即可。

繼續閱讀