DotGeneralOp 到 Ascend Op 的优化转换【免费下载链接】xla-npuXLA-NPU 是一个面向华为昇腾NPU硬件的 XLA后端实现。本项目通过接入OpenXLA/XLA开源项目将XLA开源生态与华为 CANN软件栈集成对接JAX框架。JAX框架运行时可以直接加载XLA-NPU使得基于JAX框架开发的模型可以运行在昇腾NPU上提供推理场景图编译加速能力。项目地址: https://gitcode.com/cann/xla-npu问题分析从日志和错误信息分析发现 Ascend 的 MatMul 操作对 batch 维度的处理存在问题原始错误OpName:[MatMul215] [InferShape] The k-axis of a(8) and b(14) tensors must be the same输入形状lhs:[14, 8, 64]rhs:[1, 14, 64, 8]转换后lhs:[14, 8, 64]rhs:[14, 64, 8]问题Ascend MatMul 将[14, 8, 64]解释为 K8将[14, 64, 8]解释为 K14导致 K 轴不匹配。解决方案Ascend MatMul 操作对比通过分析 Ascend 的 Op 定义发现有以下几种 MatMul 操作MatMul基本的矩阵乘法可能不支持 batch 维度输入x1, x2, bias (optional)属性transpose_x1, transpose_x2适用于2D 矩阵乘法[M, K] x [K, N] - [M, N]BatchMatMul专门支持 batch 维度的矩阵乘法输入x1, x2属性adj_x1, adj_x2适用于batch 矩阵乘法[batch..., M, K] x [batch..., K, N] - [batch..., M, N]MatMulV2增强版本支持更多数据类型输入x1, x2, bias (optional), offset_w (optional)属性transpose_x1, transpose_x2, offset_x适用于需要更多数据类型支持的场景优化策略根据 StableHLOdot_general的输入特征选择最合适的 Ascend Op场景StableHLO dot_generalAscend Op输入形状无 batch 维度contracting_dims [1] x [0]MatMul[M, K] x [K, N]有 batch 维度batching_dims [0] x [1]BatchMatMul[B, M, K] x [B, K, N]实现细节1. 添加 BatchMatMulOp 定义在mair_ops.td中添加def Air_BatchMatMulOp : Air_OpBatchMatMul, [Pure] { let summary Batch matrix multiplication operation; let description [{ Performs batch matrix multiplication on two input tensors. Supports batch dimensions: [batch..., M, K] x [batch..., K, N] - [batch..., M, N] }]; let arguments (ins Air_Tensor:$x1, Air_Tensor:$x2, DefaultValuedAttrBoolAttr, false:$adj_x1, DefaultValuedAttrBoolAttr, false:$adj_x2 ); let results (outs Air_Tensor:$output ); }2. 修改 ConvertMatMulOp根据是否有 batch 维度选择不同的操作if (!lhsBatchingDims.empty()) { // 有 batch 维度使用 BatchMatMul lhsReshapeShape {lhsBatchSize, lhsNonContractSize, lhsContractSize}; rhsReshapeShape {rhsBatchSize, rhsContractSize, rhsNonContractSize}; matmulResultShape {lhsBatchSize, lhsNonContractSize, rhsNonContractSize}; matmulResult rewriter.createBatchMatMulOp( op.getLoc(), matmulResultType, lhsReshaped, rhsReshaped, false, false).getResult(); } else { // 无 batch 维度使用 MatMul lhsReshapeShape {lhsNonContractSize, lhsContractSize}; rhsReshapeShape {rhsContractSize, rhsNonContractSize}; matmulResultShape {lhsNonContractSize, rhsNonContractSize}; matmulResult rewriter.createMatMulOp( op.getLoc(), matmulResultType, lhsReshaped, rhsReshaped, nullptr, false, false).getResult(); }3. 转换流程例子 1有 batch 维度输入stablehlo.dot_general %299, %296, batching_dims [0] x [1], contracting_dims [2] x [2] : (tensor14x8x64xf32, tensor1x14x64x8xf32) - tensor14x8x1x8xf32转换步骤维度识别lhs:[14, 8, 64]→ batch14, M8, K64rhs:[1, 14, 64, 8]→ batch14, K64, N8Transposelhs:[14, 8, 64]→[14, 8, 64](无需转置)rhs:[1, 14, 64, 8]→[14, 64, 1, 8]→[14, 64, 8]Reshapelhs:[14, 8, 64]→[14, 8, 64]rhs:[14, 64, 8]→[14, 64, 8]BatchMatMul[14, 8, 64]x[14, 64, 8]→[14, 8, 8]Reshape[14, 8, 8]→[14, 8, 1, 8]例子 2无 batch 维度输入stablehlo.dot_general %24, %arg13, contracting_dims [2] x [0] : (tensor1x8x896xf32, tensor896x128xf32) - tensor1x8x128xf32转换步骤维度识别lhs:[1, 8, 896]→ M8, K896rhs:[896, 128]→ K896, N128Reshapelhs:[1, 8, 896]→[8, 896]rhs:[896, 128]→[896, 128]MatMul[8, 896]x[896, 128]→[8, 128]Reshape[8, 128]→[1, 8, 128]优势语义正确使用 BatchMatMul 正确处理 batch 维度性能优化避免不必要的维度展平和恢复操作代码清晰根据输入特征选择最合适的操作可扩展性易于添加更多 MatMul 变体的支持修改的文件mair_ops.td添加 BatchMatMulOp 定义mair_passes.cc修改 ConvertMatMulOp根据 batch 维度选择不同的操作测试建议建议创建以下测试用例无 batch 维度的 dot_general→ 使用 MatMul有 batch 维度的 dot_general→ 使用 BatchMatMul多个 batch 维度的 dot_general→ 验证 BatchMatMul 的多 batch 支持边界情况维度大小为 1 的情况总结通过分析 Ascend 的不同 MatMul 操作我们优化了 StableHLOdot_general到 Ascend Op 的转换无 batch 维度使用 MatMul保持原有的 2D 矩阵乘法语义有 batch 维度使用 BatchMatMul正确处理 batch 维度这种优化不仅解决了 K 轴不匹配的问题还提高了转换的效率和正确性。【免费下载链接】xla-npuXLA-NPU 是一个面向华为昇腾NPU硬件的 XLA后端实现。本项目通过接入OpenXLA/XLA开源项目将XLA开源生态与华为 CANN软件栈集成对接JAX框架。JAX框架运行时可以直接加载XLA-NPU使得基于JAX框架开发的模型可以运行在昇腾NPU上提供推理场景图编译加速能力。项目地址: https://gitcode.com/cann/xla-npu创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考