CANN MLA Prolog算子文档
MlaProlog【免费下载链接】cann-recipes-infer本项目针对LLM与多模态模型推理业务中的典型模型、加速算法提供基于CANN平台的优化样例项目地址: https://gitcode.com/cann/cann-recipes-infer产品支持情况产品是否支持Atlas A2 推理系列产品√Atlas A3 推理系列产品√Ascend 950PR/Ascend 950DT 推理系列产品√功能说明MLA Prolog 模块将hidden states $x$ 转换为 $Query$和 ${Key-Value}$。计算公式$Query(q)$ 的计算 Query 的计算包括两次采样和 RmsNorm其中第二次 RmsNorm 权重恒为 1最后对 -1 轴的后 rope_dim 维度进行 inplace interleaved rope 计算$$ c^Q RmsNorm(x wq_a) $$$$ q RmsNorm(c^Q wq_b) $$$$ q[..., -rope_dim:] ROPE(q[..., -rope_dim:]) $$$Key-Value(kv)$ 的计算 kv 的计算包括一次下采样和 RmsNorm最后对 -1 轴的后 rope_dim 维度进行 inplace interleaved rope 计算$$ kv RmsNorm(x wkv) $$$$ kv[..., -rope_dim:] ROPE(kv[..., -rope_dim:]) $$函数原型torch.ops.pypto.mla_prolog_quant( token_x, wq_a, wq_b, wkv, rope_cos, rope_sin, gamma_cq, gamma_ckv, wq_b_scale ) - (Tensor, Tensor, Tensor, Tensor)参数说明token_xTensor公式中用于计算Query和Key-Value的输入tensor不支持非连续的 Tensor数据格式支持ND数据类型支持bfloat16shape为[t, h]。wq_aTensor公式中用于计算Query的下采样权重矩阵$wq_a$数据格式支持NZ/ND数据类型支持bfloat16shape为[h, q_lora_rank]。wq_bTensor公式中用于计算Query的上采样权重矩阵$wq_b$数据格式支持NZ/ND数据类型支持int8shape为[q_lora_rank, num_heads*head_dim]。wkvTensor公式中用于计算Key-Value的下采样权重矩阵$wkv$数据格式支持NZ/ND数据类型支持bfloat16shape为[h, head_dim]。rope_cosTensor用于计算旋转位置编码的余弦参数矩阵不支持非连续的 Tensor数据格式支持ND数据类型支持bfloat16shape为[t, rope_dim]。rope_sinTensor用于计算旋转位置编码的正弦参数矩阵不支持非连续的 Tensor数据格式支持ND数据类型支持bfloat16shape为[t, rope_dim]。gamma_cqTensor计算$c^Q$的RmsNorm公式中的$\gamma$参数不支持非连续的 Tensor数据格式支持ND数据类型支持bfloat16shape为[q_lora_rank]。gamma_ckvTensor计算$c^{KV}$的RmsNorm公式中的$\gamma$参数不支持非连续的 Tensor数据格式支持ND数据类型支持bfloat16shape为[head_dim]。wq_b_scaleTensor用于矩阵乘wq_b后反量化操作的per-channel参数不支持非连续的 Tensor。数据格式支持ND数据类型支持floatshape为[num_heads*head_dim, 1]。返回值说明q_outTensor公式中Query的输出tensor对应公式中的$q$不支持非连续的 Tensor。数据格式支持ND数据类型支持bfloat16shape为[t, num_heads, head_dim]。kv_outTensor公式中Key-Value的输出tensor对应公式中的$kv$不支持非连续的 Tensor。数据格式支持ND数据类型支持bfloat16shape为[t, head_dim]。qr_outTensor公式中Query做完第一次rmsnorm和quant后的输出tensor对应公式中的$c^Q$不支持非连续的 Tensor数据格式支持ND数据类型支持int8, shape为[t, q_lora_rank]。qr_scale_outTensor公式中Query做完第一次rmsnorm后的输出tensor对应公式中的$c^Q$不支持非连续的 Tensor数据格式支持ND数据类型支持float32, shape为[t, 1]。约束说明该接口支持推理场景下使用。该接口支持aclgraph模式。head_dim支持512h支持4096q_lora_rank支持1024num_heads支持64rope_dim支持64。t值域范围支持[1, 64k]。950PR/DT上暂不支持int8量化版本。非量化实现可以参考example。调用方法量化 python3 ops/pypto_python/example/test_mla_prolog_quant_pypto.py 非量化 python3 ops/pypto_python/example/test_mla_prolog_pypto.py【免费下载链接】cann-recipes-infer本项目针对LLM与多模态模型推理业务中的典型模型、加速算法提供基于CANN平台的优化样例项目地址: https://gitcode.com/cann/cann-recipes-infer创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考