首页 分享 Beam Search的学习笔记(附代码实现)

Beam Search的学习笔记(附代码实现)

来源:萌宠菠菠乐园 时间:2024-09-21 15:38

引言

Beam Search 是一种受限的宽度优先搜索方法,经常用在各种 NLP 生成类任务中,例如机器翻译、对话系统、文本摘要。本文首先介绍 Beam Search 的基本思想,然后再介绍一些beam search的优化方法,最后附上自己的代码实现。

1. Beam Search的基础版本

在生成文本的时候,通常需要进行解码操作,贪心搜索 (Greedy Search) 是比较简单的解码。Beam Search 对贪心搜索进行了改进,扩大了搜索空间,更容易得到全局最优解。Beam Search 包含一个参数 beam size k,表示每一时刻均保留得分最高的 k 个序列,然后下一时刻用这 k 个序列继续生成。示意图如下所示:
在这里插入图片描述
假设我们生成词表中有三个单词{我,爱,你}。我们设 K = 2 K=2 K=2。那么我们在第一时刻确定两个候选输出是{我,你}。紧接着我们要考虑第二个输出,具体步骤如下:

确定单词“我”为第一时刻输出,并将其作为第二时刻输入,在已知 p ( x , 我 ) p(x,我) p(x,我)的情况下,各个单词的输出概率为3种情况,每个组合的概率为 P ( 我 ∣ x ) P ( y 2 ∣ x , 我 ) P(我|x)P(y_2|x,我) P(我∣x)P(y2​∣x,我)。同样我们把“你”也作为第二时刻输入,同样也有三种组合。最后我们在六种组合中选择概率最大的三个组合。

接下来要做的重复这个过程,逐步生成单词,直到遇到结束标识符停止。最后得到概率最大的那个生成序列。其概率为:在这里插入图片描述

以上就是Beam search算法的思想,当beam size=1时,就变成了贪心算法。

2. Beam Search的优化

Beam search算法也有许多改进的地方。

2.1 Length normalization:惩罚短句

根据最后的概率公式可知,该算法倾向于选择最短的句子,因为在这个连乘操作中,每个因子都是小于1的数,因子越多,最后的概率就越小。解决这个问题的方式,最后的概率值除以这个生成序列的单词数,这样比较的就是每个单词的平均概率大小。此外,连乘因子较多时,可能会超过浮点数的最小值,可以考虑取对数来缓解这个问题。谷歌给的公式如下:
在这里插入图片描述
其中α∈[0,1],谷歌建议取值为[0.6,0.7]之间,α用于length normalization。

2.2 Coverage normalization:惩罚重复

另外我们在序列到序列任务中经常会发现一个问题,2016 年, 华为诺亚方舟实验室的论文提到,机器翻译的时候会存在over translation or undertranslation due to attention coverage。 作者提出coverage-based atttention机制来解决coverage 问题。 Google machine system 利用了如下的方式进行了length normalization 和 coverage penalty。

还是上述公式,β用于控制coverage penalty
在这里插入图片描述

coverage penalty 主要用于使用 Attention 的场合,通过 coverage penalty 可以让 Decoder 均匀地关注于输入序列 x x x 的每一个 token,防止一些 token 获得过多的 Attention

2.3 End of sentence normalization:抑制长句

有的时候我们发现生成的序列一直生成下去不会停止,有的时候我们可以显式的设置最大生成长度进行控制,这里我们可以采用下式来进行约束:
在这里插入图片描述
其中 ∣ X ∣ |X| ∣X∣是source的长度, ∣ Y ∣ |Y| ∣Y∣是当前target的长度,那么由上式可知,target长度越长的话,上述得分越低,这样就会防止出现生成一直不停止的情况。


3. Beam Search的代码实现

总的来说,beam search不保证全局最优,但是比greedy search搜索空间更大,一般结果比greedy search要好。下面附上一些代码实现:

首先,首先定义一个 Beam 类,作为一个存放候选序列的容器,属性需维护当前序列中的 token 以及对应的对数概率,同时还需维护跟当前 timestep 的 Decoder 相关的一些变量。此外,还需要给 Beam 类实现两个函数:一个 extend 函数用以扩展当前的序列(即添加新的 time step的 token 及相关变量);一个 score 函数用来计算当前序列的分数(在Beam类下的seq_score函数中有Length normalization以及Coverage normalization)。

class Beam(object): def __init__(self, tokens, log_probs, decoder_states, coverage_vector): self.tokens = tokens self.log_probs = log_probs self.decoder_states = decoder_states self.coverage_vector = coverage_vector def extend(self, token, log_prob, decoder_states, coverage_vector): return Beam(tokens=self.tokens + [token], log_probs=self.log_probs + [log_prob], decoder_states=decoder_states, coverage_vector=coverage_vector) def seq_score(self): """ This function calculate the score of the current sequence. """ len_Y = len(self.tokens) # Lenth normalization ln = (5+len_Y)**config.alpha / (5+1)**config.alpha cn = config.beta * torch.sum( # Coverage normalization torch.log( config.eps + torch.where( self.coverage_vector < 1.0, self.coverage_vector, torch.ones((1, self.coverage_vector.shape[1])).to(torch.device(config.DEVICE)) ) ) ) score = sum(self.log_probs) / ln + cn return score def __lt__(self, other): return self.seq_score() < other.seq_score() def __le__(self, other): return self.seq_score() <= other.seq_score()

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647

接着我们需要实现一个 best_k 函数,作用是将一个 Beam 容器中当前 time step 的变量传入 Decoder 中,计算出新一轮的词表概率分布,并从中选出概率最大的 k 个 token 来扩展当前序列(其中加入了End of sentence normalization),得到 k 个新的候选序列。

def best_k(self, beam, k, encoder_output, x_padding_masks, x, len_oovs): """Get best k tokens to extend the current sequence at the current time step. """ # use decoder to generate vocab distribution for the next token x_t = torch.tensor(beam.tokens[-1]).reshape(1, 1) x_t = x_t.to(self.DEVICE) # Get context vector from attention network. context_vector, attention_weights, coverage_vector = self.model.attention(beam.decoder_states, encoder_output, x_padding_masks, beam.coverage_vector) # Replace the indexes of OOV words with the index of OOV token # to prevent index-out-of-bound error in the decoder. p_vocab, decoder_states, p_gen = self.model.decoder(replace_oovs(x_t, self.vocab), beam.decoder_states, context_vector) final_dist = self.model.get_final_distribution(x, p_gen, p_vocab, attention_weights, torch.max(len_oovs)) # Calculate log probabilities. log_probs = torch.log(final_dist.squeeze()) # Filter forbidden tokens. # EOS token penalty. Follow the definition in # https://opennmt.net/OpenNMT/translation/beam_search/. log_probs[self.vocab.EOS] *= config.gamma * x.size()[1] / len(beam.tokens) log_probs[self.vocab.UNK] = -float('inf') # Get top k tokens and the corresponding logprob. topk_probs, topk_idx = torch.topk(log_probs, k) # Extend the current hypo with top k tokens, resulting k new hypos. best_k = [beam.extend(x, log_probs[x], decoder_states, coverage_vector) for x in topk_idx.tolist()] return best_k

123456789101112131415161718192021222324252627282930313233343536373839404142434445

最后我们实现主函数 beam_search。初始化encoder、attention和decoder的输⼊,然后对于每⼀个decodestep,对于现有的k个beam,我们分别利⽤best_k函数来得到各⾃最佳的k个extended beam,也就是每个decode step我们会得到k*k个新的beam,然后只保留分数最⾼的k个,作为下⼀轮需要扩展的k个beam。为了只保留分数最⾼的k个beam,我们可以⽤⼀个堆(heap)来实现,堆的中只保存k个节点,根结点保存分数最低的beam。

def beam_search(self, x, max_sum_len, beam_width, len_oovs, x_padding_masks): """Using beam search to generate summary. """ # run body_sequence input through encoder encoder_output, encoder_states = self.model.encoder( replace_oovs(x, self.vocab)) coverage_vector = torch.zeros((1, x.shape[1])).to(self.DEVICE) # initialize decoder states with encoder forward states decoder_states = self.model.reduce_state(encoder_states) # initialize the hypothesis with a class Beam instance. init_beam = Beam([self.vocab.SOS], [0], decoder_states, coverage_vector) # get the beam size and create a list for stroing current candidates # and a list for completed hypothesis k = beam_width curr, completed = [init_beam], [] # use beam search for max_sum_len (maximum length) steps for _ in range(max_sum_len): # get k best hypothesis when adding a new token topk = [] for beam in curr: # When an EOS token is generated, add the hypo to the completed # list and decrease beam size. if beam.tokens[-1] == self.vocab.EOS: completed.append(beam) k -= 1 continue for can in self.best_k(beam, k, encoder_output, x_padding_masks, x, torch.max(len_oovs) ): # Using topk as a heap to keep track of top k candidates. # Using the sequence scores of the hypos to campare # and object ids to break ties. add2heap(topk, (can.seq_score(), id(can), can), k) curr = [items[2] for items in topk] # stop when there are enough completed hypothesis if len(completed) == beam_width: break # When there are not engouh completed hypotheses, # take whatever when have in current best k as the final candidates. completed += curr # sort the hypothesis by normalized probability and choose the best one result = sorted(completed, key=lambda x: x.seq_score(), reverse=True)[0].tokens return result

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263

相关知识

详解pytorch实现猫狗识别98%附代码
Web课程设计:宠物网站设计——萌宠有家(5页) HTML+CSS 简单DIV布局宠物介绍网页模板代码
HTML学生个人网站作业设计:宠物网站设计——萌宠有家(5页) HTML+CSS 简单DIV布局宠物介绍网页模板代码 DW学生个人网站制作成品下载
机器学习实战笔记3(决策树与随机森林)
基于深度学习的鸟类识别系统(网页版+YOLOv8/v7/v6/v5代码+训练数据集)
第12周实训任务:实现宠物管理功能页面
FastAI 课程学习笔记 lesson 1:宠物图片分类
基于深度学习的犬种识别系统详解(网页版+YOLOv8/v7/v6/v5代码+训练数据集)
HTML做一个简单漂亮的宠物网页(纯html代码)
猫的遗传相关学习笔记

网址: Beam Search的学习笔记(附代码实现) https://www.mcbbbk.com/newsview205980.html

所属分类:萌宠日常
上一篇: 怎样训练狗狗自己上厕所便便?
下一篇: 调皮的日子阅读题=设计

推荐分享