智源社区 前天 20:18
矩阵乘法可以算得更快了!港中文10页论文证明:能源、时间均可节省
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

香港中文大学的研究团队提出了一种名为RXTX的新算法,专门针对矩阵与其转置的乘积(XXt)计算进行了优化。这项仅10页的论文通过结合机器学习搜索和组合优化技术,显著提升了计算效率。RXTX算法在数据分析、芯片设计、无线通信和LLM训练等领域具有潜在应用价值。实验结果表明,在处理大型矩阵时,RXTX的运算速度优于传统算法,为加速大模型训练和推理提供了新的思路。

💡RXTX算法的核心在于针对XXt矩阵乘法的优化。该算法通过将大矩阵分解为4x4子块,并结合机器学习搜索和组合优化技术,减少了计算量和运算次数,从而提升了计算效率。

⚙️RXTX算法的关键步骤包括分块与递归调用、对称乘积计算以及结果组合。它通过8次递归调用处理子问题,并计算26个一般矩阵乘积,最终得到XXt矩阵的结果。这种设计降低了渐近乘法常数,提高了计算速度。

📈实验结果表明,RXTX在计算XXt时具有显著优势。与传统算法相比,RXTX在乘法次数和总运算量上均有优化,尤其是在处理大型矩阵时。例如,在6144x6144矩阵的测试中,RXTX的平均运行时间比BLAS的默认实现快9%。

🤖RXTX算法的发现得益于机器学习与组合优化的结合。通过强化学习生成候选乘积,并利用MILP枚举与筛选以及大邻域搜索迭代,逐步优化算法,从而实现了对XXt计算的加速。

天下苦大模型矩阵乘法久矣。

毕竟不论是训练还是推理过程,矩阵乘法作为最主要的计算操作之一,往往都需要消耗大量的算力。

那么就没有一种更“快、好、省”的方法来搞这事儿吗?

有的,香港中文大学最新一篇仅10页的论文,便提出了一种新算法:

论文作者之一的Dmitry Rybin表示:

这项研究对数据分析、芯片设计、无线通信和LLM训练都有着深远的影响!

这么算矩阵乘法,更快!

矩阵乘法是计算机科学和数值线性代数中的核心问题之一。

自从Strassen和Winograd的开创性工作以来,研究者们一直在探索如何减少矩阵乘法所需的计算量。

尽管这类运算在统计、数据分析、深度学习和无线通信等领域有着广泛应用,例如协方差矩阵的计算和线性回归中的关键步骤,但对于具有特殊结构的矩阵乘法(如计算矩阵与其转置的乘积XXt)的研究相对较少。

从理论角度看,计算XXt与一般矩阵乘法具有相同的渐近复杂度,因此只能通过常数因子优化来提升速度。

因此,这篇论文《XXt Can Be Faster》提出了一种名为RXTX的新算法,通过结合机器学习搜索方法和组合优化技术,显著提升了XXt的计算效率。

我们先来了解一下RXTX。

整体来看,这个基于4×4分块矩阵的递归乘法,通过机器学习搜索与组合优化相结合的方法发现。

算法主要包含以下关键步骤:

    分块与递归调用
    :将矩阵X划分为16个4×4子块,通过8次递归调用处理子问题,并计算26个一般矩阵乘积m1至m26
    对称乘积计算
    :直接计算8个子块的对称乘积s1至m8
    结果组合
    :通过线性组合上述乘积结果,得到最终的XXt矩阵各分块元素C11至C44


与此前最先进的算法(基 Strassen的递归分治)相比,RXTX的递归关系式为 R(n)=8R(n/4) + 26M(n/4),而原算法为 S(n) = 4S(n/2) + 2M(n/2)。

这一设计使得RXTX的渐近乘法常数为 26/41≈0.6341,比原算法的2/3≈0.6667降低了约5%。

接下来,我们来看下乘法次数与运算总量分析。

通过论文中的定理1的推导,RXTX的乘法次数表达式为:

实验数据表明,当n为4的幂次时,RXTX的乘法次数比原算法低5%,且随着n增大,这一优势持续保持:


通过优化加法步骤(利用公共子表达式减少加法次数),RXTX的总运算量表达式为:

而原算法的总运算量包含对数项,导致其增长更快。

实验显示,当n≥256时,RXTX的总运算量优于原算法;当n≥1024时,显著优于朴素算法:


在6144×6144矩阵的测试中,RXTX的平均运行时间为2.524秒,比BLAS的默认实现快9%,且在99%的测试中表现更优:

尽管运行时间受硬件和内存管理影响,但理论分析表明,当n≥256时,RXTX即可展现速度优势。

值得一提的是,RXTX的发现得益于机器学习与组合优化的结合,具体流程如下:

    RL代理生成候选乘积:通过强化学习策略生成大量可能的秩-1双线性乘积。

    MILP枚举与筛选:

      MILP-A:枚举候选乘积与目标表达式(XXt的各分块)之间的线性关系。
      MILP-B:选择最小的乘积子集,确保所有目标表达式可通过线性组合表示。

    大邻域搜索迭代:通过迭代优化,逐步减少冗余乘积,提升算法效率。

这一方法借鉴了AlphaTensor的思路,但通过限制候选空间为二维张量,显著降低了计算复杂度,使得MILP求解器(如 Gurobi)能够高效处理。

论文地址:
https://arxiv.org/abs/2505.09814

参考链接:
[1]https://x.com/DmitryRybin1/status/1923349883945181392
[2]https://x.com/vikhyatk/status/1923541713618129273

—  —

📪 量子位AI主题策划正在征集中!欢迎参与专题365行AI落地方案,一千零一个AI应或与我们分享你在寻找的AI产品,或发现的AI新动向

💬 也欢迎你加入量子位每日AI交流群,一起来畅聊AI吧~


一键关注 👇 点亮星标

科技前沿进展每日见

一键三连「点赞」「转发」「小心心」

欢迎在评论区留下你的想法!


内容中包含的图片若涉及版权问题,请及时与我们联系删除

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

矩阵乘法 RXTX算法 机器学习 大模型
相关文章