掘金 人工智能 05月28日 15:18
RWKV-8 预览之 DeepEmbed:对端侧友好的稀疏设计,解决 MoE 显存占用
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

RWKV-8 "Heron" 架构引入了 DeepEmbed 技术,旨在提升大模型推理性能,特别是针对端侧设备。DeepEmbed 通过为词表中每个token训练可学习的高维向量,在推理时进行逐通道的乘性调制,从而减少显存占用。该技术允许稀疏的大模型在RAM/SSD中存储,并提升模型存储和调用世界知识的能力。DeepEmbed 不仅提升了模型性能,还在训练和推理过程中提供了高效的资源利用方式。

💡 DeepEmbed 的核心思想是在模型的每一层FFN中,为词表中的每个token训练一个可学习的高维向量,这些向量在推理时存储于RAM/SSD中,从而减少显存占用。

🔑 DeepEmbed 的工作原理是,在推理时,模型根据token index预读embedding向量,用于对FFN输出进行逐通道的乘性调制。这些向量构成一个稀疏的知识库,提升模型对token语义的感知能力。

⚙️ DeepEmbed 的实现方式是在FFN的输出或中间结果上,通过self.deepemb进行乘法缩放。该技术不直接修改FFN的参数,而是以token为单位,替代了原本基于输入动态生成的门控向量,实现对通道维度的调制。

📈 实验结果表明,更大规模的 DeepEmbed 能带来更显著的loss改善,并且可以结合LoRA技术降低显存和训练开销。

🌍 DeepEmbed 技术公布后,受到了广泛关注,包括 ViT 作者 Lucas Beyer。RWKV 社区提供了丰富的资源,方便用户了解和参与 RWKV 模型的研究。

RWKV-8 "Heron" 是我们的下一代架构,具有多个全新技术。在此我们首先公布其中的 DeepEmbed 技术,它可以实现类似 MoE 的优秀推理性能,同时无需占用显存,甚至无需占用内存,可以让稀疏的大模型真正部署到所有端侧设备。

下文将详细介绍 DeepEmbed 的原理与实现:

什么是 DeepEmbed

DeepEmbed 在模型的每一层 FFN 中为词表中的每个 token 训练一个可学习的高维向量,这可以写成 Embed 层。这些向量在训练阶段可被学习,而在推理阶段可存储于 RAM/SSD 中,对于每个 token 只需预读极少量参数,从而显著减少显存占用

推理时,模型根据 token index 可提前预读本层的 embedding 向量,用于对 FFN 输出进行逐通道的乘性调制(channelwise scaling)。

这些基于 token 的 embedding 向量构成了一个规模庞大但稀疏的知识库,能够显著提升模型存储和调用世界知识的能力。尽管这些向量看似增加了模型参数量,但不需要占用显存,且在训练过程中可通过 TP(Tensor Parallelism)避免 DP(Data Parallelism)中梯度同步的带宽开销,并可进一步 offload 至 RAM 或 SSD。

在端侧推理场景下,这些向量同样可存储于内存中,或通过 mmap 等机制直接从硬盘按需加载。每个 token 仅引入几十 KB 的额外访存开销,使该机制非常适合在边缘设备上部署。

DeepEmbed 代码示例

原始 ReLuSq FFN:

x = torch.relu(self.key(x)) ** 2  return self.value(x)

DeepEmbed_1x ReLuSq FFN:

self.deepemb = nn.Embedding(d_vocab, d_emb)...x = torch.relu(self.key(x)) ** 2return self.value(x) * self.deepemb(idx)

DeepEmbed_4x ReLuSq FFN(效果更佳,参数更多):

self.deepemb = nn.Embedding(d_vocab, d_emb * 4)...x = torch.relu(self.key(x)) ** 2return self.value(x * self.deepemb(idx))

从代码可以看出,self.deepemb 用于对 FFN 的输出或中间结果进行乘法缩放。这里的 idx 是输入的 token 索引。需要:注意将 self.deepemb 向量初始化为常数 1,以确保训练初期不会干扰模型行为。

self.deepemb 向量并不直接修改 FFN 的参数,而是以 token 为单位,替代了原本基于输入动态生成的门控向量在通道维度上的乘性调制功能,使模型对 token 的语义或类别具备更强的感知能力。

由于查表操作在推理时不占用 VRAM,这些向量在参数量层面几乎是"免费的"。因此,还可以进一步引入 n-gram(如 bigramtrigram),提升模型对词组/片段的建模能力。如果词表规模较大,也可结合 LoRA 技术降低显存和训练开销。

实验结果表明:相较于 baseline,更大规模的 DeepEmbed 能够带来更显著的 loss 改善

DeepEmbed 技术在 twitter 公布后,引发多方关注, 包括 ViT 作者 Lucas Beyer。

未来我们会逐渐公布 RWKV-8 的更多创新技术。

加入 RWKV 社区

欢迎大家加入 RWKV 社区,可以从 RWKV 中文官网了解 RWKV 模型,也可以加入 RWKV 论坛、QQ 频道和 QQ 群聊,一起探讨 RWKV 模型。

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

RWKV-8 DeepEmbed 大模型 端侧推理 模型优化
相关文章