vllm.v1.attention.backends.mamba2_attn ¶
Mamba2AttentionBackend ¶
Bases: AttentionBackend
Source code in vllm/v1/attention/backends/mamba2_attn.py
Mamba2AttentionMetadata dataclass
¶
Source code in vllm/v1/attention/backends/mamba2_attn.py
token_chunk_offset_ptr class-attribute
instance-attribute
¶
__init__ ¶
__init__(
num_prefills: int,
num_prefill_tokens: int,
num_decodes: int,
num_decode_tokens: int,
query_start_loc: Tensor,
seq_lens: Tensor,
prep_initial_states: bool,
chunk_size: int,
has_initial_states_p: Optional[Tensor],
seq_idx_p: Optional[Tensor],
chunk_indices_p: Optional[Tensor],
chunk_offsets_p: Optional[Tensor],
state_indices_tensor: Tensor,
nums_dict: Optional[dict] = None,
cu_seqlen: Optional[int] = None,
batch_ptr: Optional[tensor] = None,
token_chunk_offset_ptr: Optional[tensor] = None,
) -> None
Mamba2AttentionMetadataBuilder ¶
Bases: BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
Source code in vllm/v1/attention/backends/mamba2_attn.py
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
|
__init__ ¶
__init__(
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: device,
)
Source code in vllm/v1/attention/backends/mamba2_attn.py
build ¶
build(
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> Mamba2AttentionMetadata
Source code in vllm/v1/attention/backends/mamba2_attn.py
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
|
_query_start_loc_to_chunk_indices_offsets ¶
_query_start_loc_to_chunk_indices_offsets(
query_start_loc: Tensor,
chunk_size: int,
total_seqlens: int,
) -> tuple[Tensor, Tensor]
Parameters:
Name | Type | Description | Default |
---|---|---|---|
query_start_loc | Tensor | 1D tensor of cumulative sequence lengths, shape (num_seqs + 1,). The first element should be 0. Each entry represents the starting index of a sequence in the flattened token array. | required |
chunk_size | int | The size of each physical mamba chunk (number of tokens per chunk). | required |
total_seqlens | int | The total number of tokens in the batch. | required |
Returns:
Type | Description |
---|---|
tuple[Tensor, Tensor] | Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - chunk_indices (torch.Tensor): 1D tensor of indices indicating the physical chunk for each logical chunk. - chunk_offsets (torch.Tensor): 1D tensor of offsets indicating the starting index of each logical chunk within its physical chunk. |
This function computes the chunk indices and offsets for the given query_start_loc and chunk_size. Both are tensors of integers with length N, where N is the number of logical (pseudo) chunks. A logical chunk is a sequence of tokens that are all part of the same sequence and are all in the same physical mamba chunk. In other words, a logical chunk changes every time we cross a sequence boundary or a physical mamba chunk boundary. Logical chunks are needed to handle batched requests with initial states (see _state_passing_fwd and _chunk_scan_fwd). The chunk_indices tensor contains the index of the physical chunk for each logical chunk. The chunk_offsets tensor contains the offset (AKA starting index) of the logical chunk in the physical chunk.
Example: query_start_loc = [0, 5, 10] chunk_size = 8 total_seqlens = 10 -> chunk_indices = [0, 0, 1] -> chunk_offsets = [0, 5, 0]
In this example, we have 2 sequences, each with 5 tokens. The physical chunk size is 8 tokens. We have three logical chunks: - the first logical chunk starts at token 0 in the first physical chunk and contains all 5 tokens from the first sequence - the second logical chunk starts at token 5 in the first physical chunk and contains first 3 tokens from the second sequence - the third logical chunk starts at token 0 in the second physical chunk and contains the remaining 2 tokens from the second sequence