ARM 中常用的乘累加(MAC)指令总结
在现代 CPU 和 AI 加速器的指令集中,乘-累加(Multiply-Accumulate, MAC) 是最核心的计算操作之一。
无论是 卷积、矩阵乘法 还是 深度学习中的 GEMM,其本质都是大量的 MAC 运算。
在 ARM 架构下,针对不同的数据类型和应用场景,演进出了多种 乘累加指令。本文将对常见的几类进行梳理,并配合伪代码展示它们的计算方式。
1. FMLA —— 基础乘累加
FMLA (Floating-point Multiply-Accumulate) 是最基础的向量乘累加指令。
-
指令集:
>= ARMv7 (float32),>= ARMv8 (float16) -
常见 NEON 接口:
float32x4_t vfmaq_f32(float32x4_t a, float32x4_t b, float32x4_t c)float16x8_t vfmaq_f16(float16x8_t a, float16x8_t b, float16x8_t c)
-
运算逻辑:逐元素乘加
for (int i = 0; i < N; i++) { dst[i] += src1[i] * src2[i]; }
典型用途:浮点向量加速,如卷积核点积、向量化计算。
2. DOT —— 向量点积
为了更高效地处理 int8 量化推理,ARM 在 ARMv8.2 中引入了 DOT (Dot Product) 指令。
-
指令集:
>= ARMv8.2 -
常见 NEON 接口:
int32x4_t vdotq_s32(int32x4_t r, int8x16_t a, int8x16_t b)int32x4_t vusdotq_s32(int32x4_t r, uint8x16_t a, int8x16_t b)
-
运算逻辑:分块点积 (4 × 4 → int32)
for (int i = 0; i < 4; i++) { for (int j = 0; j < 4; j++) { dst[i] += src1[i * 4 + j] * src2[i * 4 + j]; } }
典型用途:卷积、点积,特别适合 int8 量化模型。
3. MMLA —— 矩阵乘累加
在 ARMv8.6 之后,进一步推出了 MMLA (Matrix Multiply-Accumulate) 指令,可以直接把 小块向量视为矩阵,一次性完成 2×8 × 8×2 的矩阵乘法。
-
指令集:
>= ARMv8.6 -
常见 NEON 接口:
int32x4_t vmmlaq_s32(int32x4_t r, int8x16_t a, int8x16_t b)int32x4_t vusmmlaq_s32(int32x4_t r, uint8x16_t a, int8x16_t b)
-
运算逻辑:小矩阵乘法
for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { for (int k = 0; k < 8; k++) { dst[i * 2 + j] += src1[i * 8 + k] * src2[j * 8 + k]; } } }
典型用途:高效的 int8 GEMM,加速深度学习推理。
4. BFMMLA —— BFloat16 矩阵乘累加
为了更好地支持 AI 训练与推理,ARMv8.6 又引入了 BFMMLA (BFloat16 Matrix Multiply-Accumulate),专门处理 bfloat16 × bfloat16 → float32。
-
指令集:
>= ARMv8.6 -
常见 NEON 接口:
float32x4_t vbfmmlaq_f32(float32x4_t r, bfloat16x8_t a, bfloat16x8_t b)
-
运算逻辑:小矩阵乘法
for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { for (int k = 0; k < 4; k++) { dst[i * 2 + j] += src1[i * 4 + k] * src2[j * 4 + k]; } } }
典型用途:深度学习训练(bfloat16 是目前主流 AI 训练精度格式)。
| 指令 | 指令集 | 数据类型 | 运算规模 | 主要用途 |
|---|---|---|---|---|
| FMLA | ARMv7/ARMv8 | FP16/FP32 | 逐元素 | 浮点向量运算 |
| SDOT/USDOT | ARMv8.2 | int8/uint8 → int32 | 4×4 点积 | int8 卷积、点积 |
| SMMLA/USMMLA | ARMv8.6 | int8/uint8 → int32 | 2×2×8 矩阵 | int8 GEMM |
| BFMMLA | ARMv8.6 | bfloat16 → float32 | 2×2×4 矩阵 | 深度学习 (bfloat16) |