laitimes

Summary of the twin tower model and its optimization methods

author:Big data and artificial intelligence sharing
Author: Xinghan Link: https://zhuanlan.zhihu.com/p/576286147

The two-tower model structure is widely used in the recall stage in the fields of recommender system and text matching due to its excellent prediction efficiency. Classic methods such as Microsoft's DSSM [1], Google's YoutubeDNN [2], and Airbnb's personalized user embedding [3] have all been implemented in many industrial scenarios and have achieved significant results.

As the optimization of the two-tower model gradually enters the deep-water area, the marginal benefit is getting lower and lower, and some recent research work starts from the objective effect gap between the two-tower model and the interaction model, and achieves good results by adding interaction information to the two-tower model and distilling the two-tower model with the interaction model, which expands a new space for the optimization of the two-tower model.

Summary of the twin tower model and its optimization methods
Summary of the twin tower model and its optimization methods

1. Overview of optimization methods

The optimization methods described in this paper are based on the traditional twin-tower structure, and some methods that rely on strong engineering capabilities (COLD, TDM series, DR, etc.) are not in this category.

Summary of the twin tower model and its optimization methods

二 知识蒸馏 knowledge distillation

1. PFD[4]

During offline training, a number of dominant features were added to the teacher model, which mainly included the following two categories: interactive features (which cannot be obtained by the two-tower structure), such as the user clicking on the same type of item in the past 24 hours. Post-click information (not available by online services), such as the length of time spent on the page after clicking, the status of communication with the store, etc. The teacher model does not need to be pre-trained in advance, and is updated synchronously with the student model, and the two models are updated independently in the initial stage, and the distillation starts after the teacher model is stable.

Summary of the twin tower model and its optimization methods

Distillation method: share feature representation, for the common features of teacher and student, share the corresponding features of each other's embedding, add auxiliary loss, the predicted value of the student model should fit the real label and the predicted value of the teacher model.

Summary of the twin tower model and its optimization methods

2. ENDX[5]

Summary of the twin tower model and its optimization methods

The teacher model uses the interaction information to generate query and answer representations of the same size as the student model (items in the recommendation system) in order to provide more learning signals for the student model during distillation. Due to the large difference in the network structure between the teacher model and the student model in PFD, only the predicted value logit can be used for distillation, resulting in less supervision information provided by the teacher model.

蒸馏方式:通过Geometry Alignment Mechanism (GAM)机制对齐student与teacher各自对应的query和answer表征,具体包含以下几步:

  1. The similarity of the two vectors is characterized by conditional probability, and the more similar the two vectors, the greater the value of the conditional probability
Summary of the twin tower model and its optimization methods
  1. All vector pairs of the same type in a batch calculate the probability distribution corresponding to the similarity in the student network, and also calculate the probability distribution corresponding to the similarity in the teacher network. The proximity of the two probability distributions is characterized by KL-divergence
Summary of the twin tower model and its optimization methods
  1. 四种类型向量pair(P(answer|query), P(answer|answer), P(query|answer), P(query|query))共同构成迁移辅助loss
Summary of the twin tower model and its optimization methods

3. TRMD[6]

There are two teachers, cross-encoder teacher and bi-encoder teacher, both of which need to be pre-trained in advance, and the parameters are fixed during distillation. Unfortunately, the authors do not demonstrate the magnitude of the effect of two teachers compared to one teacher.

Summary of the twin tower model and its optimization methods

Steaming method:

  1. The cross-encoder teacher guides the student to learn the CLS representation (CLS is the CLS representation output by Bert, which generally represents the starting character of the doc)
Summary of the twin tower model and its optimization methods
  1. bi-encoder teacher(ColBert) guides students to learn REP representations (REP is all representations output by Bert, including CLS, query, and doc.) In particular, it should be noted that the bi-encoder teacher is also a twin tower, so it has two CLSs like the student, and if you want to learn REP representation through the cross-encoder teacher, you need to perform operations such as max or sum on the two CLSs)
Summary of the twin tower model and its optimization methods
  1. Student imitates two teachers to calculate scores separately, and the two scores are added together as the model prediction value.

How does the bi-encoder get the score of the cross-encoder? The two CLS representations of the bi-encoder are aggregated into one, and the subsequent process remains unchanged.

4. VIRT[7]

The Q and K calculated by all words of the query and doc sequences in the teacher model at the transformer encoder layer contain interaction information. Each word has a corresponding Q, K, and V parameters calculated by Transformer, so the Student model can be distilled using the Q,K parameters containing the interaction information in the Teacher model. The teacher model needs to be pre-trained, and the parameters are fixed during the distillation process.

In the following figure, X represents query and Y represents doc.

Summary of the twin tower model and its optimization methods

Steaming method:

  1. All the parameters of the teacher model can be decomposed into four parts after matrix multiplication, which are only related to query, query and doc intersect, doc and query intersect, and only related to doc. For the cross part, the calculation results of the teacher model are used to distill the missing part of the student model, as shown in Figure B, and the distillation process is realized by auxiliary loss.
  2. In order to make up for the shortcomings of simple vector inner product calculation without interaction information, the authors design an attention mechanism to use the interaction information learned in the previous step of distillation to weight the last layer representations on the query and doc sides, respectively. The specific calculation is as follows:
Summary of the twin tower model and its optimization methods

5. Distilled-DualEncoder[8]

The core idea is similar to the previous VIRT, with the addition of soft-label distillation between the two model predictions

Summary of the twin tower model and its optimization methods

Steaming method:

  1. Similar to VIRT, a transformer is used to encoder each word, and then the crossover information of the teacher is obtained through matrix multiplication to distill the missing part of the student. The only difference is the choice of auxiliary loss, where the L2 distance is used in VIRT and KL-divergence is used in this paper.
  2. Teacher Predicted ValueDistilled Student's Predicted Value

6. ERNIE-Search[9]

In view of the difficulty of bi-encoder to learn cross-encoder directly, Colbert is introduced as a bridge for information transmission. After the training is completed, the student model is still in the bi-encoder structure, and only one inner product calculation is required.

Summary of the twin tower model and its optimization methods

Distillation method: two-layer distillation

  1. ColBert 蒸馏 student.

The probability distribution of the score of the query for all candidate docs is calculated separately

Summary of the twin tower model and its optimization methods

The proximity of the two probability distributions is characterized by KL-divergence

Summary of the twin tower model and its optimization methods
  1. ross-encoder teacher蒸馏ColBert

Similar to the previous step, the probability distribution of teacher scoring all candidate docs is calculated, and then the proximity is characterized by KL-divergence. In addition, the attention information of word-level auxiliary loss distillation is introduced, which is similar to that of VIRT and Distilled-DualEncoder, and the cross-information contained in the teacher model guides the learning of the missing part of student.

Summary of the twin tower model and its optimization methods
  1. The author also adds the auxiliary loss of teacher to student, and the final loss includes three training losses of the model itself and four auxiliary losses
Summary of the twin tower model and its optimization methods

三 交互增强 interaction enhancement

1 ColBert[10]

Colbert is a pioneering work in this direction, and many subsequent papers have been studied on this basis. This interaction is described in the paper as late interaction, which corresponds to the early interaction of the cross-encoder from the beginning. The similarity of query and doc is upgraded from the inner product of a single vector to the similarity between each word and doc under the cumulative query.

Summary of the twin tower model and its optimization methods

Interaction method: Given each word in a query, iterate over the word in the doc to calculate the matching score, and summarize the matching score of all words in the query as the final score

Summary of the twin tower model and its optimization methods

Number of product calculations: query_word_num * doc_word_num

2 IntTower[11]

The last layer of the hidden vector on the item side interacts with the multi-layer hidden vector on the user side to depict the similarity. In addition, the similarity characterization is enhanced by self-supervised learning to construct an assisted loss(L_cir) guidance model through positive and negative samples.

Summary of the twin tower model and its optimization methods

Interaction: The two types of interaction are fused in this paper

  1. The last layer of the item is interacted with the multilayer (L) representation on the user side as the model prediction value, which is divided into three steps
  • The hidden layers of the user/item side are mapped to M subspaces (M vectors)
  • The M vectors of each layer on the user side and M vectors on the item side are calculated as the inner product, and then max is taken as the representation score of the layer
  • The multi-layer representation scores on the user side are added up as the predicted values of the model
  1. Positive and negative sample pairs interact as auxiliary loss:L_cir

All samples except positive samples are treated as negative samples, and the InfoNCE method commonly used in recall is used to achieve the effect that the user is close to the positive and far away from the negative.

Summary of the twin tower model and its optimization methods

The number of inner product calculations is as follows: Number of hidden layers on the user side L * Number of subspaces M * Number of subspaces M

3 MVKE[12]

The user side first generates multiple public interest representations, and the item side (Tag in the figure below) represents each user is used as a query to obtain the item-related interest representations of each user through attention aggregation.

Summary of the twin tower model and its optimization methods

interaction 方式:

  1. VK-Expert is used to characterize the user's interest, and Virual Kernel is the corresponding learnable parameter, which is used as a query to perform a weighted aggregation (attention aggregation) of all features (Fields) on the user side as the output of VK-Expert.
  2. The item-side representation is weighted and aggregated with multiple VK-Expert representations on the user side as the user's item-related interest representation
  3. The inner product of the user's item-related interest representation and item representation was used as the predicted value of the model

内积计算次数:number of VK Experts + 1

reference

  1. ^(DSSM) Huang et al. Learning deep structured semantic models for web search using clickthrough data. CIKM. 2013.
  2. ^(YoutubeDNN) Covington et al. Deep neural networks for youtube recommendations. Recsys. 2016.
  3. ^Grbovic et al. Real-time personalization using embeddings for search ranking at airbnb. KDD. 2018.
  4. ^(PFD) Xu C, Li Q, Ge J, et al. Privileged features distillation at Taobao recommendationsKDD. 2020.
  5. ^ (ENDX) Wang et al. Enhancing Dual-Encoders with Question and Answer Cross-Embeddings for Answer Retrieval. arXiv, 2022.
  6. ^(TRMD) Choi et al. Improving Bi-encoder Document Ranking Models with Two Rankers and Multi-teacher Distillation. SIGIR. 2021.
  7. ^(VIRT) Li et al. VIRT: Improving Representation-based Models for Text Matching through Virtual Interaction. arXiv , 2021.
  8. ^(Distilled-DualEncoder) Wang et al. Distilled Dual-Encoder Model for Vision-Language Understanding. arXiv, 2021.
  9. ^(ERNIE-Search) Lu et al. ERNIE-Search: Bridging Cross-Encoder with Dual-Encoder via Self On-the-fly Distillation for Dense Passage Retrieval. arXiv, 2022.
  10. ^(ColBert) Khattab et al. Colbert: Efficient and effective passage search via contextualized late interaction over bert. SIGIR.2020.
  11. ^ IntTower: the Next Generation of Two-Tower Model for Pre-Ranking System
  12. ^(MVKE) Xu et al. Mixture of virtual-kernel experts for multi-objective user profile modeling. KDD. 2022

Read on