解锁大模型长上下文能力

【本文已在同名 微信公众号 / 知乎 / 个人博客linsight.cn 上线】


步入2024年Q2,大模型在RAG、文档对话、大模型Agent能力等方向的发展持续升温。在平时的日常生活和工作中,大模型工具提供的文档总结、文本润色、代码生成等能力已经是提高效率的必备帮手,甚至在一些复杂或者不熟悉的场景上,大模型也已经能提供一些比较专业的帮助。

在这些方向上,大模型(超)长上下文的能力都是基础。无论是使用详细的CoT/ToT,还是通过多篇检索文档提供专业知识,抑或是使用相关样例提升回复质量,都需要模型具备处理很长的输入输出信息的能力。这不仅要求模型在较长的位置编码下依然具有良好的语言建模能力,而且还需要模型能够进行长距离的、细致的阅、准确的阅读和理解。

本篇将梳理几个通过轻量级训练解锁大模型长上下文能力的工作。

支持128k上下文的数据工程

论文:Data Engineering for Scaling Language Models to 128K Context

时间:2024年2月

阶段:预训练

长度:128k

评测指标

模型的长上下文能力不仅体现在文本较长的时候,模型的PPL依然能保持在较低的水平,还体现在对于长上下文输入,模型依然能够进行准确的阅读理解和推理。

以往一些工作仅使用validation dataset上的PPL作为评测指标,并不能很好地表征模型的真实长上下文能力。而目前被广泛使用的 Needle in a Haystack,或者叫大海捞针任务,是对长上下文能力的一个比较好的评测。这篇论文主要就以大海捞针任务为标准,对不同的模型和方案进行对比。

两个PPL几乎相同的模型,在大海捞针任务上的差距可以很大,如下图所示,颜色越绿代表正确率越高

目前已有的一些扩展大模型上下文窗口的方法,比如LongLoRA和Mistral所采用的YaRN,虽然理论上来说,能够支持>100k的上下文长度,但是实际上在大海捞针任务的表现却不太好。相关模型在大海捞针任务上的效果对比如下所示,只有GPT-4的效果比较好。

数据分布

这篇论文认为,在<=4k窗口长度完成预训练的模型,其实就已经基本具备在128k或者更大的上下文窗口进行推理的能力,只需要进行轻量级的继续预训练(e.g. <5B token),就能够解锁这种能力。

(而一些其他的工作在这方面则有着相反的观点,比如在32k窗口训练了400B token的《Effective long-context scaling of foundation models》,以及Xverse)

要做继续预训练,最重要的一点就是要决定使用什么样的数据。

这篇论文里的实验是基于LLAMA的,因此使用了和LLAMA预训练数据具有相近领域分布的SlimPajama数据集作为基础。

对于长上下文的继续预训练数据,需要仔细考虑数据长度和领域分布的影响。通常来说,某些领域天然会有更高比例的长文本数据,比如书籍、论文和github,而一些其他领域的长数据就较少,比如新闻。如果直接从整体数据中挑选长数据而忽略领域分布,就可能造成训练数据在领域分布上的偏移。

论文使用了几种不同的数据处理策略,用于后面的实验对比:
- Cut at 4K:把所有的数据按4k长度进行分块,这样不会影响领域分布。这也是很多4k预训练模型所采样的方案,比如LLAMA。
- Cut at 128K:截断长度提升到128k,可以保留长文本内部信息的依赖关系。LongLoRA就是这么做的。
- Per-source Upsampling:在保持各个领域的比例不变的前提下,对长文本进行上采样,提高长文本的比例。这是这篇论文所推荐的方法,实验效果最好。
- Global Upsampling:不管领域,直接对长文本进行上采样。
- Upsample Arxiv/ Book/ Github:提高特定领域的数据比例,对长文本进行上采样。

这些策略基本涵盖了大部分长文本相关工作在数据上的处理策略。

不同数据处理策略下,SlimPajama数据内各领域的分布如下图所示

Per-source Upsampling是效果最好的,也是这篇论文所推荐的数据工程策略。

实验配置

实验上,用80k的窗口长度训练LLAMA2-7B模型,用64k的窗口训练LLAMA2-13B模型。

虽然理论上,计算复杂度度和模型训练窗口长度是平方关系,但是实际实现上,由于有FlashAttention等方案,可以把Attention的计算通过设备间通讯,在多个设备间并行起来。而设备间的通讯(包括GPU和CPU,GPU和GPU之间)成本都是constant或者linear,因此实际上80k窗口的的训练耗时只是4k长度的训练的3倍,而不是理论上的400倍。

当然,实际所需的计算量并没有减少,但是至少时间成本从平方变成了线性。剩下的,只要堆jia卡qian就可以提速。

Per-source Upsampling和其他工作的数据处理策略的对比如下

训练的配置和耗时如下所示

实验的其他配置:
- lr = 2e-5
- RoPE base从1,0000改为500,000
- batch size = 4M token

训练量

前面提到,论文认为只需要轻量级的继续预训练就可以解锁长上下文能力,那么到底需要训练多少token呢?

论文分别取了训练了100M、300M、500M、1B、5B、10B token的中间checkpoint进行PPL和海底捞针任务评测,结果如下

结论是,在训练了500M token的时候,模型基本解锁了长上下文的能力;在训练了5B token的时候,模型已经收敛,而且继续训练到10B token也没有进一步收益了。

数据策略对比

使用前面提到的不同数据策略在LLAMA2-7B模型用5B token进行训练,并对比效果。

LLAMA2的预训练长度为4k,因此对比的时候分成了0-4k和4k-128k两段,分别评测模型经过长文本训练后,在短文本上的效果是否有变差,以及在长文本上是否有提升。

各个数据策略在不同领域的效果变化如下

可以得到几个结论:
- 在0-4k长度上,除了Per-source Upsampling以外,各种数据策略都会对模型效果有损害
- 在一些领域上的提升,并不能很好地迁移到其他领域,比如Book和Github之间就有点跷跷板效应,其中一个效果好了,另一个可能就有损失
- 在4k-128k,Per-source Upsampling在各个领域的效果相对较为平衡(绿色的数量最多)

此外,length upsampling很重要。Per-source Upsampling的策略在领域上可以和源数据保持一致,而提升长文本的比例。

用同样80k的训练窗口在LLAMA2-7B进行实验,一个使用原数据进行拼接,另一个使用Per-source Upsampling,结果如下。在PPL基本相同的情况下,Per-source Upsampling在大海捞针的效果远超原数据。这说明提高长文本的比例,能极大优化模型远距离建模的能力。

结论

通过实验,论文提出提升模型长上下文能力的数据工程实践的几个关键点:
- 在长窗口上进行轻量级训练,可以提升模型实际的远距离建模能力,而不仅仅是保持PPL较低
- 领域之间有竞争关系,最好和原预训练模型所用的分布保持一致
- 长度上采样对最终效果有很大影响,要提高各领域内长文本的比例

Paraphrasing

论文:Training With "Paraphrasing the Original Text" Improves Long-Context Performance

时间:2023年12月

阶段:微调

长度:在50k长度依然能有较好的效果,如下所示。

检索能力

对于长上下文的任务,有用的信息通常是稀疏的,一般只有少量的句子或者段落包含了可以用于回答问题的有用信息。可以隐式地将这样长上下文的任务拆分成两个子任务,即相关信息的检索,和根据相关信息回答问题两个任务。

目前一些支持长上下文的方法,比如位置编码相关的线性插值、NTK插值、YaRN等,虽然使得模型在形式上支持了长上下文的任务,但是在任务的准确性上效果却不佳。

使用这些优化方案的模型依然会遇到lost in the middle的问题,即模型天然更容易关注到输入文本的开头和结尾部分的信息,而更容易忽略中间部分的信息,注意力迷失在大量无关内容上,而无法集中到少数相关的位置上。而对于长上下文的任务,大量的信息是处于middle的位置的,如果忽略这些信息自然会使得任务效果不好。而效果不好的原因就是模型在长上下文的情况下,retrieval能力偏弱,不能找到有用的信息。

相关工作

一些工作直接把模型在长窗口下进行训练,比如:
- Together的LLaMA-2-7B-32K(https://huggingface.co/datasets/togethercomputer/Long-Data-Collections);Together开源了Multipassage-QA-from-Natural-Questions和BookSum微调数据集。
- LongAlpaca(《LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models》)
- Ziya-Reader(《Never Lost in the Middle:Improving Large Language Models via Attention Strengthening Question Answering》)

直接在长窗口训练有一定的效果,但是依然有几个问题:
- 模型推理窗口越来越大,所需的训练数据集长度也要不断更新。
- 随着长度增大,训练成本变高。
- 构建长上下文数据集的成本比价高,高质量的数据并不容易获得。虽然有一些开源的数据集,但是在实际场景上可能还需要做领域适配,分布调整等工作。

一个更简单一点的方法是优化prompt的设计,比如CoT。

在长上下文的场景下,可以通过prompt让模型显式地先找到原文的相关信息再进行回答。比如Claude-2.1就通过在prompt增加“Here is the most relevant sentence in the context”让长文本问答的准确率从27%提升到98%(https://www.anthropic.com/news/claude-2-1-prompting)。

也可以对输入内容进行重新的编排:
- LongLLMLingua(《LongLLMLingua: Accelerating and Enhancing LLMs in Long Context Scenarios via Prompt Compression》)对输入文本进行了压缩。
- Attention Sorting(《Attention Sorting Combats Recency Bias In Long Context Language Models》)在decode过程中根据各个文档被分配到的注意力值,对文档进行重新排序。

提高检索能力

这篇论文提出了一个叫检索相关度(retrieval relevance)的指标,一个token(或者n-gram) \(x\) 的检索相关度 \(R(x)\) 定义如下。

\[R(x)=\frac{n^\prime}n\log\frac N{N^\prime+1}\]

这个指标和TF-IDF很像。其中,\(n^\prime\) 表示 \(x\) 在gold-chunk中的频率,而 \(n\) 是gold-chunk中的总token数;\(N\) 表示整个上下文中总chunk数,\(N^\prime\) 是包含x的chunk的数量。

基于token \(x\) 的检索相关度 \(R(x)\) ,定义训练样本 \(S\) 的检索相关度如下

\[\mathcal{R}(S)=\frac{1}{|S_a|}\sum_{x\in\mathcal{S}_a}R(x)\]

其中 \(S_a\) 表示 \(S\) 的答案部分。

通过 \(\mathcal{R}(S)\) 这个指标可以反映出一个训练样本对模型提高检索能力的贡献。\(\mathcal{R}(S)\) 越高,这个样本对提高模型检索能力的贡献越大。

那么一个简单有效提升训练样本检索相关度的做法,就是把gold-chunk放到答案中,即paraphrasing the original text。

一个paraphrasing和其他答案设计方案对比的例子如下

其中高亮部分的token是高检索相关度的token,明显paraphrasing拥有更高的比例。

论文使用GPT-4来构建包含paraphrasing的问答对,流程实际如下

这种方式收集了一批单文档问答和多文档问答的数据,再加上一些传统文本摘要数据(摘要不好用这种方式构建,因此直接使用)等,构成一个包含10,825条英文数据,8,454条中文数据,长度在8k和32k之间的数据集。数据集详细的领域分布如下所示

论文构建的数据集和Multi-passage-QA-from-NQ的检索相关性指标对比如下

使用这个数据集微调的模型,和其他模型在LongBench上的效果对比如下

另外,在这个数据集上微调之后,模型对于lost in the middle的问题也有一定的缓解,如下所示

PoSE

论文:PoSE: Efficient Context Window Extension of LLMs via Positional Skip-wise Training

时间:2023年9月

阶段:微调

长度:128k

背景

目前大部分流行的大模型使用旋转位置编码RoPE。在短文本上训练的模型,在长输入上效果不好的原因之一,就是长文本有很多模型没有见过没有训练过的位置编码。

基于位置编码的长上下文优化,比如线性插值、NTK插值和YaRN等,依然需要进行目标长度的训练才能有比价好的效果。而随着目标长度越来越长(8k,32k,128k...),这样的训练成本也越来越高,逐渐变得不容易进行。

这篇论文提出Positional Skip-wisE,PoSE,通过在短的训练窗口模拟长距离的位置编码,提升模型处理长上下文的能力。模型可以在2k的训练窗口进行训练,而在128k的长度进行推理。相比直接训练128k模型效率更高。

也有一些工作的思路和这篇文章有相近之处,比如RandPos(《Randomized positional encodings boost length generalization of transformers》),但是RandPos主要用于预训练阶段,并且相邻token之间的位置是不连续的,而PoSE主要用于微调阶段,相邻token之间的位置是连续的。

位置模拟

PoSE提出两个设计原则:
- 模拟所用的位置编码index要覆盖目标长度的范围。如果我们想在128k的窗口进行推理,那就要保证训练的时候,模型从1-128k的位置编码都见过。
- 为了不损害原模型的能力,位置编码应该尽量保持原来预训练的结构,即尽量连续,和保持顺序关系。

假设我们的训练窗口长度为 \(L_c\),首先我们随机把它切成 \(N\) 个chunk, \(c_0,c_1,\ldots,c_{N-1}\),长度分别为 \(l_0,l_1,\ldots,l_{N-1}\)。对于chunk \(i\),其中token的位置编码下标如下

\[\mathrm{Pos}(c_i)=\{st_i,st_i+1,\ldots,st_i+l_i-1\},\quad st_i=\sum_{j=0}^{i-1}l_j\]

然后我们给每个chunk,从uniform distribution \(\mathcal{U}(S)\) 中随机采样一个skipping bias \(u_i\),把这个bias加到这个对应chunk的token的位置编码下标中,就有

\[\mathrm{PoSE}(c_i)=\{u_i+st_i,u_i+st_i+1,\ldots,u_i+st_i+l_i-1\}\]

这里要注意,处理后各个chunk的位置编码下标不能有overlap,所以要求 \(u_i\geq u_{i-1}\)

直观地说,引入skipping bias使模型能接触到更大范围的位置编码。为了全面覆盖目标上下文窗口,我们为每个训练sample单独采样每个chunk的长度和skipping bias。

此外,位置编码index在每个chunk内的连续性,与原模型预训练期间所采用的结构非常相似。因此,在这些新的index上进行微调,不会损害模型原有的能力。

现在,位置编码的下标决定好了,我们还需要决定每个chunk的token使用哪些。

token的采样和位置编码下标的采样类似,具体来说,我们采样\(v_i\sim\mathcal{U}(\{v_{i-1},\ldots,L_x-L_c\})\),那么 \(c_i\) 的token如下

\[c_i=\boldsymbol{x}[v_i+st_i:v_i+st_i+l_i]\]

论文对一些采样变体,比如 \(v_i=u_i\)\(v_i=0\) 等进行了探索,发现基本没有什么影响,因此 \(v_i\) 保持原来的采样方案即可。

在实际训练中,\(N\) 设置为2,因为如果太大可能对原模型的能力造成损害。而 \(u_0\)\(v_0\) 设为了0。

PoSE方案如下图所示

实验上,使用了LLAMA-7B模型,在2k的窗口上进行了1,000步的训练,batch size为64。使用lr=2e-5,warmup step=10。

PoSE和其他模型在PPL上的对比如下,基本能达到和Full-length训练相近的水平。

而在passkey retrieval任务上,也有不错的效果,如下图所示

相比其他方案,PoSE的一个优势是可以在没有任何成本增加的情况下,支持更长的推理长度。比如可以通过简单修改采样策略的参数,PoSE就可以支持到1M,甚至更大的窗口长度,这是其他方法难以做到的。

小结

  1. 有了FlashAttention等方案之后,在128k这个长度,我们也有能力在合理的成本下,进行继续预训练,使用5B左右的token解锁模型的长上下文能力。
  2. 预训练中,长文本对模型的远距离建模能力很重要,要提高长文本的比例才有更好的效果。此外,领域的分布也是一个需要关注的点。
  3. 在长窗口的微调上,精心设计输入输出形式能带来一些收益。
  4. 对于更长的窗口,比如M级别这种几乎无法直接训练/微调的长度,PoSE这种模拟的方案能够在不增加成本的情况下,在效果上达到接近直接训练/微调的表现。

读到这了,来一发点赞收藏关注吧~

博客:http://www.linsight.cn/
知乎:Linsight
微信公众号:Linsight


【往期文章】

MoE模型的前世今生
LLM长上下文的问题
解锁大模型长上下文能力
理解Attention:从起源到MHA,MQA和GQA
Yi技术报告-划重点看细节
transformer中normalization的二三事
从代码实现看normalization-到底做了什么
稀疏注意力计算:sliding window attention
理解LLM位置编码:RoPE
大模型算法题(1)
大模型算法题(2)
大模型算法题(3)
大模型算法题(4)
大模型算法题(5)


Reference

【1】Data Engineering for Scaling Language Models to 128K Context https://arxiv.org/abs/2402.10171
【2】Training With "Paraphrasing the Original Text" Improves Long-Context Performance https://arxiv.org/abs/2312.11193
【3】PoSE: Efficient Context Window Extension of LLMs via Positional Skip-wise Training https://arxiv.org/abs/2309.10400