长文本CP 切分共2次All2All第一次AlltoAll输入按Seq维度汇总按Head维度切。切输入非TP维度的切参数[s/c, b, n/t, h] -AlltoAll- [s, b, n/(t*c), h]第二次AlltoAll恢复按Seq维度切按Head维度汇总。[s, b, n/(t*c), h] -AlltoAll- [s/c, b, n/t, h]其中t 为TP, c 为CP, n nHead数举例 CP 2, TP 4 , H 8192, nHead 16阶段形状说明输入[s/2, b, 8192]CP 切分后每 rank 持有半个序列MLA 解压后 Q/K/V[s/2, b, 16, 192]16 heads/rank64 heads ÷ TP4经过了TP的降维A2A 后scatter headgather seq[s, b, 8, 192]全序列head 减半Flash Attention 输出[s, b, 8, 128]全序列本地计算A2A 后scatter seqgather head[s/2, b, 16, 128]还原序列分片o_proj 后[s/2, b, 8192]还原 hidden_states, 经过TP升维compressed_kv [s, b, 576] ← kv_a_proj 压缩后的 latent是 _preprocess 的输入 │ ├── split → ct_kv [s, b, 512] ← kv_lora_rank 部分 │ k_pe [s, b, 64] ← rope 部分 │ ├── kv_a_layernorm(ct_kv) │ └── kv_b_proj (Up-projection, 解压) [s, b, 512] → [s, b, 16heads, 128128] k_nope [s, b, 16, 128] v [s, b, 16, 128] q_b_input (经过 q_b_proj 解压) q_nope [s, b, 16, 128] q_pe [s, b, 16, 64] 最终拼接: query_states [s, b, 16, 192] q_nope q_pe key_states [s, b, 16, 192] k_nope k_pe value_states [s, b, 16, 128]MLA attention:DeepseekV2Attention └── self.core_attention_flash FlashAttention(...) # 基础 flash attn ↓ (当 CP alltoall 时自动包装) └── self.core_attention_flash DistributedAttention(FlashAttention, cp_group)