static const int FP32_EXP_BIAS = 127;
static const int FP8_EXP_BIAS = 24; // [0, 31] --> [-24, 7] --> [1 / 2^24, 2^7]
void MNNFp32ToFp8(uint8_t* dst, const float* src, size_t size) {
for (int i = 0; i < size; i++) {
// 1. 获取 FP32 的二进制表示(IEEE 754 32-bit)
// float 在内存中占 4 字节,将其 reinterpret 成 uint32_t 方便位操作
uint32_t rawData = *((uint32_t *)(&src[i]));
// 2. 提取符号位 S(最高位 31)
// 0 表示正数,1 表示负数
uint32_t sign = (rawData >> 31) & 1U;
// 3. 提取指数位 E (8 bit)
// FP32 的指数在 bit 23~30
uint32_t exp = (rawData >> 23) & 0xFFU;
// 4. 提取尾数位 M
// FP32 尾数有 23 bit,这里只取高 2 bit 来适配 FP8(s1e5m2)
uint32_t mant = (rawData >> 21) & 0x3U;
// 5. 计算 FP32 的真实指数
// FP32 的 Bias = 127,减去 Bias 得到真实指数
int realExp = (int)exp - FP32_EXP_BIAS;
// 6. 指数范围截断(clamp)
// FP8(s1e5m2) 指数范围 [-15,16],Bias = 15
// 下限截断:如果 realExp 小于 -15,就设置为 -15
realExp = ALIMAX(realExp, 0 - FP8_EXP_BIAS);
// 上限截断:如果 realExp 大于 16,就设置为 16
realExp = ALIMIN(realExp, 31 - FP8_EXP_BIAS);
// 7. 加回 FP8 Bias,得到 FP8 的指数 E'
exp = (uint32_t)(realExp + FP8_EXP_BIAS);
// 8. 拼接 FP8 (8 bit)
// FP8 格式: [1bit sign | 5bit exponent | 2bit mantissa]
dst[i] = (int8_t)((sign << 7) | (exp << 2) | mant);
}
}
void MNNFp8ToFp32(float* dst, const uint8_t* src, size_t size) {
for (int i = 0; i < size; i++) {
uint32_t sign = (src[i] >> 7) & 1U;
uint32_t exp = (int)((src[i] >> 2) & 0x1fU);
uint32_t mant = (src[i] & 3U) << 21;
int realExp = (int)exp - FP8_EXP_BIAS;
exp = (uint32_t)(realExp + FP32_EXP_BIAS);
uint32_t rawData = (sign << 31) | (exp << 23) | mant;
dst[i] = *((float *)(&rawData));
}
}
// fp16 <--> fp8
void MNNFp16ToFp8(uint8_t* dst, const uint16_t* src, size_t size) {
#ifdef MNN_USE_NEON
#ifdef __aarch64__
int loopN = size / 16;
for (int i = 0; i < loopN; i++) {
uint8x16_t v1 = vld1q_u8((uint8_t*)(src + i * 16));
uint8x16_t v2 = vld1q_u8((uint8_t*)(src + i * 16 + 8));
uint8x16_t res = vuzp2q_u8(v1, v2);
vst1q_u8(dst + i * 16, res);
}
for (int i = loopN * 16; i < size; i++) {
dst[i] = static_cast<int8_t>(src[i] >> 8);
}
#else
int loopN = size / 8;
for (int i = 0; i < loopN; i++) {
uint16x8_t vec = vld1q_u16(src + i * 8);
uint8x8_t res = vshrn_n_u16(vec, 8);
vst1_u8(dst + i * 8, res);
}
for (int i = loopN * 8; i < size; i++) {
dst[i] = static_cast<int8_t>(src[i] >> 8);
}
#endif // ARM64
#else
for (int i = 0; i < size; i++) {
dst[i] = static_cast<int8_t>(src[i] >> 8);
}
#endif // USE_NEON
}
void MNNFp8ToFp16(uint16_t* dst, const uint8_t* src, size_t size) {
#ifdef MNN_USE_NEON
int loopN = size / 8;
for (int i = 0; i < loopN; i++) {
uint8x8_t vec8x8 = vld1_u8(src + i * 8);
uint16x8_t vec16x8 = vshll_n_u8(vec8x8, 8);
vst1q_u16(dst + i * 8, vec16x8);
}
for (int i = loopN * 8; i < size; i++) {
dst[i] = static_cast<int16_t>(src[i]) << 8;
}
#else
for (int i = 0; i < size; i++) {
dst[i] = static_cast<int16_t>(src[i]) << 8;
}
#endif // USE_NEON
}