Transformer-XL: Attentive Language Models Beyonds a Fixed-Length Context


1. Transformer-XL的由来

在正式讨论Transformer-XL之前,我们先来看看经典的Transformer(后文称Vanilla Transformer)是如何处理数据和训练评估模型的,如图1所示。

image-20210604171637306

图1 Vanilla Transformer 训练和评估阶段

数据处理方面,给定一串较长的文本串,Vanilla Transformer会按照固定的长度(比如512),直接将该文本串进行划分成若干Segment。这个处理方式不会关注文本串中语句本身的边界(比如标点或段落),这样”粗暴”的划分通常会将一句完整的话切分到两个Segment里面,导致上下文碎片化(context fragmentation)。另外,Transformer本身能够维持的依赖长度很有可能会超出这个固定的划分长度,从而导致Transformer能够捕获的最大依赖长度不超过这个划分长度,Transformer本身达不到更好的性能。

模型训练方面,如图1a所示,Vanilla Transformer每次传给模型一个Segment进行训练,第1个Segment训练完成后,传入第2个Segment进行训练,然而前后的这两个Segment是没有任何联系的,也就是前后的训练是独立的。但事实是前后的Segment其实是有关联的。

模型评估方面,如图1b所示,Vanilla Transformer会采用同训练阶段一致的划分长度,但仅仅预测最后一个位置的token,完成之后,整个序列向后移动一个位置,预测下一个token。这个处理方式保证了模型每次预测都能使用足够长的上下文信息,也缓解了训练过程中的context framentation问题。但是每次的Segment都会重新计算,计算代价很大。

基于上边的这些不足,Transformer-XL被提出来解决这些问题。它主要提出了两个技术:Segment-Level 循环机制相对位置编码Transformer-XL能够建模更长的序列依赖,比RNN长80%,比Vanilla Transformer长450%。同时具有更快的评估速度,比Vanilla Transformer快1800+倍。同时在多项任务上也达到了SoTA的效果。

2. Transformer-XL 建模更长序列

2.1 Segment-Level 循环机制

Transformer-XL通过引入Segment-Level recurrence mechanism来建模更长序列,它通过融合前后两个Segment的信息来到这个目的。

这里循环机制和RNN循环机制类似,在RNN中,每个时刻的RNN单元会接收上个时刻的输出和当前时刻的输入,然后将两者融合计算得出当前时刻的输出。Transformer-XL同样是接收上个时刻的输出和当前时刻的输入,然后将两者融合计算得出当前时刻的输出。但是两者的处理单位并不相同,RNN的处理单位是一个词,Transformer-XL的处理单位是一个Segment。图2展示了Transformer-XL在训练阶段和评估阶段的Segment处理方式。

image-20210604181648404

图2 Transformer-XL的训练和评估阶段

模型训练阶段,如图2a所示,Transformer-XL会缓存前一个Segment的输出序列,在计算下一个Segment的输出时会使用上一个Segment的缓存信息,将前后不同Segment的信息进行融合,能够帮助模型看见更远的地方,建模更长的序列依赖能力,同时也避免了context fragmentation问题。

模型评估阶段,如图2b所示,Transformer-XL通过缓存前一个Segment的输出序列,当下一个Segment需要用这些输出时(前后两个Segment具有大部分的重复),不需要重新计算,从而加快了推理速度。

下边我们来具体聊聊这些事情是怎么做的,假设前后的两个Segment分别为:\(\text{s}_{\tau}=[x_{\tau,1},x_{\tau,2},...,x_{\tau,L}]\)\(\text{s}_{\tau+1}=[x_{\tau+1,1},x_{\tau+1,2},...,x_{\tau+1,L}]\),其中序列长度为\(L\)。另外假定\(h_{\tau}^n \in \mathbb{R}^{L \times d}\)为由\(\text{s}_{\tau}\)计算得出的第\(n\)层的状态向量,则下一个Segment \(\text{s}_{\tau+1}\)的第\(n\)层可按照如下方式计算:

\[\begin{split} \begin{align} & \tilde{h}_{\tau+1}^{n-1} = \left[ \text{SG}(h_{\tau}^{n-1}) \; \circ \;h_{\tau+1}^{n-1} \right] \\ & q_{\tau+1}^{n}, \; k_{\tau+1}^n, \; v_{\tau+1}^n = h_{\tau+1}^{n-1}W_{q}^{\mathrm{ T }}, \; \tilde{h}_{\tau+1}^{n-1}W_{k}^{\mathrm{ T }}, \; \tilde{h}_{\tau+1}^{n-1}W_{v}^{\mathrm{ T }} \\ & h_{\tau+1}^n = \text{Transformer-Layer}(q_{\tau+1}^{n}, \; k_{\tau+1}^n, \; v_{\tau+1}^n) \end{align} \end{split}\]

其中,\(\text{SG}(h_{\tau}^{n-1}) \)表示不使用梯度,\(\left[ \text{SG}(h_{\tau}^{n-1}) \; \circ \;h_{\tau+1}^{n-1} \right]\)表示将前后两个Segment的输出向量在序列维度上进行拼接。中间的公式表示获取Self-Attention计算中相应的\(q,k,v\)矩阵,其中在计算\(q\)的时候仅仅使用了当前Segment的向量,在计算\(k\)\(v\)的时候同时使用前一个Segment和当前Segment的信息。最后通过Self-Attention融合计算,得出当前Segment的输出向量序列。

2.2 相对位置编码

Segment-Level recurrence mechanism看起来已经做到了长序列建模,但是这里有个问题需要进一步讨论一下。我们知道,在Vanilla Transformer使用了绝对位置编码,我们来看看如果将绝对位置编码应用到Segment-Level recurrence mechanism中会怎样。

还是假设前后的两个Segment分别为:\(\text{s}_{\tau}=[x_{\tau,1},x_{\tau,2},...,x_{\tau,L}]\)\(\text{s}_{\tau+1}=[x_{\tau+1,1},x_{\tau+1,2},...,x_{\tau+1,L}]\),其中序列长度为\(L\)。每个Segment的Position Embedding矩阵为\(U_{1:L} \in \mathbb{R}^{L \times d}\), 每个Segment \(\text{s}_{\tau}\)的词向量矩阵为\(E_{\text{s}_{\tau}} \in \mathbb{R}^{L \times d}\),在Vanilla Transformer中,两者相加输入模型参与计算,如下式所示:

\[\begin{split} h_{\tau+1} = f(h_{\tau},\; E_{\text{s}_{\tau+1}}+U_{1:L}) \\ h_{\tau} = f(h_{\tau-1},\; E_{\text{s}_{\tau}}+U_{1:L}) \end{split}\]

很明显,如果按照这个方式计算,前后两个段\(E_{\text{s}_{\tau}}\)\(E_{\text{s}_{\tau+1}}\)将具有相同的位置编码,这样两者信息融合的时候肯定会造成位置信息混乱。为了避免这份尴尬的操作,Transformer-XL使用了相对位置编码

相对位置是通过计算两个token之间的距离定义的,例如第5个token相对第2个token之间的距离是3, 那么位置\(i\)相对位置\(j\)的距离是\(i-j\),假设序列之中的最大相对距离\(L_{max}\),则我们可以定义这样的一个相对位置矩阵\(R \in \mathbb{R}^{L_{max} \times d}\),其中\(R_k\)表示两个token之间距离是\(k\)的相对位置编码向量。注意在Transformer-XL中,相对位置编码向量不是可训练的参数,以\(R_k = [r_{k,1}, r_{k,2},...,r_{k,d}]\)为例,每个元素通过如下形式生成:

\[ r_{b,2j} = \text{sin}(\frac{b}{10000^{2j/d}}), \quad r_{b,2j+1} = \text{cos}(\frac{b}{10000^{(2j)/d}}) \]

Transformer-XL将相对位置编码向量融入了Self-Attention机制的计算过程中,这里可能会有些复杂,我们先来看看Vanilla Transformer的Self-Attention计算过程,如下:

\[\begin{split} \begin{align} A_{i,j}^{\text{abs}} &= (W_q(E_{x_i}+U_i))^{\text{T}}(W_k(E_{x_j}+U_j))) \\ &= \underbrace {E_{x_i}^{\text{T}} W_q^{\text{T}} W_k E_{x_j}}_{(a)} + \underbrace {E_{x_i}^{\text{T}} W_q^{\text{T}} W_k U_j}_{(b)} + \underbrace {U_{i}^{\text{T}} W_q^{\text{T}} W_k E_{x_j}}_{(c)} + \underbrace {U_{i}^{\text{T}} W_q^{\text{T}} W_k U_{j}}_{(d)} \end{align} \end{split}\]

其中\(E_{x_i}\)表示token \(x_i\)的词向量,\(U_i\)表示其绝对位置编码,根据这个展开公式,Transformer-XL将相对位置编码信息融入其中,如下:

\[ \begin{align} A_{i,j}^{\text{rel}} = \underbrace {E_{x_i}^{\text{T}} W_q^{\text{T}} W_{k,E} E_{x_j}}_{(a)} + \underbrace {E_{x_i}^{\text{T}} W_q^{\text{T}} W_{k,R} R_{i-j}}_{(b)} + \underbrace {u^{\text{T}} W_{k,E} E_{x_j}}_{(c)} + \underbrace {v^{\text{T}} W_{k,R} R_{i-j}}_{(d)} \end{align} \]

这里做了这样几处改变以融入相对位置编码:

  1. 在分项\((b)\)\((d)\)中,使用相对位置编码\(R_{i-j}\)取代绝对位置编码\(U_j\)

  2. 在分项\((c)\)\((d)\)中,使用可训练参数\(u\)\(v\)取代\(U_{i}^{\text{T}} W_q^{\text{T}}\)。因为\(U_{i}^{\text{T}} W_q^{\text{T}}\)表示第\(i\)个位置的query 向量,这个query向量对于其他要进行Attention的位置来说都是一样的,因此可以直接使用统一的可训练参数进行替换。

  3. 在所有分项中,使用\(W_{k,E}\)\(W_{k,R}\)计算基于内容(词向量)的key向量和基于位置的key向量。

式子中的每个分项分别代表的含义如下:

  1. \((a)\)描述了基于内容的Attention

  2. \((b)\)描述了内容对于每个相对位置的bias

  3. \((c)\)描述了内容的全局bias

  4. \((d)\)描述了位置的全局bias

2.3 完整的Self-Attention计算过程

上边描述了Transformer-XL中的两个核心技术:Segment-Level 循环机制相对位置编码,引入了这两项技术之后,Transformer-XL中从第\(n-1\)层到第\(n\)层完整的计算过程是这样的:

\[\begin{split} \begin{align} \tilde{h}_{\tau}^{n-1} &= \left[ \text{SG}(h_{\tau-1}^{n-1}) \; \circ \;h_{\tau}^{n-1} \right] \\ q_{\tau}^{n}, \; k_{\tau}^n, \; v_{\tau}^n &= h_{\tau}^{n-1}{W_{q}^n}^{\mathrm{ T }}, \; \tilde{h}_{\tau}^{n-1}{W_{k,E}^n}^{\mathrm{ T }}, \; \tilde{h}_{\tau}^{n-1}{W_{v}^n}^{\mathrm{ T }} \\ A_{\tau,i,j}^{n} &= {q_{\tau, i}^{n}}^{\text{T}}k_{\tau,j}^{n} + {q_{\tau, i}^{n}}^{\text{T}}W_{k,R}^{n}R_{i-j} + u^{\text{T}}k_{\tau,j} + v^{\text{T}}W_{k,R}^{n}R_{i-j} \\ {\alpha}_{\tau}^n &= \text{Masked-Softmax}(A_{\tau}^n)v_{\tau}^n \\ {\omicron}_{\tau}^n & = \text{LayerNorm}(\text{Linear}({\alpha}_{\tau}^n)+h_{\tau}^{n-1}) \\ h_{\tau}^n &= \text{Positionwise-Feed-Forward}({\omicron}_{\tau}^n) \end{align} \end{split}\]