← 返回列表

MemAgent: Reshaping Long-Context LLM with Multi-Conv RL-based Memory Agent

作者 ByteDance Seed, Tsinghua University AIR, SIA-Lab
年份 2025
会议/期刊 arXiv 2025
评分
标签 长上下文 强化学习 记忆机制
摘要 固定长度 token 记忆 + 分段读写 + Multi-Conv DAPO 强化学习,32K 训练外推至 3.5M token,精度衰减 <5%,O(N) 线性复杂度

核心思想

如何让 LLM 处理百万级 token 的超长文档?传统方法要么扩展上下文窗口($O(n^2)$ 注意力成本),要么做压缩但损失信息。

MemAgent 提出全新范式:给 LLM 一块固定长度的 token 记忆(1024 tokens),让模型分段读取文档,每读完一段就用 RL 学到的策略更新记忆 → 最终只用记忆 + 查询生成答案。

关键结果:在 32K 文档上训练,能外推到 3.5M token,精度衰减 <5%。

背景知识

长上下文处理的三种范式

范式 方法 复杂度 问题
扩展窗口 RoPE 外推、稀疏注意力 $O(n^2)$(或降低常数) 计算成本随长度平方增长
压缩表示 KV-cache 压缩、隐式记忆 $O(n)$ 但固定 不可解释,压缩质量不可控
分段处理 MemAgent(本文) $O(N)$ 线性 需要学习何时保留/丢弃信息

为什么需要 RL

方法 记忆更新策略 问题
启发式规则 固定摘要/截断 无法适应不同任务
SFT 监督 模仿专家轨迹 缺乏”压缩-保留”的权衡信号
RL 端到端优化最终答案质量 直接优化”信息保留 vs 压缩”的权衡

方法详解

1. 记忆读写机制

将文档分成 $K$ 个 chunk,每个长度 $C$,记忆长度固定为 $M$(实验中 $M = 1024$):

\[p(m^k | c^k, m^{k-1})\]
  • $c^k$:第 $k$ 个文档片段
  • $m^{k-1}$:上一轮的记忆内容(token 序列)
  • $m^k$:更新后的记忆

两个模块

  1. 上下文处理模块:逐 chunk 读取,每次更新记忆
  2. 答案生成模块:仅用查询 + 最终记忆生成答案

2. 计算复杂度分析

方法 每步成本 总成本
全注意力 $O((q+c+o)^2)$ 随长度平方增长
MemAgent $O(C + M)$ $O(N)$ 线性增长

从 8K 到 4M token,MemAgent 的 FLOPS 线性扩展。

3. Multi-Conv DAPO 训练

标准 DAPO/GRPO 假设每个样本只有一个对话。MemAgent 需要处理多轮对话(每个 chunk 是一轮)→ 扩展维度。

优势函数

\[\hat{A}(i,j,t) = r_i - \text{mean}(\{R_i\}_{i=1}^G)\]

奖励从最终答案均匀回传到所有轮次。

损失函数(从 (group, token) 扩展到 (group, conversation, token)):

\[J_{\text{DAPO}}(\theta) = \mathbb{E}\left[\frac{1}{\sum|o_{i,j}|} \sum_{i,j,t} \left(C_{i,j,t} - \beta D_{\text{KL}}\right)\right]\]

其中 PPO-clip 项:

\[C_{i,j,t} = \min\left(r_{i,j,t} \hat{A}_{i,j,t}, \text{clip}(r_{i,j,t}, 1 \pm \epsilon) \hat{A}_{i,j,t}\right)\]

4. 奖励设计

单答案任务

\[R = \max_y \mathbb{I}(\text{is\_equiv}(y, \hat{y}))\]

多值任务

\[R = \frac{|\{y \in Y | y \in \hat{y}\}|}{|Y|}\]

5. 训练配置

  • 基座模型:Qwen2.5-7B/14B-Instruct
  • 训练数据:32,768 个 HotpotQA 样本(28K tokens)
  • 上下文分配:1024 查询 + 5000 文档 + 1024 记忆 + 1024 输出
  • 超参数:KL 因子 1e-3,学习率 1e-6,rollout batch 128-256,group size 16

实验结果

主要结果:RULER-HotpotQA 准确率

模型 7K 28K 112K 448K 896K 3.5M
RL-MemAgent-14B 83.6% 84.4% 76.6% 75.0% 77.3% 78.1%
RL-MemAgent-7B 82.0% 78.9% 79.7% 74.2% 76.6% 71.1%
QwenLong-L1-32B 72.7% 72.7% 31.3% 13.3% 11.7% N/A
Qwen2.5-14B-1M 60.2% 50.0% 50.0% 8.6% 0.0% N/A

关键发现

  • MemAgent-14B 在 3.5M token 仍保持 78.1%
  • 传统 1M 上下文模型在 896K 就完全崩溃(0%)
  • 32B 参数的 QwenLong 在 112K 后急剧下降

OOD 泛化:RULER 任务

MemAgent-14B 在 8K-512K 上下文的 RULER 任务上达到 >95% 准确率,包括:

  • 单/多键 NIAH(大海捞针)
  • 多值提取
  • 变量追踪
  • 高频词提取

消融实验

配置 效果
无记忆 112K 后急剧下降
有记忆但无 RL 中等改善,仍随长度衰减
有记忆 + RL 所有长度一致,衰减 <5%

RL 训练是关键——没有 RL,记忆机制只提供结构支撑,无法优化压缩-保留权衡。

个人思考

  1. “固定长度 token 记忆” 是优雅的简约设计:不需要修改 Transformer 架构,不需要特殊的注意力机制 → 任何 LLM 都可以即插即用。
  2. 32K → 3.5M 的外推是惊人的 100× 泛化:说明 RL 学到的不是”如何处理 28K 文档”,而是通用的信息压缩-保留策略
  3. “记忆可解释” 是独特优势:与 KV-cache 压缩不同,token 级记忆可以被人类阅读和检查 → 可以验证模型保留了什么信息。
  4. Multi-Conv DAPO 解决了多轮对话 RL 的信用分配问题:最终奖励均匀回传到所有轮次,虽然简单但有效。
  5. 传统 1M 上下文模型在 896K 就崩溃说明”扩展窗口”路线的根本局限——注意力机制在超长序列上的信息检索能力是瓶颈,不仅是计算成本。
← 返回列表