機器之心專欄
機器之心編輯部
騰訊 AI Lab、帝國理工與中山大學合作發表論文《Learning Neural Set Functions Under the Optimal Subset Oracle》,提出基于最優子集的集合函數學習方法。
集合函數被廣泛應用于各種場景之中,例如商品推薦、異常檢測和分子篩選等。在這些場景中,集合函數可以被視為一個評分函數:其将一個集合作為輸入并輸出該集合的分數。我們希望從給定的集合中選取出得分最高的子集。鑒于集合函數的廣泛應用,如何學習一個适用的集合函數是解決許多問題的關鍵。為此,騰訊 AI Lab、帝國理工與中山大學合作發表論文《Learning Neural Set Functions Under the Optimal Subset Oracle》,提出基于最優子集的集合函數學習方法。該方法在多個應用場景中取得良好效果。論文已被 NeurIPS 2022 接收并選為口頭報告(Oral Presentation)。
- 論文位址:https://arxiv.org/abs/2203.01693
- 代碼位址:https://github.com/SubsetSelection/EquiVSet
一、引言
很多現實應用場景與集合密切相關,例如推薦系統、異常檢測和分子篩選等。這些應用都潛在地學習了一個集合函數來評價給定集合的得分,使得輸出的集合擁有最高得分。以商品推薦為例子(如下圖所示),我們希望從某個網店的商品庫V中推薦子集,使得使用者對該商品子集擁有最高評分
圖 1 集合函數學習在商品推薦中的例子
具體地,我們假設每個使用者心中存在一個評分函數
,該函數将一個商品子集
作為輸入,輸出使用者對該子集的評分,即
。使用者總是從系統推薦的商品集合中購買得分最高的商品子集:
我們希望學習一個函數
,使其盡可能逼近真正的評分函數
. 然而在實際應用場景,由于标注成本過高,我們無法得到使用者對每一個商品子集的評分。是以,我們假設資料集的形式為
,其中
為使用者i購買的商品子集,
為對應的商品庫。我們希望找到合适的參數
, 使得使用者購買的商品最大化集合函數
然而找到合适的參數
并不是一件容易的事情。為此,我們将目标函數定義為最大似然估計
其中我們通過
的正比限制使得最終學習到的集合函數滿足上文的要求。進一步地,我們希望該機率分布及由此推導的訓練方法滿足若幹的性質,如:置換不變性、最小先驗假設和可擴充性(scalability)等。
在本文中,我們提出了等變變分集合函數學習方法 (Equivariant Variational inference for Set function learning. EquiVSet). 具體地,我們使用能量模型(energy-based models)模組化機率
;能量模型是最大熵分布,滿足最小先驗假設。其次,我們通過 DeepSet 類型的模型架構模組化集合函數
,使其滿足置換不變性;最後我們使用均攤變分分布來滿足可擴充性的要求。實驗證明,EquiVSet 在商品推薦、異常檢測和分子篩選等現實應用場景中都有出色表現。值得一提的是,雖然傳統的端到端子集預測模型也适用于以上場景,但是他們通常屬于黑盒模型。在本文中,我們顯式模組化集合函數,并通過最大化集合函數來進行子集預測。學習的集合函數可用來評價不同子集的效益,是以更具有可解釋性。
二、方法簡介
圖 2 EquiVSet 訓練和推理過程概覽
我們首先将機率模型定義為能量模型:
, 并用 DeepSet 架構對能量函數
進行模組化,進而實作置換不變性。此外,由于能量模型為最大熵分布,其具有最小資訊先驗假設的特點。為了訓練該模型,我們進一步引入了變分分布
, 并通過神經網絡對其進行模組化。如圖 2 所示,模型訓練包含兩個步驟:
1. 為了學習變分
,我們可以最小化
之間的 KL 距離。
2. 為了訓練模型
,我們首先通過神經網絡 EquiNet 輸出變分分布的初始參數
, 然後通過平均場變分推斷來更新變分參數
,使其逼近模型分布
。此外,該步驟可以讓模型參數
依賴于變分參數
。是以我們可以通過最小化交叉熵損失來更新模型參數
:.
模型通過反複疊代步驟 1 和 2 來更新參數, 進而達到合作式訓練變分網絡 EquiNet 和能量網絡的目的。這種合作學習方式的效果可以通過圖 3 形象示意:在每輪疊代中,我們通過均攤變分推斷
來更新變分分布的參數,使其不斷逼近模型
;參數
更新完畢後,我們通過最小化交叉熵損失來訓練模型
,使其不斷逼近真實資料分布
。模型訓練完畢後,近似地
成立,是以我們可以根據變分參數來選取最優子集。值得注意的是, 算法角度來說, 簡單使用端到端子集預測模型相當于隻模組化了變分網絡 EquiNet, 即隻模組化了變分分布, 是以無法達到合作學習的目的。 我們後續的實驗部分也驗證了這種端到端子集預測方法的性能與合作學習方法 EquiVSet 相差甚遠。
圖 3 EquiVSet 參數更新示意圖
三、實驗結果
為了驗證 EquiVSet 的有效性,我們在三個任務上進行測試:商品推薦、異常檢測和分子篩選。
1. 在商品推薦任務中,我們使用 amazon baby register dataset,該資料集包含了真實的使用者購買記錄。在該任務上,EquiVSet 在大部分場景中都取得最佳性能。具體地,相比于先前的 SOTA 算法 PGM,EuiVSet 的性能平均提升 33%。相比于傳統的黑盒端到端子集預測方法 DeepSet(NoSetFn)(該方法相當于僅模組化了變分網絡 EquiNet),EquiVSet 的性能平均提升 39%,說明了顯式模組化集合函數的重要性。
2. 在異常檢測任務中,我們使用四個經典資料集:double mnist,celebA,fashion-mnist 和 cifar-10。下圖給出了 celebA 上異常檢測的例子。
圖 3 celebA 資料集。每一行是一個資料樣本。在每個樣本中,正常圖檔擁有兩個共同屬性(最右列),異常圖檔(紅色方框)沒有該屬性。
以下表格提供不同方法在該任務上的性能對比,可以看出 EquiVSet 顯著優于其他方法, 并比 PGM 和 DeepSet(NoSetFn) 的性能分别平均提高 37% 和 80%。
3. 在分子篩選中,我們使用 PDBBind 和 BindingDB 兩個經典資料。該任務是從給定的分子庫中,篩選出符合一定屬性的分子。下表是 EquiVSet 和各個方法的對比結果。
四、結論
本文提出的基于最優子集的集合函數學習方法。通過将集合機率定義成能量模型,使得模型滿足置換不變性、最小先驗等特點。借助最大似然方法和等變變分技巧,模型能夠高效地訓練和推理。在商品推薦、異常檢測和分子篩選上的應用認證了該方法的有效性。