StridedSlice 算子 API 描述【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力涵盖算子生成、算子优化等领域支撑模型选型、训练效果评估统一量化评估标准识别Agent能力短板构建CANN领域评测平台推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench1. 算子简介使用步长对输入张量进行多维切片提取子张量。支持 begin_mask、end_mask 控制边界、shrink_axis_mask 收缩维度、new_axis_mask 插入新维度、ellipsis_mask 省略号等功能。主要应用场景深度学习模型中对特征图进行区域裁剪和下采样序列模型中按步长提取时间步或特征片段数据预处理中对多维张量进行灵活的切片操作模型推理中通过掩码机制实现复杂的维度操控算子特征难度等级L3LayoutTransform单输入单输出支持 0-8 维输入支持负数步长和多种掩码控制2. 算子定义数学公式$$ y[i,j,k,...] x[\text{begin}[0]:\text{end}[0]:\text{strides}[0],\ \text{begin}[1]:\text{end}[1]:\text{strides}[1],\ \text{begin}[2]:\text{end}[2]:\text{strides}[2],\ ...] $$各掩码参数的作用begin_mask二进制掩码位 1 表示该维度从 0 开始忽略 begin 值end_mask二进制掩码位 1 表示该维度切到末尾忽略 end 值ellipsis_mask二进制掩码位 1 表示该维度使用省略号标记shrink_axis_mask二进制掩码位 1 表示该维度被收缩掉维度大小为 1new_axis_mask二进制掩码位 1 表示该位置插入大小为 1 的新维度3. 接口规范算子原型cann_bench.strided_slice(Tensor x, int[] begin, int[] end, int[] strides, int begin_mask, int end_mask, int ellipsis_mask, int shrink_axis_mask, int new_axis_mask) - Tensor y输入参数说明参数类型默认值描述xTensor必选输入张量beginint[]必选切片起始位置数组长度等于输入维度数endint[]必选切片结束位置数组长度等于输入维度数stridesint[]必选切片步长数组长度等于输入维度数支持负数步长begin_maskint64_t—二进制掩码位 1 表示该维度从 0 开始位 0 使用 begin 值end_maskint64_t—二进制掩码位 1 表示该维度切到末尾位 0 使用 end 值ellipsis_maskint64_t—二进制掩码位 1 表示该维度使用省略号标记shrink_axis_maskint64_t—二进制掩码位 1 表示该维度被收缩掉维度大小为 1new_axis_maskint64_t—二进制掩码位 1 表示该位置插入大小为 1 的新维度输出参数Shapedtype描述y由 begin、end、strides 及各掩码决定与输入 x 相同输出张量切片结果数据类型输入 dtype输出 dtypeint8int8uint8uint8int32int32int64int64float16float16float32float32bfloat16bfloat16规则与约束输入支持 0-8 维张量begin、end、strides 数组长度必须等于输入维度数strides 中每个元素不能为 0支持负数步长表示逆序切片begin 和 end 支持负数索引表示从末尾倒数各掩码参数以二进制位的形式对应各维度低位对应低维度输出 dtype 与输入 dtype 一致4. 精度要求采用生态算子精度标准进行验证。误差指标平均相对误差MERE采样点中相对误差平均值$$ \text{MERE} \text{avg}(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)\text{1e-7}}) $$最大相对误差MARE采样点中相对误差最大值$$ \text{MARE} \max(\frac{\text{abs}(actual - golden)}{\text{abs}(golden)\text{1e-7}}) $$通过标准数据类型FLOAT16BFLOAT16FLOAT32HiFLOAT32FLOAT8 E4M3FLOAT8 E5M2通过阈值(Threshold)2^-102^-72^-132^-112^-32^-2当平均相对误差 MERE Threshold最大相对误差 MARE 10 * Threshold 时判定为通过。5. 标准 Golden 代码import torch def strided_slice( x: torch.Tensor, begin: list, end: list, strides: list, begin_mask: int 0, end_mask: int 0, ellipsis_mask: int 0, shrink_axis_mask: int 0, new_axis_mask: int 0 ) - torch.Tensor: 使用步长对输入张量进行多维切片对标 TensorFlow strided_slice。 Args: x: 输入张量 begin: 切片起始位置数组 end: 切片结束位置数组 strides: 切片步长数组支持负数步长 begin_mask: 二进制掩码位1表示该维度从0开始 end_mask: 二进制掩码位1表示该维度切到末尾 ellipsis_mask: 二进制掩码位1表示省略号标记 shrink_axis_mask: 二进制掩码位1表示收缩该维度取单元素 new_axis_mask: 二进制掩码位1表示插入新维度 Returns: 输出张量切片结果 ndim x.dim() shape x.shape # 处理 ellipsis_mask ellipsis_pos None for i in range(32): if ellipsis_mask (1 i): ellipsis_pos i break # 计算 new_axis 数量 num_new_axis 0 for i in range(len(begin) if begin else 0): if new_axis_mask (1 i): num_new_axis 1 indices [] input_dim_idx 0 param_idx 0 if ellipsis_pos is not None: num_params len(begin) if begin else 0 num_ellipsis_dims ndim - (num_params - num_new_axis - 1) if num_ellipsis_dims 0: num_ellipsis_dims 0 while input_dim_idx ndim or param_idx (len(begin) if begin else 0): if param_idx len(begin) and (new_axis_mask (1 param_idx)): indices.append(None) param_idx 1 continue if ellipsis_pos is not None and param_idx ellipsis_pos: for _ in range(num_ellipsis_dims): indices.append(slice(None, None, None)) input_dim_idx 1 param_idx 1 continue if input_dim_idx ndim and param_idx len(begin): dim_size shape[input_dim_idx] b begin[param_idx] if param_idx len(begin) else 0 e end[param_idx] if param_idx len(end) else dim_size s strides[param_idx] if param_idx len(strides) else 1 if b 0: b b dim_size if e 0: e e dim_size if begin_mask (1 param_idx): b 0 if s 0 else dim_size - 1 if end_mask (1 param_idx): e dim_size if s 0 else -1 if shrink_axis_mask (1 param_idx): indices.append(b) else: indices.append(slice(b, e, s)) input_dim_idx 1 param_idx 1 elif input_dim_idx ndim: indices.append(slice(None, None, None)) input_dim_idx 1 else: if param_idx len(begin) and (new_axis_mask (1 param_idx)): indices.append(None) param_idx 1 return x[tuple(indices)]6. 额外信息算子调用示例import torch import cann_bench x torch.randn(1024, 1024, dtypetorch.float16, devicenpu) y cann_bench.strided_slice(x, [0, 0], [512, 512], [2, 2], 0, 0, 0, 0, 0) x torch.randn(2, 8, 256, 256, dtypetorch.float32, devicenpu) y cann_bench.strided_slice(x, [0, 0, 0, 0], [-1, -1, 128, 128], [1, 1, 2, 2], 0, 0, 0, 0, 0)【免费下载链接】cann-bench评测AI在处理CANN领域代码任务的能力涵盖算子生成、算子优化等领域支撑模型选型、训练效果评估统一量化评估标准识别Agent能力短板构建CANN领域评测平台推动AI能力在CANN领域的持续演进。项目地址: https://gitcode.com/cann/cann-bench创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考