天天看點

Key Fact as Pivot: A Two-Stage Model for Low Resource Table-to-Text Generation 論文代碼解析

1.資料處理部分

原始文本處理

table2entity2text.py

舉例:原始資料的一個句子

鍵值對

name_1:walter   name_2:extra    image:<none>    image_size:<none>   caption:<none>  birth_name:<none>   birth_date_1:1954   birth_place:<none>  death_date:<none>   death_place:<none>  death_cause:<none>  resting_place:<none>    resting_place_coordinates:<none>    residence:<none>    nationality_1:german    ethnicity:<none>    citizenship:<none>  other_names:<none>  known_for:<none>    education:<none>    alma_mater:<none>   employer:<none> occupation_1:aircraft   occupation_2:designer   occupation_3:and    occupation_4:manufacturer   home_town:<none>    title:<none>    salary:<none>   networth:<none> height:<none>   weight:<none>   term:<none> predecessor:<none>  successor:<none>    party:<none>    boards:<none>   religion:<none> spouse:<none>   partner:<none>  children:<none> parents:<none>  relations:<none>    signature:<none>    website:<none>  footnotes:<none>    article_title_1:walter  article_title_2:extra      

描述

walter extra is a german award-winning aerobatic pilot , chief aircraft designer and founder of extra flugzeugbau -lrb- extra aircraft construction -rrb- , a manufacturer of aerobatic aircraft .      

經過資料處理後變成:

{"source": "walter extra 1954 german aircraft designer and manufacturer walter extra", "target": "walter extra is a german award-winning aerobatic pilot , chief aircraft designer and founder of extra flugzeugbau -lrb- extra aircraft construction -rrb- , a manufacturer of aerobatic aircraft .", "field": "name name birth_date nationality occupation occupation occupation occupation article_title article_title", "lpos": "1 2 1 1 1 2 3 4 1 2", "rpos": "2 1 1 1 4 3 2 1 2 1"}      

然後再增加是否為關鍵事實的label(即source中的單詞是否出現在target中,有則label為1,無則為0):

{"source": "walter extra 1954 german aircraft designer and manufacturer walter extra", "target": "walter extra is a german award-winning aerobatic pilot , chief aircraft designer and founder of extra flugzeugbau -lrb- extra aircraft construction -rrb- , a manufacturer of aerobatic aircraft .", "field": "name name birth_date nationality occupation occupation occupation occupation article_title article_title", "lpos": "1 2 1 1 1 2 3 4 1 2", "rpos": "2 1 1 1 4 3 2 1 2 1", "label": "1 1 0 1 1 1 1 1 1 1", "pivot": "walter extra german aircraft designer and manufacturer walter extra"}      

同時生成單獨的pivot檔案

extract_entity.py

使用Stanford CoreNLP toolkit2标記文本,為每個單詞配置設定POS标記。我們保留target中其POS标簽屬于{NN,NNS,NNP,NNPS,JJ,JJR,JJS,CD,FW}标簽集的單詞,即{名詞,單數或品質;名詞,複數;專有名詞,單數;專有名詞,複數;形容詞;形容詞,比較級;形容詞,最進階;基數;外國詞},并删除剩餘單詞。建構僞并行資料,得到entity檔案:

walter extra german award-winning aerobatic pilot chief aircraft designer founder extra flugzeugbau extra aircraft construction manufacturer aerobatic aircraft      

extract_pivot.py

将source(value),target(text),field,lpos,rpos,entity,lable,pivot放一起生成train.pivot.jsonl檔案:

{"value": "walter extra 1954 german aircraft designer and manufacturer walter extra", "text": "walter extra is a german award-winning aerobatic pilot , chief aircraft designer and founder of extra flugzeugbau -lrb- extra aircraft construction -rrb- , a manufacturer of aerobatic aircraft .", "field": "name name birth_date nationality occupation occupation occupation occupation article_title article_title", "lpos": "1 2 1 1 1 2 3 4 1 2", "rpos": "2 1 1 1 4 3 2 1 2 1", "entity": "walter extra german award-winning aerobatic pilot chief aircraft designer founder extra flugzeugbau extra aircraft construction manufacturer aerobatic aircraft", "label": "1 1 0 1 1 1 1 1 1 1", "pivot": "walter extra german aircraft designer and manufacturer walter extra"}      

process_pivot.py

生成兩份檔案:

table2pivot:

随機選取資料集中10000個元素,将其對應的value,label,field,lpos,rpos放一起生成train.t2p.jsonl檔案:

{"value": "walter extra 1954 german aircraft designer and manufacturer walter extra", "label": "1 1 0 1 1 1 1 1 1 1", "field": "name name birth_date nationality occupation occupation occupation occupation article_title article_title", "lpos": "1 2 1 1 1 2 3 4 1 2", "rpos": "2 1 1 1 4 3 2 1 2 1"}      

pivot2text:

for i, d in enumerate(ori_datas):    
    if index is None or i in index:        
        data = {'source': d['pivot'], 'target': d['text']}    
    else:        
        data = {'source': d['entity'], 'target': d['text']}    
    datas.append(data)      
{"source": "walter extra german award-winning aerobatic pilot chief aircraft designer founder extra flugzeugbau extra aircraft construction manufacturer aerobatic aircraft", "target": "walter extra is a german award-winning aerobatic pilot , chief aircraft designer and founder of extra flugzeugbau -lrb- extra aircraft construction -rrb- , a manufacturer of aerobatic aircraft ."}      

construct_corpus.py

首先index是之前生成的随機索引檔案,size=10,000。

train_parallel_dataset:

for d in ori_datas:    
    data = {'source': d['value'], 'target': d['text'], 'field': d['field'], 'lpos': d['lpos'], 'rpos': d['rpos']}    
    datas.append(data)      

生成train.parallel.{10000}.jsonl, len=10000

train_p2t_dataset:

def get_filter_data(data: Dict) -> Dict:
    _data = {}
    label = data['label'].split(' ')
    field = data['field'].split(' ')
    lpos = data['lpos'].split(' ')
    rpos = data['rpos'].split(' ')
    _field = [f for f, l in zip(field, label) if l == '1']
    _lpos = [f for f, l in zip(lpos, label) if l == '1']
    _rpos = [f for f, l in zip(rpos, label) if l == '1']
​
    return {'source': data['pivot'], 'target': data['text'],
            'field': ' '.join(_field), 'lpos': ' '.join(_lpos),
            'rpos': ' '.join(_rpos)}      
def train_p2t_dataset(r_path: str, w_path: str, index: List[int]) -> List[Dict]:
    '''
    value, text, field, lpos, rpos, pivot, entity
    '''
    ori_datas = loads(open(r_path))[1:]
​
    statistic = {'length': len(ori_datas)}
    datas = [statistic]
    index = set(index)
​
    for i, d in enumerate(ori_datas):
        if i in index:
            #datas.append({'source': d['pivot'], 'target': d['text']})
            datas.append(get_filter_data(d))
        else:
            datas.append({'source': d['entity'], 'target': d['text']})
    
    dumps(datas, open(w_path, 'w'))      

生成train.p2t.{10000}.jsonl檔案,len=580000

train_t2p_dataset:

for d in ori_datas:
    data = {'value': d['value'], 'label': d['label'], 'field': d['field'],
            'lpos': d['lpos'], 'rpos': d['rpos']}
    datas.append(data)      

生成train.t2p.{10000}.jsonl檔案, len=10000

train_aug_dataset:

   for i, d in enumerate(ori_datas):
        if i in index:
            datas.append({'source': d['value'], 'target': d['text'], 'field': d['field'], 'lpos': d['lpos'], 'rpos': d['rpos']})
        else:
            datas.append({'source': d['entity'], 'target': d['text']})      

生成train.aug.{10000}.jsonl檔案, len=58000

train_semi_dataset:

   for i, d in enumerate(ori_datas):
        if i in index:
            datas.append({'source': d['value'], 'target': d['text'], 'field': d['field'], 'lpos': d['lpos'], 'rpos': d['rpos']})
        else:
            datas.append({'source': d['text'], 'target': d['text']})      

生成train.semi.{10000}.jsonl檔案, len=58000

train_pretrain_dataset:

   for i, d in enumerate(ori_datas):
        if i in index:
            datas.append({'source': d['value'], 'target': d['text'], 'field': d['field'], 'lpos': d['lpos'], 'rpos': d['rpos']})
        else:
            datas.append({'source': '', 'target': d['text']})      

生成train.pretrain.{10000}.jsonl檔案, len=58000

将資料轉為allennlp适用的格式

dataset.pivot.py

class Table2PivotDataset:

将輸入資料轉為Instance。每個Instance會有一個textfield包含一個句子,一個sequenceLabelField包含對應的标簽。

classs Pivot2TextDataset:

将生成的關鍵事實轉為Instance,并增加噪聲資料:

   def add_noise(self, words):
        #words = self.word_shuffle(words)
        # 随機扔掉一些單詞
        words = self.word_drop(words)
        # 随機将一些單詞替換為‘@@UNKNOWN@@’
        words = self.word_blank(words)
        # 随機位置增加一定比例的‘@@UNKNOWN@@’
        words = self.word_append(words)
​
        return words      

class Table2PivotCorpus:

将資料加載進來:

train_path = os.path.join(data_path, 'train.t2p.{0}.jsonl'.format(scale))
dev_path = os.path.join(data_path, 'valid.t2p.jsonl')
test_path = os.path.join(data_path, 'test.t2p.jsonl')
​
vocab_dir = os.path.join(data_path, 'dicts-{0}-t2p-{1}'.format(vocab_size, scale))      

轉為Instance:

self.train_dataset = Table2PivotDataset(path=train_path, max_len=max_len)
self.test_dataset = Table2PivotDataset(path=test_path, max_len=max_len)
self.dev_dataset = Table2PivotDataset(path=dev_path, max_len=max_len)      

轉為torch的dataloader:

self.train_loader = torch.utils.data.DataLoader(dataset=self.train_dataset,
batch_size=batch_size,
collate_fn=collate_fn,
shuffle=True
)      

将結果儲存為predict-{10000}.txt檔案

       for ids, source, field, lpos, rpos in zip(model_ids, sources, fields, lposs, rposs):
            words = [s for id, s in zip(ids, source) if id > 0]
            _field = [s for id, s in zip(ids, field) if id > 0]
            _lpos = [s for id, s in zip(ids, lpos) if id > 0]
            _rpos = [s for id, s in zip(ids, rpos) if id > 0]
            model_words.append(' '.join(words))
            _fields.append(' '.join(_field))
            _lposs.append(' '.join(_lpos))
            _rposs.append(' '.join(_rpos))
        
        with open(os.path.join(self.data_path, 'predict-{0}.txt'.format(self.scale)), 'w') as f:
            print('\n'.join(model_words), file=f)      

評估模型:def evaluate

class Pivot2TextCorpus:

加載資料:

train_path = os.path.join(data_path, 'train.p2t.{0}.jsonl'.format(scale))
dev_path = os.path.join(data_path, 'valid.p2t.jsonl')
test_path = os.path.join(data_path, 'test.predict.{0}.jsonl'.format(scale))      

将生成的關鍵事實轉為Instance,并增加噪聲資料:

self.train_dataset = Pivot2TextDataset(path=train_path, src_max_len=src_max_len, tgt_max_len=tgt_max_len, share=share, append_rate=append_rate, drop_rate=drop_rate, blank_rate=blank_rate, use_feature=use_feature)
​
self.test_dataset = Pivot2TextDataset(path=test_path, src_max_len=src_max_len, tgt_max_len=tgt_max_len, share=share, use_feature=use_feature)
​
self.dev_dataset = Pivot2TextDataset(path=dev_path, src_max_len=src_max_len, tgt_max_len=tgt_max_len, share=share, use_feature=use_feature)      

評估:def evaluate

2.模型部分

table2pivot子產品使用模型:

sequence_labeling.py

一個雙向LSTM編碼器加一個線性分類解碼器,将單詞嵌入,屬性嵌入和位置嵌入連接配接起來作為模型x的輸入:

self.src_embedding = nn.Embedding(self.vocab_size, emb_size)
self.key_embedding = nn.Embedding(vocab.get_vocab_size('keys'), key_emb_size)
self.lpos_embedding = nn.Embedding(vocab.get_vocab_size('lpos'), pos_emb_size)
self.rpos_embedding = nn.Embedding(vocab.get_vocab_size('rpos'), pos_emb_size)
​
self.encoder = rnn.rnn_encoder(emb_size + key_emb_size + pos_emb_size * 2, hidden_size, enc_layers, dropout,
bidirectional)
self.decoder = nn.Linear(hidden_size, self.label_size)
self.accuracy = SequenceAccuracy()      
src_embs = torch.cat([self.src_embedding(src),
                        self.key_embedding(keys),
                        self.lpos_embedding(lpos),
                        self.rpos_embedding(rpos)], dim=-1)
​
src_embs = pack(src_embs, lengths, batch_first=True)
encode_outputs = self.encoder(src_embs)
out_logits = self.decoder(encode_outputs['hidden_outputs'])
seq_mask = (src > 0).float()
​
self.accuracy(predictions=out_logits, gold_labels=tgt, mask=seq_mask)
loss = sequence_cross_entropy_with_logits(logits=out_logits,targets=tgt, weights=seq_mask,average='token')
# outputs = out_logits.max(-1)[1] greedy      

預測部分:

out_logits = self.decoder(encode_outputs['hidden_outputs'])
outputs = out_logits.max(-1)[1]
​
outputs = outputs.index_select(dim=0, index=rev_indices)
src = src.index_select(dim=0, index=rev_indices)
​
seq_mask = src > 0
correct_tokens = (tgt.eq(outputs) * seq_mask).sum()
total_tokens = seq_mask.sum()
​
h_tokens = (tgt.eq(outputs) * seq_mask * (tgt.eq(0))).sum()
r_total_tokens = (tgt.eq(0) * seq_mask).sum()
p_total_tokens = (outputs.eq(0) * seq_mask).sum()
​
output_ids = 1 - outputs
​
return {'correct': correct_tokens.item(), 'total': total_tokens.item(), 'hit': h_tokens.item(),
'r_total': r_total_tokens.item(), 'p_total': p_total_tokens.item(),
'output_ids': output_ids.tolist()}      

pivot2table子產品使用模型:

transformer.py

一個标準的transformer架構

       self.encoder = TransformerEncoder(n_hidden, ff_size, n_head, dropout, n_block)
        self.decoder = TransformerDecoder(n_hidden, ff_size, n_head, dropout, n_block)
​
        self.generator = nn.Linear(n_hidden, self.vocab_size)
        self.accuracy = SequenceAccuracy()
​
    def encode(self, src):
        embs = self.src_embedding(src)
        out = self.encoder(embs)
        return out
​
    def decode(self, tgt, memory, step_wise=False):
        embs = self.tgt_embedding(tgt)
        outputs = self.decoder(embs, memory, step_wise)
        out, attns = outputs['outs'], outputs['attns']
        out = self.generator(out)
        return out, attns
​
    def forward(self,
                src: Dict[str, torch.Tensor],
                tgt: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        src, tgt = src['tokens'], tgt['tokens']
        encode_outputs = self.encode(src)
        out_logits, _ = self.decode(tgt[:, :-1], encode_outputs)
​
        targets = tgt[:, 1:].contiguous()
        seq_mask = (targets > 0).float()
​
        self.accuracy(predictions=out_logits, gold_labels=targets, mask=seq_mask)
        loss = sequence_cross_entropy_with_logits(logits=out_logits,
                                                  targets=targets,
                                                  weights=seq_mask,
                                                  average='token',
                                                  label_smoothing=self.label_smoothing)
        outputs = {'loss': loss}
​
        return outputs      

預測解碼時使用貪婪搜尋:

   def greedy_search(self,
                      src: Dict[str, torch.Tensor],
                      max_decoding_step: int) -> Dict[str, torch.Tensor]:
​
        src = src['tokens']
        ys = torch.ones(src.size(0), 1).long().fill_(self._bos).cuda()
        # self.decoder.init_cache()
        # output_ids, attns = [], []
​
        encode_outputs = self.encode(src)
​
        for i in range(max_decoding_step):
            outputs, attn = self.decode(ys, encode_outputs, step_wise=False)
            logits = outputs[:, -1]
            # logits = outputs
            next_id = logits.max(1, keepdim=True)[1]
            # ys = next_id
            # output_ids.append(ys)
            # attns.append(attn)
            ys = torch.cat([ys, next_id], dim=1)
​
        output_ids = ys[:, 1:]
        attns = attn
        # output_ids = torch.cat(output_ids, dim=1)
        # attns = torch.cat(attns, dim=1)
        alignments = attns.max(2)[1]
        outputs = {'output_ids': output_ids.tolist(), 'alignments': alignments.tolist()}
​
        return outputs      

整體流程:

table2pivot:

train:

train.t2p.10000.jsonl 1W标記資料用于訓練模型

然後用于驗證集valid.t2p.jsonl 7.2W資料,根據效果記錄模型是否最佳或需要提前停止

test:

test.t2p.jsonl 7.2W資料,加載best.th模型,生成test.predict.10000.jsonl關鍵事實(預測label>0)。

Accuracy: 85.21618881852217, Precision: 91.67537147343486, Recall: 87.06278910894117, F1: 89.3095632971675

pivot2text:

train:

train.p2t.10000.jsonl 78W訓練

accuracy: 0.9140, loss: 0.3335 ||: 100%|████| 9105/9105 [30:46<00:00, 5.84it/s]

然後用于驗證集valid.t2p.jsonl 7.2W資料,根據效果記錄模型是否最佳或需要提前停止

100%|█████████████████████████████████████████| 569/569 [03:16<00:00, 3.06it/s] BLEU = 29.33, 54.4/32.9/24.6/19.3 (BP=0.967, ratio=0.967, hyp_len=1836100, ref_len=1898421)

test:

100%|█████████████████████████████████████████| 569/569 [03:16<00:00, 2.88it/s] BLEU = 29.63, 54.2/32.9/24.6/19.3 (BP=0.978, ratio=0.978, hyp_len=1856611, ref_len=1898421)

根據test.predict.10000.jsonl生成最終文本與target比較。