Contents

在大模型推理中击败不确定性[译]

前OpenAI核心人物解释LLMs不确定性来源 可复现性,是科学进步的基石。然而,想让大语言模型(LLM)给出可复现的结果,却出奇地难。

举个例子,你可能会发现,多次向 ChatGPT 问同一个问题,会得到不同的答案。这本身不奇怪,因为从语言模型获得结果涉及一个“采样”过程,也就是将模型的输出转换成一个概率分布,然后依概率选择一个词元(token)。

但更令人惊讶的或许是,即便我们把“温度”(temperature)参数调到 0(这在理论上会让采样过程变成确定性的,即总是选择概率最高的词元,也叫“贪心采样”),LLM 的 API 在实际中仍然不是确定性的(可以看看过去的一些讨论:这里这里这里)。甚至当你在自己的硬件上,用 vLLM 或 SGLang 这样的开源推理库来运行时,采样结果依旧不是确定性的(见这里这里)。

可为什么 LLM 的推理引擎不是确定性的呢?一个常见的假说是,浮点数运算的非结合性(non-associativity)与并发执行(concurrent execution)共同导致了不确定性——哪个并发核心先算完,结果就可能不一样。我们称之为 LLM 推理非确定性的“并发+浮点数”假说。例如,最近一篇 arXiv 预印本中写道:

GPU 中的浮点数算术存在非结合性,即 $(a+b)+c \neq a+(b+c)$,这是由有限精度和舍入误差造成的。这一特性直接影响了 Transformer 架构中注意力分数和 logits 的计算,其中跨多线程的并行操作会因执行顺序的不同而产生不同的结果。

你也能在别处看到类似的“并发+浮点数”假说,比如这里(“为了速度需要做权衡,为了让端点更快,用了 GPU,而 GPU 做的是并行[非确定性]计算。任何现代 GPU 神经网络计算都会受此影响。”),或者这里(“因为 GPU 是高度并行的,加法或乘法的顺序在每次执行时可能不同,这会像滚雪球一样导致输出的微小差异。”)。

这个假说不能说全错,但它没有揭示事情的全貌。比如,即便在 GPU 上,对同样的数据反复执行同一个矩阵乘法,每次总能得到逐比特(bitwise)完全相同的结果。我们确实在用浮点数,我们的 GPU 也确实有大量并发,那为什么在这个测试里就看不到不确定性呢?

A = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16)
B = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16)
ref = torch.mm(A, B)
for _ in range(1000):
    assert (torch.mm(A, B) - ref).abs().max().item() == 0

要理解 LLM 推理非确定性的真正原因,我们必须挖得更深。

麻烦的是,就连给 LLM 推理的“确定性”下个定义都很难。下面这些陈述可能看起来相互矛盾,但它们都是对的:

  1. GPU 上的某些内核(kernel)是非确定性的。
  2. 但是,语言模型前向传播(forward pass)中用到的所有内核都是确定性的。
  3. 并且,LLM 推理服务器(如 vLLM)的前向传播也可以声称是确定性的。
  4. 尽管如此,从任何一个使用推理服务器的用户的角度来看,结果都是非确定性的。

在这篇文章里,我们会解释为什么“并发+浮点数”假说没有抓住要害,揭开 LLM 推理非确定性背后的真凶,并说明如何战胜它,在 LLM 推理中获得真正可复现的结果。


原罪:浮点数运算的非结合性

在讨论非确定性之前,我们有必要先解释一下,为什么数值上会出现差异。毕竟,我们通常认为机器学习模型是遵循交换律或结合律等结构规则的数学函数。难道机器学习库不应该给我们一个“数学上正确”的结果吗?

罪魁祸首是浮点数运算的非结合性。也就是说,对于浮点数:

$$ (a+b)+c \neq a+(b+c) $$

(0.1 + 1e20) - 1e20
# 输出: 0.0

0.1 + (1e20 - 1e20)
# 输出: 0.1

讽刺的是,恰恰是打破了结合律,才让浮点数变得有用。

浮点数之所以有用,是因为它们允许一种“动态”的精度。为了便于解释,我们用十进制(而不是二进制),假设浮点数的格式是“尾数 $\times 10^{\text{指数}}$”,并且尾数有 3 位数字,指数有 1 位。

比如,数值 3450,我们可以精确表示为 $3.45 \times 10^3$。我们也能表示小得多的数,比如 0.486,表示为 $4.86 \times 10^{-1}$。这样,浮点数就能同时表示极大和极小的数值。在科学上,我们或许会说,浮点数让我们能保持恒定数量的“有效数字”。

如果你把两个指数相同的浮点数相加,过程和整数加法差不多。例如,$123$ ($1.23 \times 10^2$) + $456$ ($4.56 \times 10^2$) 结果是 $579$ ($5.79 \times 10^2$)。

但当我们加两个指数不同的浮点数时,比如 1230 和 23.4,会发生什么?精确结果是 1253.4。然而,我们一次只能保持 3 位数的精度。因此,浮点数加法会丢掉最后两位,得到 $1.25 \times 10^3$(也就是 1250)。

$1.23 \times 10^3$ + $2.34 \times 10^1$ = $1.2534 \times 10^3$ (精确值: 1253.4)

我们需要 3 位精度来表示 1230,也需要 3 位精度来表示 23.4。但把它们加起来,得到的数字(1253.4)需要 5 位精度才能表示。我们的浮点数格式必须舍掉末尾的 34。从某种意义上说,我们相当于在相加前,把原来的 23.4 四舍五入到了 20.0。

到了这一步,信息就已经丢失了。注意,每当我们加两个“尺度”(即指数)不同的浮点数时,都可能发生这种情况。而加指数不同的浮点数是家常便饭。实际上,如果我们能保证永远不需要不同的指数,那我们直接用整数就行了!

换句话说,每次我们以不同的顺序对浮点数求和,都可能得到完全不同的结果。举个极端的例子,对下面这个数组求和,根据顺序不同,可能产生 102 种不同的结果。

import random

vals = [1e-10, 1e-5, 1e-2, 1]
vals = vals + [-v for v in vals]

results = []
random.seed(42)
for _ in range(10000):
    random.shuffle(vals)
    results.append(sum(vals))

results = sorted(set(results))
print(f"There are {len(results)} unique results: {results}")

# 输出:
# There are 102 unique results: [-8.326672684688674e-17, -7.45931094670027e-17, ..., 8.326672684688674e-17]

虽然这是输出不一致的根本原因,但它没有直接回答非确定性从何而来。它没能帮我们理解,浮点数为什么会以不同的顺序相加,这种情况何时发生,以及如何避免。

答案,就藏在内核的实现方式之中。


为什么内核不总按相同顺序加总数字?

如前所述,一个常见的解释是“并发+浮点数”假说。该假说认为,如果并发线程完成的顺序是不确定的,并且累加的顺序依赖于线程完成的顺序(比如使用原子加法 atomic add),那么我们的累加顺序也会是不确定的。

令人困惑的是,虽然这确实会导致内核不确定,但并发(和原子加法)最终却和 LLM 推理的非确定性完全无关!要解释真正的罪魁祸首,我们先来理解为什么现代 GPU 内核很少需要原子加法。

什么时候需要原子加法?

通常,GPU 会在许多“核心”(即 SM,流式多处理器)上并发启动一个程序。由于这些核心之间没有内在的同步机制,如果它们需要相互通信,就会带来挑战。例如,如果所有核心都必须累加到同一个元素上,你可以使用“原子加法”(有时也叫“取值并加”)。原子加法是“非确定性”的——结果累加的顺序纯粹取决于哪个核心先完成计算。

具体来说,想象你用 100 个核心来对一个 100 元素的向量求和(例如 torch.sum())。虽然你可以并行加载所有 100 个元素,但最终我们必须把它们规约(reduce)成一个元素。一种方法是使用某种“原子加法”原语,硬件保证所有的加法都会被处理,但不保证顺序。

原子加法确保每个核心的贡献都会反映在最终总和中。但它不保证贡献被相加的顺序。顺序完全取决于哪个核心先完成,这是一个不确定性的属性。因此,多次执行同一个并行程序可能导致不确定的输出。

这通常就是人们所说的“非确定性”——你用完全相同的输入执行同一个内核两次,却得到了不同的结果。这被称为**“逐次运行非确定性”**(run-to-run nondeterminism),即你用完全相同的依赖项运行同一个 Python 脚本两次,却得到不同的结果。

虽然并发的原子加法确实会使内核变得非确定性,但对于绝大多数内核来说,原子加法并非必需。事实上,在 LLM 的典型前向传播中,通常一个原子加法都没有

这可能有些出人意料,毕竟并行化一个规约操作可以从原子加法中受益。原子加法之所以最终变得非必需,主要有两个原因:

  1. 沿着“批次”(batch)维度的并行性通常已经足够,我们不需要再沿着规约维度进行并行化。例如,假设我们不是对一个 100 维向量进行规约,而是并行规约 500 个这样的向量。在这种情况下,我们可以在每个核心中处理一个完整的向量,让每个核心操作不同的向量。
  2. 随着时间推移,大多数神经网络库都采用各种策略,在不牺牲性能的前提下实现确定性。例如,我们可以执行一个“分裂”(或树状)规约,把 100 个元素的规约分成五个 20 元素的规约(从而实现五路并行)。然后,为了合并剩下的五个元素,我们可以要么执行一个单独的“收尾”规约(这个过程不并行,但操作的元素足够少,开销很小),要么利用信号量(semaphore)来确保每个并发的线程块会以确定性的顺序进行累加。

由于这两个因素,对于绝大多数神经网络操作而言,避免原子加法带来的性能损失可以忽略不计。

不过,仍有少数常见操作,若要避免原子加法会有显著的性能损失。例如,PyTorch 中的 scatter_add (a[b] += c)。然而,在 LLM 中唯一常用到的是 FlashAttention 的反向传播。

但是,LLM 的前向传播不涉及任何需要原子加法的操作。因此,LLM 的前向传播实际上是“逐次运行确定性”的。

[Image illustrating deterministic server behavior with fixed inputs]

从推理服务器的角度看,它确定性的。给定完全相同的用户请求,它将总是提供相同的确定性输出。

维基百科写道:“确定性算法是指,给定一个特定的输入,将总是产生相同输出的算法。” 在这种情况下,给定完全相同的输入(即推理服务器正在处理的那些确切请求),前向传播总是产生完全相同的输出。

然而,前向传播本身是“确定性”的,并不足以保证包含它的整个系统是确定性的。例如,如果我们的请求输出依赖于并行的其他用户请求(比如批归一化 batch-norm),那会怎样?由于单个请求无法知道并行的请求会是什么,从它的角度看,我们整个 LLM 推理也是非确定性的!

事实证明,我们的请求输出确实依赖于并行的其他用户请求。不是因为我们在批次间以某种方式泄露了信息,而是因为我们的前向传播缺乏**“批次不变性”(batch invariance),导致我们请求的输出依赖于前向传播的批次大小**。


批次不变性与“确定性”

为了解释批次不变性,我们把系统简化,只看矩阵乘法(matmul)。你可以假设所有的矩阵乘法实现都是“逐次运行确定性”的。然而,它们并不是“批次不变”的。换句话说,当批次大小改变时,批次中的每个元素都可能得到不同的结果。

从数学角度看,这是一个相当不寻常的性质。矩阵乘法沿着批次中的每个元素应该是“独立”的——批次中的其他元素,以及批次有多大,都不应该影响批次中某个特定元素的计算结果。

然而,通过实验我们可以观察到,事实并非如此。

import torch
torch.set_default_device('cuda') 

B = 2048
D = 4096
a = torch.linspace(-1000, 1000, B*D).reshape(B, D)
b = torch.linspace(-1000, 1000, D*D).reshape(D, D)

# 取批次中的第一个元素做矩阵-向量乘法
out1 = torch.mm(a[:1], b)

# 做完整的矩阵-矩阵乘法,然后取结果的第一个元素
out2 = torch.mm(a, b)[:1]

print((out1 - out2).abs().max()) # tensor(1669.2500, device='cuda:0')

请注意,这“逐次运行确定性”的。如果你多次运行这个脚本,它会确定性地返回相同的结果。但它不是“硬件/软件版本不变”的——你的 GPU/PyTorch 版本可能会返回不同的值,但对于一个特定环境,它应该确定性地返回相同的值。

然而,当一个非批次不变的内核被用在一个更大的推理系统中时,整个系统就可能变得非确定性。当你向一个推理端点发出查询时,从用户的角度看,服务器的负载量实际上是“非确定性”的。负载决定了内核运行时所用的批次大小,从而改变了每个独立请求的最终结果!

[Image illustrating nondeterministic user experience due to varying server load]

虽然推理服务器本身可以声称是“确定性”的,但对单个用户来说,情况就不同了。从单个用户的角度看,其他并发用户不是系统的“输入”,而是系统的一个不确定属性。这使得 LLM 推理从每个用户的角度看是“非确定性”的。

如果你把内核不具备不变性的某个属性(即批次大小)与该属性的非确定性(即服务器负载)组合在一起,你就会得到一个非确定性的系统。

换句话说,几乎所有 LLM 推理端点都是非确定性的,主要原因就是负载(从而导致批次大小)在非确定性地变化! 这种非确定性并非 GPU 独有——由 CPU 或 TPU 提供的 LLM 推理端点同样会存在这个不确定性来源。

所以,如果我们想在推理服务器中避免非确定性,就必须让我们的内核实现批次不变性。为了理解如何实现这一点,我们先来看看为什么内核一开始就没有批次不变性。


如何让内核具有批次不变性?

要让一个 Transformer 实现具有批次不变性,我们必须让它的每一个内核都具有批次不变性。幸运的是,我们可以假设所有逐点(pointwise)操作都具有批次不变性。因此,我们只需要关注那 3 个涉及规约的操作:RMSNorm、矩阵乘法和注意力机制。

方便的是,它们的实现难度也是递增的。每一个都需要一些额外的考虑才能在保持合理性能的同时实现批次不变性。我们先从 RMSNorm 说起。

批次不变的 RMSNorm

RMSNorm 可以实现为:

# x: [batch_size, hidden_dim]
# weight: [hidden_dim]
def rms_norm(x, weight):
    return x * torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True)) * weight

批次不变性的要求是,对于每个元素,其规约顺序必须是固定的,不受内核批次大小的影响。 注意,这不意味着我们必须总是使用同一种规约策略。

因此,只有当批次大小影响到规约策略时,我们才会破坏批次不变性。

我们来看看 RMSNorm 的标准并行策略。通常,并行算法受益于最小化核心间的通信。所以,我们可以从一个策略开始:将每个批次元素分配给一个核心,如上图所示。

数据并行的 RMSNorm:理想情况下,我们希望在并行策略中避免核心间的通信。一种方法是将每个批次元素分配给一个核心,从而保证每次规约都完全在单个核心内完成。这就是所谓的“数据并行”策略,因为我们只是沿着一个不需要通信的维度进行并行化。

增加批次大小并不会影响我们的规约策略;如果 200 的批次大小能为我们的内核提供足够的并行度,那么 2000 的批次大小肯定也能。

更大批次的数据并行 RMSNorm:将数据并行策略扩展到更大的批次很简单——不是让每个核心处理一行,而是让每个核心顺序处理不同的行。这保留了批次不变性,因为每个批次元素的规约策略保持不变。

反之,减小批次大小可能会带来挑战。因为我们把每个批次元素分配给一个核心,减小批次大小最终会导致核心数多于批次元素数,使得一些核心闲置。

遇到这种情况,一个优秀的内核工程师会采用前面提到的解决方案之一(原子加法或分裂规约),以保持良好的并行度和性能。不幸的是,这改变了规约策略,使得这个内核不再具有批次不变性。

分裂规约的 RMSNorm:如果我们的批次大小很小,数据并行策略可能不再有足够的并行度来饱和我们的核心。在这种情况下,将一个规约“分裂”到多个核心上可能更高效,让我们能充分利用 GPU。然而,这失去了批次不变性,因为我们不再以相同的顺序规约每个元素。

最简单的解决方案就是完全忽略这些情况。这并非完全不合理——小批次大小意味着内核无论如何都可能很快执行完毕,所以一点减速可能不是灾难性的。

如果我们确实需要优化这种情况,一种方法是始终使用一种即使在批次大小很小时也具有足够并行度的规约策略。这样的策略在批次较大时会导致过度的并行,但能让我们在所有尺寸范围内都达到还不错的(但非峰值)性能。

批次不变的矩阵乘法

其核心思想是,你可以把矩阵乘法看作是一个逐点操作后跟着一个规约。然后,如果我们通过将输出分块成瓦片(tile)来并行化矩阵乘法,我们就得到了一个类似的“数据并行”内核策略,它将每个规约都保持在单个核心内。

数据并行的矩阵乘法:与 RMSNorm 类似,矩阵乘法的标准并行策略是“数据并行”策略,将整个规约保持在一个核心内。最直接的理解方式是将输出张量分割成二维瓦片,并将每个瓦片分配给不同的核心。每个核心随后计算属于该瓦片的所有点积,再次将整个规约在单个核心内完成。

与 RMSNorm 类似,我们的“批次”维度(M 和 N)也可能变得太小,迫使我们沿规约维度(K)进行分裂。在矩阵乘法中沿规约维度进行分裂被称为 Split-K Matmul。和 RMSNorm 一样,使用这种策略会破坏批次不变性。

Split-K 矩阵乘法:如果我们的批次维度相当小,我们可能没有足够的并行度,需要进行 Split-K 矩阵乘法。在这个例子中,我们将每个规约任务分散到两个核心上,它们会分别进行累加,最后再合并结果。

此外,矩阵乘法还有一个额外的复杂性——张量核心(tensor core)指令。对于规约,我们可以一次只操作一行,但高效的矩阵乘法内核必须一次操作一整个“瓦片”。不同的张量核心指令(比如 wgmma.mma_async.sync.aligned.m64n128k16)内部可能有不同的规约顺序。批次大小很小可能是使用不同张量核心指令的一个原因。例如,批次大小为 1 时,最快的内核通常根本不使用张量核心。

填充的张量核心指令:如果批次太小,我们可能连一个二维瓦片都放不进输出。此时,切换到更小的张量核心指令或者干脆不用张量核心会更高效!然而,这两种选择都让我们的内核无法保持批次不变。

所以,确保矩阵乘法批次不变性的最简单方法是,编译一个内核配置,并将其用于所有形状。虽然我们会损失一些性能,但这在 LLM 推理中通常不是灾难性的。特别是,Split-K 在 M 和 N 很小时最需要,而幸运的是,在我们的场景中,N(即模型维度)通常都很大!

[Image comparing performance of cuBLAS vs. a batch-invariant matmul kernel]

尽管实现了批次不变性,我们相比 cuBLAS 只损失了大约 20% 的性能。注意这还不是一个优化过的 Triton 内核。然而,性能曲线中的一些模式揭示了我们的批次不变性要求在何处损失性能。首先,在非常小的批次大小时,由于指令过大和并行度不足,我们损失了大量性能。其次,随着批次大小增加,出现了一个“锯齿”模式,这是由量化效应(瓦片和波次)引起的,通常通过改变瓦片大小来缓解。

批次不变的注意力机制

在为矩阵乘法实现了批次不变性后,注意力机制引入了两个额外的难题——因为它包含两个矩阵乘法。

  1. 与 RMSNorm 和矩阵乘法只在特征维度上规约不同,我们现在既要在特征维度上规约,也要在序列维度上规约。
  2. 因此,注意力机制必须处理各种影响序列处理方式的推理优化(如分块预填充、前缀缓存等)。

因此,要在 LLM 推理中实现确定性,我们的数值计算必须对同时处理的请求数量以及每个请求在推理引擎中如何被切分保持不变。

我们先来看看注意力的标准并行策略,这个策略最早在 FlashAttention2 中引入。与 RMSNorm 和 Matmul 类似,默认策略也是“数据并行”的。因为我们在键/值(K/V)张量上进行规约,数据并行策略只能在查询(Q)张量上进行并行化。

FlashAttention2 策略:我们沿着 Q 进行并行化,同时对 K/V 进行规约。这意味着我们的整个规约可以保持在单个核心内,使其成为另一种数据并行策略。

为了实现“批次不变性”,对于一个给定的词元,其规约顺序不能依赖于它所在序列中同时被处理的其他词元的数量。如果你像 vLLM 的 Triton 注意力内核那样,将 KV 缓存中的 K/V 值与当前正在处理的词元的 K/V 值分开处理,这就无法实现。

[Image illustrating how separate KV cache handling breaks batch invariance]

带 KV 缓存的 FlashAttention:明确地将 KV 缓存与当前 KV 值分开处理会破坏批次不变性,原因有点微妙,与“边界条件”有关。具体来说,想象你的块大小是 32,但我们当前 KV 缓存中有 80 个元素。然后我们又计算了 48 个未缓存的元素。在这种情况下,我们需要三个块来计算“P 缓存”,另外两个块来计算“P”。这总共需要五个块来完成我们的规约,而我们总共只有四个块的元素量(即 128),这肯定会改变我们的规约顺序。

要解决这个问题,我们可以在注意力内核本身之前就更新 KV 缓存和页表,确保无论正在处理多少词元,我们的键和值总是以一致的方式布局。

加上这个细节(以及前几节提到的所有东西,比如一致的瓦片大小),我们就能够实现一个批次不变的注意力实现了!

然而,这里有一个重要问题。与矩阵乘法不同,我们在 LLM 推理中遇到的注意力形状常常确实需要一个分裂规约内核,通常称为 Split-KV 或 FlashDecoding。这是因为,如果不在规约维度上并行,我们就只能在批次维度、头维度和“查询长度”维度上并行。在解码阶段,查询长度非常小(通常是 1),所以除非我们有非常大的批次大小,否则通常无法饱和 GPU。

固定数量 Split-KV 策略(即 FlashDecode):如果查询长度变得非常小(就像解码时那样),我们可能会陷入内核并行度极低的情况。此时,我们需要再次沿着规约维度——这次是 KV 维度——进行分裂。

不幸的是,这个问题不像 RMSNorm 和 Matmul 那样容易忽略。

此外,注意力机制常用的分裂规约策略也给批次不变性带来了挑战。例如,FlashInfer 的“平衡调度算法”会选择能够饱和所有 GPU 核心的最大分裂尺寸,这使得规约策略不是“批次不变”的。

要实现批次不变性,我们必须采用**“固定分裂尺寸”的策略。换句话说,我们不固定分裂的数量**,而是固定每次分裂的大小,从而得到一个可变的分裂数量。通过这种方式,我们可以保证无论正在处理多少词元,我们总是执行相同的规约顺序。

固定尺寸 Split-KV 策略:这个策略与前一个唯一的区别在于我们的分裂是“固定尺寸”的。例如,如果我们的 KV 长度是 1000,我们不会把它分成四个等长的 250 的分片,而是会把它分成三个固定尺寸 256 的分片和一个长度 232 的分片。这让我们能够保留批次不变性,因为我们的规约策略不再依赖于我们一次处理多少查询词元!


实现

我们提供了一个在 vLLM 之上实现确定性推理的演示,利用了其 FlexAttention 后端以及 torch.Library。通过 torch.Library,我们能以一种非侵入性的方式替换掉大部分相关的 PyTorch 操作符。你可以在 thinking-machines-lab/batch-invariant-ops 找到“批次不变”内核库,以及在 vLLM 中以“确定性”模式运行的示例。


实验

生成结果的非确定性有多大?

我们使用 Qwen/Qwen3-235B-A22B-Instruct-2507 模型,在温度为 0 的设置下,用提示“告诉我关于理查德·费曼的事”(非思维链模式)采样 1000 次,每次生成 1000 个词元。令人惊讶的是,我们生成了 80 种独特的回复,其中最常见的一种出现了 78 次。

观察这些回复的差异点,我们发现它们在开头 102 个词元上其实是完全相同的!第一次出现分歧是在第 103 个词元。所有的回复都生成了序列“费曼出生于 1918 年 5 月 11 日,在”,但 992 个回复接着生成了“纽约皇后区”,而有 8 个回复生成了“纽约市”。

另一方面,当我们启用批次不变内核后,所有 1000 次生成的结果都完全相同。这正是我们从数学上对采样器的期望,但没有批次不变内核,我们无法获得确定性的结果。

性能

我们尚未投入大量精力来优化批次不变内核的性能。不过,我们还是做些实验来验证其性能是否可用。

我们用一块 GPU 搭建一个运行 Qwen-3-8B 的 API 服务器,请求 1000 个序列,输出长度在 90 到 110 之间。

配置 时间(秒)
vLLM 默认 26
未优化的确定性 vLLM 55
+ 改进的注意力内核 42

性能下降大部分源于 vLLM 中的 FlexAttention 集成尚未经过深度优化。尽管如此,我们看到性能并非灾难性的。

真正的同策略(On-Policy)强化学习

正如研究人员指出的,训练和推理之间的数值差异,会暗中把我们的同策略强化学习(On-Policy RL)变成异策略(Off-Policy)强化学习。

当然,如果我们连两次相同的推理请求都无法得到逐比特相同的结果,那么在训练和推理之间获得逐比特相同的结果是不可能的。确定性推理让我们也能修改我们的训练栈,从而在采样和训练之间获得逐比特相同的结果,最终实现真正的同策略强化学习。

我们在 RLVR(视觉推理的强化学习)设置下,在 Bigmath 数据集上进行了实验。

如果我们不进行异策略校正(即重要性权重)进行训练,我们的奖励在训练中途就会崩溃。但是,如果我们能在采样器和训练器之间实现逐比特相同的结果,我们就完全是同策略的(即 KL 散度为 0),也能够顺利训练。

我们还可以绘制采样器和训练器之间对数概率的 KL 散度,三组实验的行为截然不同。当使用重要性权重时,KL 散度维持在 0.001 左右,偶尔有尖峰。然而,使用重要性权重最终会导致 KL 散度在奖励崩溃的同时飙升。当然,在运行“真正的同策略强化学习”时,我们的 KL 散度稳定在 0,表明训练策略和采样策略之间没有任何分歧。

注意,没有重要性权重的运行在第 318 步附近有一个显著的损失尖峰,与之对应的是对数概率的 KL 散度也出现了一个尖峰。同时,使用异策略校正或运行“真正的同策略”都能让 RL 顺利进行。显示“真正的同策略”的蓝线不是 bug——它就是一条平直的 0 线。


结论

现代软件系统包含许多抽象层。在机器学习中,当我们遇到非确定性和微小的数值差异时,往往很容易选择视而不见。毕竟,我们的系统已经是“概率性”的了,多一点非确定性又有什么关系呢?把失败的单元测试里的 atol/rtol 容忍度调高一点又有什么问题呢?训练器和采样器之间的对数概率差异可能不是真正的 bug,对吧?

我们拒绝这种失败主义。只要稍加努力,我们能够理解非确定性的根本原因,甚至解决它们!我们希望这篇博文能为社区提供一个坚实的理解,关于如何解决我们推理系统中的非确定性问题,并激励其他人去完全理解他们自己的系统。