Skip to content

vllm.attention.ops.common

CPTritonContext

The CPTritonContext is used to avoid recompilation of the Triton JIT.

Source code in vllm/attention/ops/common.py
class CPTritonContext:
    """ The CPTritonContext is used to avoid recompilation of the Triton JIT.
    """

    def __init__(self):
        self.inner_kernel = None

    def call_kernel(self, kernel, grid, *regular_args, **const_args):
        if self.inner_kernel is None:
            self.inner_kernel = kernel[grid](*regular_args, **const_args)
        else:
            self.inner_kernel[grid](*regular_args)

inner_kernel instance-attribute

inner_kernel = None

__init__

__init__()
Source code in vllm/attention/ops/common.py
def __init__(self):
    self.inner_kernel = None

call_kernel

call_kernel(kernel, grid, *regular_args, **const_args)
Source code in vllm/attention/ops/common.py
def call_kernel(self, kernel, grid, *regular_args, **const_args):
    if self.inner_kernel is None:
        self.inner_kernel = kernel[grid](*regular_args, **const_args)
    else:
        self.inner_kernel[grid](*regular_args)

_correct_attn_cp_out_kernel

_correct_attn_cp_out_kernel(
    outputs_ptr,
    new_output_ptr,
    lses_ptr,
    vlse_ptr,
    outputs_stride_B,
    outputs_stride_H,
    outputs_stride_D,
    lses_stride_N,
    lses_stride_B,
    lses_stride_H,
    lse_idx,
    HEAD_DIM: constexpr,
    N_ROUNDED: constexpr,
)

Apply the all-gathered lses to correct each local rank's attention output. we still need perform a cross-rank reduction to obtain the final attention output.

Parameters:

Name Type Description Default
output

[ B, H, D ]

required
lses

[ N, B, H ]

required

Return: output: [ B, H, D ] lse : [ B, H ]

Source code in vllm/attention/ops/common.py
@triton.jit
def _correct_attn_cp_out_kernel(outputs_ptr, new_output_ptr, lses_ptr,
                                vlse_ptr, outputs_stride_B, outputs_stride_H,
                                outputs_stride_D, lses_stride_N, lses_stride_B,
                                lses_stride_H, lse_idx, HEAD_DIM: tl.constexpr,
                                N_ROUNDED: tl.constexpr):
    """
    Apply the all-gathered lses to correct each local rank's attention
    output. we still need perform a cross-rank reduction to obtain the
    final attention output.

    Args:
        output: [ B, H, D ]
        lses   : [ N, B, H ]
        cp, batch, q_heads, v_head_dim
    Return:
        output: [ B, H, D ]
        lse   : [ B, H ]
    """
    batch_idx = tl.program_id(axis=0).to(tl.int64)
    head_idx = tl.program_id(axis=1).to(tl.int64)
    d_offsets = tl.arange(0, HEAD_DIM)
    num_n_offsets = tl.arange(0, N_ROUNDED)

    # shape = [N]
    lse_offsets = num_n_offsets * lses_stride_N + batch_idx * \
        lses_stride_B + head_idx * lses_stride_H

    # calc final lse
    lse = tl.load(lses_ptr + lse_offsets)
    lse = tl.where((lse != lse) | (lse == float('inf')), -float('inf'), lse)
    lse_max = tl.max(lse, axis=0)
    lse -= lse_max
    lse_exp = tl.exp(lse)
    lse_acc = tl.sum(lse_exp, axis=0)
    lse = tl.log(lse_acc)
    lse += lse_max

    lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
    tl.store(vlse_ptr + lse_offsets, lse)

    # shape = [D]
    output_offsets = batch_idx * outputs_stride_B + \
                    head_idx * outputs_stride_H + \
                    d_offsets * outputs_stride_D

    # correct output
    lse_offset = lse_idx * lses_stride_N + batch_idx * \
        lses_stride_B + head_idx * lses_stride_H
    lse_tmp = tl.load(lses_ptr + lse_offset)
    lse_finally = lse_tmp - lse
    lse_finally = tl.where(
        (lse_finally != lse_finally) | (lse_finally == float('inf')),
        -float('inf'), lse_finally)
    factor = tl.exp(lse_finally)
    output = tl.load(outputs_ptr + output_offsets)
    output = output * factor

    tl.store(new_output_ptr + output_offsets, output)

correct_attn_out

correct_attn_out(
    out: Tensor,
    lses: Tensor,
    cp_rank: int,
    ctx: CPTritonContext,
)

Apply the all-gathered lses to correct each local rank's attention output. we still need perform a cross-rank reduction to obtain the final attention output.

Parameters:

Name Type Description Default
output

[ B, H, D ]

required
lses

[ N, B, H ]

required

Return: output: [ B, H, D ] lse : [ B, H ]

Source code in vllm/attention/ops/common.py
def correct_attn_out(out: torch.Tensor, lses: torch.Tensor, cp_rank: int,
                     ctx: CPTritonContext):
    """
    Apply the all-gathered lses to correct each local rank's attention
    output. we still need perform a cross-rank reduction to obtain the
    final attention output.

    Args:
        output: [ B, H, D ]
        lses   : [ N, B, H ]
    Return:
        output: [ B, H, D ]
        lse   : [ B, H ]
    """
    if ctx is None:
        ctx = CPTritonContext()

    lse = torch.empty_like(lses[0])

    grid = (out.shape[0], out.shape[1], 1)
    regular_args = (out, out, lses, lse, *out.stride(), *lses.stride(),
                    cp_rank)
    const_args = {
        "HEAD_DIM": out.shape[-1],
        "N_ROUNDED": lses.shape[0],
    }

    ctx.call_kernel(_correct_attn_cp_out_kernel, grid, *regular_args,
                    **const_args)
    return out, lse

cp_lse_ag_out_rs

cp_lse_ag_out_rs(
    cp_attn_out: Tensor,
    cp_attn_lse: Tensor,
    cp_group: GroupCoordinator,
    ctx: CPTritonContext = None,
)

cp_attn_out: [ B, H, D ] cp_attn_lse: [ B, H ]

Source code in vllm/attention/ops/common.py
def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor,
                     cp_attn_lse: torch.Tensor,
                     cp_group: GroupCoordinator,
                     ctx: CPTritonContext = None):
    """
    cp_attn_out: [ B, H, D ]
    cp_attn_lse: [ B, H ]
    """
    if cp_group.world_size == 1:
        return cp_attn_out

    if ctx is None:
        ctx = CPTritonContext()

    lses = torch.empty((cp_group.world_size, ) + cp_attn_lse.shape,
                       dtype=cp_attn_lse.dtype,
                       device=cp_attn_lse.device)

    cp_attn_lse = cp_attn_lse.contiguous()
    lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
    out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
    assert out.is_contiguous()
    out = cp_group.reduce_scatter(out, dim=1)
    return out