vllm.attention.ops.common ¶
CPTritonContext ¶
The CPTritonContext is used to avoid recompilation of the Triton JIT.
Source code in vllm/attention/ops/common.py
_correct_attn_cp_out_kernel ¶
_correct_attn_cp_out_kernel(
outputs_ptr,
new_output_ptr,
lses_ptr,
vlse_ptr,
outputs_stride_B,
outputs_stride_H,
outputs_stride_D,
lses_stride_N,
lses_stride_B,
lses_stride_H,
lse_idx,
HEAD_DIM: constexpr,
N_ROUNDED: constexpr,
)
Apply the all-gathered lses to correct each local rank's attention output. we still need perform a cross-rank reduction to obtain the final attention output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output | [ B, H, D ] | required | |
lses | [ N, B, H ] | required |
Return: output: [ B, H, D ] lse : [ B, H ]
Source code in vllm/attention/ops/common.py
correct_attn_out ¶
correct_attn_out(
out: Tensor,
lses: Tensor,
cp_rank: int,
ctx: CPTritonContext,
)
Apply the all-gathered lses to correct each local rank's attention output. we still need perform a cross-rank reduction to obtain the final attention output.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
output | [ B, H, D ] | required | |
lses | [ N, B, H ] | required |
Return: output: [ B, H, D ] lse : [ B, H ]
Source code in vllm/attention/ops/common.py
cp_lse_ag_out_rs ¶
cp_lse_ag_out_rs(
cp_attn_out: Tensor,
cp_attn_lse: Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None,
)
cp_attn_out: [ B, H, D ] cp_attn_lse: [ B, H ]