MNN fp32 <--> fp8

static const int FP32_EXP_BIAS = 127;
static const int FP8_EXP_BIAS = 24;   // [0, 31] --> [-24, 7] --> [1 / 2^24, 2^7]
void MNNFp32ToFp8(uint8_t* dst, const float* src, size_t size) {
    for (int i = 0; i < size; i++) {

        // 1. 获取 FP32 的二进制表示(IEEE 754 32-bit)
        // float 在内存中占 4 字节,将其 reinterpret 成 uint32_t 方便位操作
        uint32_t rawData = *((uint32_t *)(&src[i]));

        // 2. 提取符号位 S(最高位 31)
        // 0 表示正数,1 表示负数
        uint32_t sign = (rawData >> 31) & 1U;

        // 3. 提取指数位 E (8 bit)
        // FP32 的指数在 bit 23~30
        uint32_t exp = (rawData >> 23) & 0xFFU;

        // 4. 提取尾数位 M
        // FP32 尾数有 23 bit,这里只取高 2 bit 来适配 FP8(s1e5m2)
        uint32_t mant = (rawData >> 21) & 0x3U;

        // 5. 计算 FP32 的真实指数
        // FP32 的 Bias = 127,减去 Bias 得到真实指数
        int realExp = (int)exp - FP32_EXP_BIAS;

        // 6. 指数范围截断(clamp)
        // FP8(s1e5m2) 指数范围 [-15,16],Bias = 15
        // 下限截断:如果 realExp 小于 -15,就设置为 -15
        realExp = ALIMAX(realExp, 0 - FP8_EXP_BIAS);

        // 上限截断:如果 realExp 大于 16,就设置为 16
        realExp = ALIMIN(realExp, 31 - FP8_EXP_BIAS);

        // 7. 加回 FP8 Bias,得到 FP8 的指数 E'
        exp = (uint32_t)(realExp + FP8_EXP_BIAS);

        // 8. 拼接 FP8 (8 bit)
        // FP8 格式: [1bit sign | 5bit exponent | 2bit mantissa]
        dst[i] = (int8_t)((sign << 7) | (exp << 2) | mant);

   
    }
}

void MNNFp8ToFp32(float* dst, const uint8_t* src, size_t size) {
    for (int i = 0; i < size; i++) {
        uint32_t sign = (src[i] >> 7) & 1U;
        uint32_t exp = (int)((src[i] >> 2) & 0x1fU);
        uint32_t mant = (src[i] & 3U) << 21;
        int realExp = (int)exp - FP8_EXP_BIAS;
        exp = (uint32_t)(realExp + FP32_EXP_BIAS);
        uint32_t rawData = (sign << 31) | (exp << 23) | mant;
        dst[i] = *((float *)(&rawData));
    }
}
// fp16 <--> fp8
void MNNFp16ToFp8(uint8_t* dst, const uint16_t* src, size_t size) {
#ifdef MNN_USE_NEON
#ifdef __aarch64__
    int loopN = size / 16;
    for (int i = 0; i < loopN; i++) {
        uint8x16_t v1 = vld1q_u8((uint8_t*)(src + i * 16));
        uint8x16_t v2 = vld1q_u8((uint8_t*)(src + i * 16 + 8));
        uint8x16_t res = vuzp2q_u8(v1, v2);
        vst1q_u8(dst + i * 16, res);
    }
    for (int i = loopN * 16; i < size; i++) {
        dst[i] = static_cast<int8_t>(src[i] >> 8);
    }
#else
    int loopN = size / 8;
    for (int i = 0; i < loopN; i++) {
        uint16x8_t vec = vld1q_u16(src + i * 8);
        uint8x8_t  res = vshrn_n_u16(vec, 8);
        vst1_u8(dst + i * 8, res);
    }
    for (int i = loopN * 8; i < size; i++) {
        dst[i] = static_cast<int8_t>(src[i] >> 8);
    }
#endif // ARM64
#else
    for (int i = 0; i < size; i++) {
        dst[i] = static_cast<int8_t>(src[i] >> 8);
    }
#endif // USE_NEON
}
void MNNFp8ToFp16(uint16_t* dst, const uint8_t* src, size_t size) {
#ifdef MNN_USE_NEON
    int loopN = size / 8;
    for (int i = 0; i < loopN; i++) {
        uint8x8_t vec8x8 = vld1_u8(src + i * 8);
        uint16x8_t vec16x8 = vshll_n_u8(vec8x8, 8);
        vst1q_u16(dst + i * 8, vec16x8);
    }
    for (int i = loopN * 8; i < size; i++) {
        dst[i] = static_cast<int16_t>(src[i]) << 8;
    }
#else
    for (int i = 0; i < size; i++) {
        dst[i] = static_cast<int16_t>(src[i]) << 8;
    }
#endif // USE_NEON
}