天天看點

(Neural Turning Machine)神經圖靈機了解和pytorch實作Neural Turning MachinePytorch實作

一般的神經網絡不具有記憶功能,輸出的結果隻基于目前的輸入;而LSTM網絡的出現則讓網絡有了記憶:能夠根據之前的輸入給出目前的輸出。但是,LSTM的記憶程度并不是那麼理想,對于比較長的輸入序列,LSTM的最終輸出隻與最後的幾步輸入有關,也就是long dependency問題,當然這個問題可以由注意力機制解決,然而卻不能從根本上解決長期記憶的問題,原因是由于LSTM是假設在時間序列上的輸入輸出:由t-1時刻得到t時刻的輸出,然後再循環輸入t時刻的結果得到t+1時刻的輸出,這樣勢必會使處于前面序列的輸入被淹沒,導緻這部分記憶被“丢掉“。

神經圖靈機通過引入外部記憶解決了這個問題。 舉個簡單的例子,我們人類在記憶一些事情的時候,除了用腦袋記,還會寫在備忘錄上,當我們想不起來的時候,就可以去翻閱備忘錄,進而獲得相關的記憶。神經圖靈機模仿人類記憶的過程:其中的控制器(controller)相當于我們人類的大腦,用于把輸入事物的特征提取出來;外部記憶(memory)相當于我們的備忘錄,把事物的特征記錄在上面,那麼完整的過程就是:控制器将目前輸入轉化為特征,寫入記憶,再讀取與目前輸入特征有關的記憶作為最後的輸出。整個過程與圖靈機的讀寫很像,隻不過神經圖靈機這裡讓所有的讀寫操作都可微分化,是以可以用神經網絡誤差後向傳播的方式去訓練模型。

(Neural Turning Machine)神經圖靈機了解和pytorch實作Neural Turning MachinePytorch實作

那麼問題就來了,當獲得一個輸入的特征後,我們如何确定在記憶中儲存的位置,而且如何從記憶中擷取與目前輸入相關的資訊呢?這就是接下來要分析的神經圖靈機主要工作。

Neural Turning Machine

(Neural Turning Machine)神經圖靈機了解和pytorch實作Neural Turning MachinePytorch實作

1. 讀記憶 (Read Heads)

我們把記憶看作是一個 N × M N \times M N×M的矩陣 M t M_t Mt​,t表示目前時刻,表示記憶會随着時間發生變化。我們的讀過程就是生成一個定位權值向量 w t w_t wt​,長度為 N N N,表示N個位置對應的記憶權值大小,最後讀出的記憶向量 r t r_t rt​為:

r t = ∑ i N w t ( i ) M t ( i ) r_t = \sum_i^N w_t(i) M_{t}(i) rt​=i∑N​wt​(i)Mt​(i)

其中權值向量的和為1: ∑ i w t ( i ) = 1 \sum_i w_t(i)=1 ∑i​wt​(i)=1,本質上是一個對N條記憶進行一個權重求和的思想。

2. 寫記憶(Write Heads)

神經圖靈機的寫過程參考了LSTM的門的概念:先用輸入門決定增加的資訊,再用遺忘門決定要丢棄的資訊,最後用更新門加上增加的資訊并減去丢棄的資訊。具體來說,神經圖靈機會生成一個擦除向量 e t e_t et​(erase vector) 和一個增加向量 a t a_t at​(add vector),長度都為 N N N,向量中每個元素的值大小範圍從0到1,表示要增加或者删除的資訊。對于寫記憶過程,神經圖靈機首先執行一個擦除操作,擦除程度的大小同樣由向量 w t w_t wt​決定:

M t ′ ( i ) = M t − 1 ( i ) ( 1 − w t ( i ) e t ( i ) ) M_t'(i)=M_{t-1}(i)(1-w_t(i)e_t(i)) Mt′​(i)=Mt−1​(i)(1−wt​(i)et​(i))

這個操作表示從 t − 1 t-1 t−1時刻的記憶中丢棄了一些資訊,若 w t w_t wt​和 e t e_t et​同時為0,則表示記憶沒有丢棄資訊,目前記憶與 t − 1 t-1 t−1時刻保持不變。執行完擦除後,然後執行增加操作:

M t ( i ) = M t ′ ( i ) + w t ( i ) a t ( i ) M_t(i)=M'_t(i)+w_t(i)a_t(i) Mt​(i)=Mt′​(i)+wt​(i)at​(i)

這步表示在丢棄一些資訊後需要新增的資訊,同樣,若 w t w_t wt​和 a t a_t at​都為0,表示目前記憶無新增,與擦除後的記憶保持一緻。其中, e t e_t et​和 a t a_t at​都是由控制器給出,而控制器基本上由神經網絡實作,可以是LSTM,也可以是MLP。

由于整個過程都是都是矩陣的加減乘除,所有的讀寫操作都是可微分的,是以我們可以用梯度下降法訓練整個參數模型。但是接下來,我們需要确定 w t w_t wt​定位向量,由于這個向量直接決定着目前輸入與記憶的相關性,是以神經圖靈機在生成 w t w_t wt​向量上做了很多工作。

3. 定位機制(Addressing Mechanism)

關于決定其相關性的方法有很多,主要分為兩大類: 基于内容的(content-based)和基于位置的(location-based)。神經圖靈機結合了這兩個方法提出一個定位機制用于生成定位向量 w t w_t wt​,具體來說,先用基于内容的方法,再用基于位置的方法。

(Neural Turning Machine)神經圖靈機了解和pytorch實作Neural Turning MachinePytorch實作

3.1 Content-based Addressing

基于内容的定位計算主要基于餘弦相似度:首先控制器給出一個 k t k_t kt​向量作為查詢的key,然後計算 k t k_t kt​與 M t M_t Mt​中各個記憶向量的餘弦相似度,最後經過一個softmax操作得到基于内容的定位向量 w t c w_t^c wtc​:

w t c ( i ) = exp ⁡ ( β t K [ k t , M t ( i ) ] ) ∑ j exp ⁡ ( β t K [ k t , M t ( j ) ] ) w_t^c(i) = \frac{\exp(\beta_tK[k_t,M_t(i)])}{\sum_j \exp(\beta_t K[k_t,M_t(j)])} wtc​(i)=∑j​exp(βt​K[kt​,Mt​(j)])exp(βt​K[kt​,Mt​(i)])​

其中 K [ . , . ] K[.,.] K[.,.]是餘弦相似度計算:

K [ u , v ] = u ⋅ v ∣ ∣ u ∣ ∣ ⋅ ∣ ∣ v ∣ ∣ K[u,v] = \frac{u \cdot v}{ ||u|| \cdot ||v||} K[u,v]=∣∣u∣∣⋅∣∣v∣∣u⋅v​

3.2 Location-based Addressing

3.2.1. Interpolation(插值)

控制器生成一個門檻值 g t g_t gt​對目前的内容定位向量 w t c w_t^c wtc​與t-1時刻的定位向量 w t − 1 w_{t-1} wt−1​進行一個插值操作,插值的結果即為輸出值 w t g w_t^g wtg​:

w t g = g t w t c + ( 1 − g t ) w t − 1 w_t^g=g_tw_t^c+(1-g_t)w_{t-1} wtg​=gt​wtc​+(1−gt​)wt−1​

這裡的插值操作可以了解為LSTM的更新門,結合過去的 w w w權值計算新的 w w w

3.2.2. shift(偏移)

對于 w t g w_t^g wtg​中的每個位置元素 w t g ( i ) w_t^g(i) wtg​(i),我們考慮它相鄰的k個偏移元素,認為這k個元素與 w t g ( i ) w_t^g(i) wtg​(i)相關,如當k=3時,三個相鄰元素分别是: w t g ( i ) w_t^g(i) wtg​(i)本身和位置偏移為1的元素 w t g ( i − 1 ) w_t^g(i-1) wtg​(i−1)和 w t g ( i + 1 ) w_t^g(i+1) wtg​(i+1),此時,我們希望新的位置為i的元素能包含這三個元素,是以用一個長度為3的偏移權值向量 s t s_t st​來表示這三個元素的權重,然後權值求和得到輸出值 w t ′ w'_t wt′​:

w t ′ ( i ) = ∑ j = − 1 1 w t g ( i + j ) s ( j + 1 ) w'_t(i)=\sum_{j=-1} ^1 w_t^g(i+j) s(j+1) wt′​(i)=j=−1∑1​wtg​(i+j)s(j+1)

這裡的偏移操作在原文中用的是循環卷積(circular convolution)公式表示的,我們可以了解為把向量 w t g w_t^g wtg​首尾相連形成一個環狀,然後在環中用 s t s_t st​作為卷積核做一維卷積操作。本質上是假設目前元素與相鄰的偏移元素相關。

3.2.3. Sharping(重塑)

當偏移操作中的權值比較平均的時候,上述的卷積操作會導緻資料的分散(dispersion)和洩漏(leakage),就像把一個點的資訊分散在三個點中,權值如果太平均會使三個點包含的值太模糊(個人了解),是以需要把權值大小的差別進行強化,也就是sharping。具體來說,控制器生成一個參數 γ t > 1 \gamma_t>1 γt​>1,然後對各個權值進行 γ t \gamma_t γt​指數然後歸一化:

w t ( i ) = w t ′ ( i ) γ t ∑ j w t ′ ( j ) γ t w_t(i)= \frac{w'_t(i)^{\gamma_t}}{ \sum_j w'_t(j)^{\gamma_t}} wt​(i)=∑j​wt′​(j)γt​wt′​(i)γt​​

最後我們得出了最終的 w t w_t wt​用于提取和儲存記憶。

Pytorch實作

這裡代碼基于的是pytorch-ntm,代碼寫的相當工整,可讀性很高,這裡隻分析一些重要的步驟:

讀過程

讀過程就是從控制器(LTSM)輸出的值提取我們需要的k, beta, g, s, gama值,然後調用_address_memory獲得目前的定位權值向量w, 再用矩陣乘法獲得讀過程的輸出

def forward(self, embeddings, w_prev):
    """NTMReadHead forward function.

    :param embeddings: input representation of the controller.
    :param w_prev: previous step state
    """
    o = self.fc_read(embeddings)
    k, beta, g, s, gama = _split_cols(o, self.read_lengths)

    # Read from memory
    w = self._address_memory(k, beta, g, s, gama, w_prev)
    r = self.memory.read(w)

    return r, w

 def read(self, w):
     """Read from memory (according to section 3.1)."""
     return torch.matmul(w.unsqueeze(1), self.memory).squeeze(1)
           

寫過程

寫過程同樣是獲得定位機制需要的k,beta, g, s, gama以及需要擦除的向量e和增加的向量a,然後調用_address_memory獲得定位向量w,然後根據e和a計算得出最後的寫入向量

def forward(self, embeddings, w_prev):
    """NTMWriteHead forward function.

    :param embeddings: input representation of the controller.
    :param w_prev: previous step state
    """
    o = self.fc_write(embeddings)
    k, beta, g, s, gama, e, a = _split_cols(o, self.write_lengths)

    # e should be in [0, 1]
    e = F.sigmoid(e)

    # Write to memory
    w = self._address_memory(k, beta, g, s, gama, w_prev)
    self.memory.write(w, e, a)

    return w

def write(self, w, e, a):
    """write to memory (according to section 3.2)."""
    self.prev_mem = self.memory
    self.memory = Variable(torch.Tensor(self.batch_size, self.N, self.M))
    erase = torch.matmul(w.unsqueeze(-1), e.unsqueeze(1))
    add = torch.matmul(w.unsqueeze(-1), a.unsqueeze(1))
    self.memory = self.prev_mem * (1 - erase) + add
           

Addressing Mechanism

定位機制的計算非常直覺,首先_similarity方法計算餘弦相似讀獲得wc,然後調用_interpolate與過去的w_prev進行插值操作,接着_shift偏移操作,這裡實際上調用的是_convolve循環卷積方法,最後進行_sharpen操作獲得最終的w

def address(self, k, beta, g, s, gama, w_prev):

    # Content focus
    wc = self._similarity(k, beta)

    # Location focus
    wg = self._interpolate(w_prev, wc, g)
    w1 = self._shift(wg, s)
    w = self._sharpen(w1, gama)

    return w

def _similarity(self, k, beta):
    k = k.view(self.batch_size, 1, -1)
    w = F.softmax(beta * F.cosine_similarity(self.memory + 1e-16, k + 1e-16, dim=-1), dim=1)
    return w

def _interpolate(self, w_prev, wc, g):
    return g * wc + (1 - g) * w_prev

def _shift(self, wg, s):
    result = Variable(torch.zeros(wg.size()))
    for b in range(self.batch_size):
        result[b] = _convolve(wg[b], s[b])
    return result

def _sharpen(self, w1, gamma):
    w = w1 ** gamma
    w = torch.div(w, torch.sum(w, dim=1).view(-1, 1) + 1e-16)
    return w

def _convolve(w, s):
    """Circular convolution implementation."""
    assert s.size(0) == 3
    t = torch.cat([w[-1:], w, w[:1]])
    c = F.conv1d(t.view(1, 1, -1), s.view(1, 1, -1)).view(-1)
    return c
           

訓練過程

首先輸入一系列的資料,每次輸入一個樣本,都先後進行讀和寫過程,然後在不給定輸入的情況下,獲得一系列輸出值,每次獲得一個輸出值時,同樣先後進行着讀和寫過程;隻不過輸出的時候控制器接受的是0向量,而輸入資料的時候控制器接受的是樣本x值。我們可以根據輸出的值與樣本label的差距計算loss,對于copy任務來說,輸入樣本和label都是樣本本身,損失可以使用binary entropy loss,最後梯度下降法更新整合模型參數

def train_batch(net, criterion, optimizer, X, Y):
    """Trains a single batch."""
    optimizer.zero_grad()
    inp_seq_len = X.size(0)
    outp_seq_len, batch_size, _ = Y.size()

    # New sequence
    net.init_sequence(batch_size)

    # Feed the sequence + delimiter
    for i in range(inp_seq_len):
        net(X[i])

    # Read the output (no input given)
    y_out = Variable(torch.zeros(Y.size()))
    for i in range(outp_seq_len):
        y_out[i], _ = net()

    loss = criterion(y_out, Y)
    loss.backward()
    clip_grads(net)
    optimizer.step()

    y_out_binarized = y_out.clone().data
    y_out_binarized.apply_(lambda x: 0 if x < 0.5 else 1)

    # The cost is the number of error bits per sequence
    cost = torch.sum(torch.abs(y_out_binarized - Y.data))

    return loss.data[0], cost / batch_size

# 每次調用net(x)或者net()獲得輸出值的forward方法
def forward(self, x, prev_state):
    """NTM forward function.

    :param x: input vector (batch_size x num_inputs)
    :param prev_state: The previous state of the NTM
    """
    # Unpack the previous state
    prev_reads, prev_controller_state, prev_heads_states = prev_state

    # Use the controller to get an embeddings
    inp = torch.cat([x] + prev_reads, dim=1)
    controller_outp, controller_state = self.controller(inp, prev_controller_state)

    # Read/Write from the list of heads
    reads = []
    heads_states = []
    for head, prev_head_state in zip(self.heads, prev_heads_states):
        if head.is_read_head():
            r, head_state = head(controller_outp, prev_head_state)
            reads += [r]
        else:
            head_state = head(controller_outp, prev_head_state)
        heads_states += [head_state]

    # Generate Output
    inp2 = torch.cat([controller_outp] + reads, dim=1)
    o = F.sigmoid(self.fc(inp2))

    # Pack the current state
    state = (reads, controller_state, heads_states)

    return o, state
           
(Neural Turning Machine)神經圖靈機了解和pytorch實作Neural Turning MachinePytorch實作

關于訓練結果,可以去github裡看,目前隻有copy和deepcopy兩個任務,應該是分開訓練,但是按照前面分析的,神經圖靈機應該是可以先後訓練多個任務,并且保持新的任務不會覆寫舊的任務,從理論上分析,如果讓記憶矩陣非常大,那麼就可以把每個任務儲存到記憶中不同的塊中,保持記憶矩陣的稀疏性,是可以做到任務間不互相幹涉,是以讓模型達到能學習多個任務的能力。谷歌16年在Nature中提出的DNC其實也就是神經圖靈機,論文裡介紹了一些現在神經圖靈機可以完成的通用任務,想了解神經圖靈機具體應用的可以去看看。下面放出論文位址和代碼位址:

神經圖靈機(NTM):https://arxiv.org/abs/1410.5401

DNC: https://www.nature.com/articles/nature20101

參考代碼:https://github.com/loudinthecloud/pytorch-ntm

繼續閱讀