Skip to content

vllm.model_executor.layers.mamba.mamba_utils

MambaStateCopyFunc module-attribute

MambaStateCopyFunc: TypeAlias = Callable[
    [Tensor, list[int], int, int], MambaCopySpec
]

Type alias for a function that computes a MambaCopySpec for copying state slices. Parameters: state: torch.Tensor - the Mamba state tensor (e.g., conv or temporal states). block_ids: list[int] - the list of block indices for the state to copy. cur_block_idx: int - current block index within block_ids to copy from. num_accepted_tokens: int - number of accepted tokens used to compute the copy offset. Range: 1 .. 1 + num_speculative_tokens (inclusive).

MambaCopySpec dataclass

Data class specifying the memory-copy parameters for Mamba states used for prefix caching in align mode.

Attributes:

Name Type Description
start_addr int

Starting address for the memory copy operation.

num_elements int

Number of elements to copy from the starting address.

Source code in vllm/model_executor/layers/mamba/mamba_utils.py
@dataclass
class MambaCopySpec:
    """
    Data class specifying the memory-copy parameters for Mamba states used for
    prefix caching in align mode.

    Attributes:
        start_addr (int): Starting address for the memory copy operation.
        num_elements (int): Number of elements to copy from the starting address.
    """

    start_addr: int
    num_elements: int

MambaStateShapeCalculator

Source code in vllm/model_executor/layers/mamba/mamba_utils.py
class MambaStateShapeCalculator:
    @classmethod
    def linear_attention_state_shape(
        cls,
        num_heads: int,
        tp_size: int,
        head_dim: int,
    ) -> tuple[tuple[int, int, int], ...]:
        state_shape = (num_heads // tp_size, head_dim, head_dim)
        return (state_shape,)

    @staticmethod
    def _orient_conv_shape(dim: int, state_len: int) -> tuple[int, int]:
        """Return (dim, state_len) for DS layout, (state_len, dim) for SD."""
        if is_conv_state_dim_first():
            return (dim, state_len)
        return (state_len, dim)

    @classmethod
    def mamba1_state_shape(
        cls,
        tp_world_size: int,
        intermediate_size: int,
        state_size: int,
        conv_kernel: int,
    ) -> tuple[tuple[int, int], tuple[int, int]]:
        conv_dim = divide(intermediate_size, tp_world_size)
        conv_state_shape = cls._orient_conv_shape(conv_dim, conv_kernel - 1)

        temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)

        return conv_state_shape, temporal_state_shape

    @classmethod
    def mamba2_state_shape(
        cls,
        tp_world_size: int,
        intermediate_size: int,
        n_groups: int,
        num_heads: int,
        head_dim: int,
        state_size: int,
        conv_kernel: int,
        num_spec: int = 0,
    ) -> tuple[tuple[int, int], tuple[int, int, int]]:
        # if n_groups is not divisible by world_size, need to extend the shards
        # to ensure all groups needed by a head is sharded along with it
        n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
        # heads and n_groups are TP-ed
        conv_dim = intermediate_size + 2 * n_groups * state_size

        conv_state_shape = cls._orient_conv_shape(
            divide(conv_dim, tp_world_size), conv_kernel - 1 + num_spec
        )

        # These are not TP-ed as they depend on A, dt_bias, D
        # - they are typically small
        #   e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
        temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
        return conv_state_shape, temporal_state_shape

    @classmethod
    def short_conv_state_shape(
        cls,
        tp_world_size: int,
        intermediate_size: int,
        conv_kernel: int,
    ) -> tuple[tuple[int, int]]:
        conv_dim = divide(intermediate_size, tp_world_size)
        conv_state_shape = cls._orient_conv_shape(conv_dim, conv_kernel - 1)
        return (conv_state_shape,)

    @classmethod
    def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
        """Compute the increase in group numbers to account for
        replication in order to accompany the head shards."""

        # in the case ngoups % tp_size == 0, this will be zero
        if ngroups % tp_size == 0:
            return 0

        # for n_groups == 1, this is exactly tp_size - n_groups
        return tp_size - ngroups

    @classmethod
    def gated_delta_net_state_shape(
        cls,
        tp_world_size: int,
        num_k_heads: int,
        num_v_heads: int,
        head_k_dim: int,
        head_v_dim: int,
        conv_kernel_size: int,
        num_spec: int = 0,
    ):
        conv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads
        conv_state_shape = cls._orient_conv_shape(
            divide(conv_dim, tp_world_size),
            conv_kernel_size - 1 + num_spec,
        )

        temporal_state_shape = (
            divide(num_v_heads, tp_world_size),
            head_v_dim,
            head_k_dim,
        )
        return conv_state_shape, temporal_state_shape

    @classmethod
    def kda_state_shape(
        cls,
        tp_world_size: int,
        num_heads: int,
        head_dim: int,
        num_k_heads: int | None = None,
        head_k_dim: int | None = None,
        conv_kernel_size: int = 4,
        num_spec: int = 0,
    ) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int, int]]:
        if num_k_heads is None:
            num_k_heads = num_heads
        if head_k_dim is None:
            head_k_dim = head_dim

        proj_size = num_heads * head_dim
        proj_k_size = num_k_heads * head_k_dim

        conv_state_shape = cls._orient_conv_shape(
            divide(proj_size, tp_world_size), conv_kernel_size - 1
        )
        conv_state_k_shape = cls._orient_conv_shape(
            divide(proj_k_size, tp_world_size), conv_kernel_size - 1
        )
        recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
        return (
            conv_state_shape,
            conv_state_k_shape,
            conv_state_k_shape,
            recurrent_state_shape,
        )

_orient_conv_shape staticmethod

_orient_conv_shape(
    dim: int, state_len: int
) -> tuple[int, int]

Return (dim, state_len) for DS layout, (state_len, dim) for SD.

Source code in vllm/model_executor/layers/mamba/mamba_utils.py
@staticmethod
def _orient_conv_shape(dim: int, state_len: int) -> tuple[int, int]:
    """Return (dim, state_len) for DS layout, (state_len, dim) for SD."""
    if is_conv_state_dim_first():
        return (dim, state_len)
    return (state_len, dim)

extra_groups_for_head_shards classmethod

extra_groups_for_head_shards(ngroups: int, tp_size: int)

Compute the increase in group numbers to account for replication in order to accompany the head shards.

Source code in vllm/model_executor/layers/mamba/mamba_utils.py
@classmethod
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
    """Compute the increase in group numbers to account for
    replication in order to accompany the head shards."""

    # in the case ngoups % tp_size == 0, this will be zero
    if ngroups % tp_size == 0:
        return 0

    # for n_groups == 1, this is exactly tp_size - n_groups
    return tp_size - ngroups

get_conv_copy_spec

get_conv_copy_spec(
    state: Tensor,
    block_ids: list[int],
    cur_block_idx: int,
    num_accepted_tokens: int,
) -> MambaCopySpec

Return a MambaCopySpec for copying a convolutional state slice.

Works for both SD layout (num_blocks, state_len, dim) and DS layout (num_blocks, dim, state_len).

Source code in vllm/model_executor/layers/mamba/mamba_utils.py
def get_conv_copy_spec(
    state: torch.Tensor,
    block_ids: list[int],
    cur_block_idx: int,
    num_accepted_tokens: int,
) -> MambaCopySpec:
    """Return a MambaCopySpec for copying a convolutional state slice.

    Works for both SD layout ``(num_blocks, state_len, dim)`` and
    DS layout ``(num_blocks, dim, state_len)``.
    """
    src_block_id = block_ids[cur_block_idx]
    offset = num_accepted_tokens - 1
    if is_conv_state_dim_first():
        # DS layout: (num_blocks, dim, state_len) — state_len is last.
        if offset > 0:
            # Slicing along the last dim yields a non-contiguous view
            # because features (dim) are strided by state_len.
            raise NotImplementedError(
                "DS conv state layout does not yet support speculative "
                "decoding with mamba_cache_mode='align' "
                "(num_accepted_tokens > 1)."
            )
        src_state = state[src_block_id]
    else:
        # SD layout: (num_blocks, state_len, dim) — dim contiguous.
        src_state = state[src_block_id, offset:]
    return MambaCopySpec(
        start_addr=src_state.data_ptr(), num_elements=src_state.numel()
    )

get_conv_state_layout cached

get_conv_state_layout() -> ConvStateLayoutType

Return the SSM conv state layout.

SD = (state_len, dim) — dim is the innermost contiguous dimension. DS = (dim, state_len) — TP-sharded dim is on dim-1 (like HND for KV cache), consistent with SSM temporal state layout.

Source code in vllm/model_executor/layers/mamba/mamba_utils.py
@functools.lru_cache
def get_conv_state_layout() -> ConvStateLayoutType:
    """Return the SSM conv state layout.

    SD = (state_len, dim) — dim is the innermost contiguous dimension.
    DS = (dim, state_len) — TP-sharded dim is on dim-1 (like HND for KV
         cache), consistent with SSM temporal state layout.
    """
    layout: ConvStateLayoutType | None = envs.VLLM_SSM_CONV_STATE_LAYOUT
    if layout is not None:
        logger.info_once(
            "VLLM_SSM_CONV_STATE_LAYOUT env detected. "
            "Setting SSM conv state layout to %s.",
            layout,
        )
        return layout

    return "SD"

get_temporal_copy_spec

get_temporal_copy_spec(
    state: Tensor,
    block_ids: list[int],
    cur_block_idx: int,
    num_accepted_tokens: int,
) -> MambaCopySpec

Return a MambaCopySpec for copying a temporal state slice.

Source code in vllm/model_executor/layers/mamba/mamba_utils.py
def get_temporal_copy_spec(
    state: torch.Tensor,
    block_ids: list[int],
    cur_block_idx: int,
    num_accepted_tokens: int,
) -> MambaCopySpec:
    """Return a MambaCopySpec for copying a temporal state slice."""
    src_block_id = block_ids[cur_block_idx + num_accepted_tokens - 1]
    src_state = state[src_block_id]
    return MambaCopySpec(
        start_addr=src_state.data_ptr(), num_elements=src_state.numel()
    )

is_conv_state_dim_first

is_conv_state_dim_first() -> bool

True when the conv state is stored as (dim, state_len) per block.

Source code in vllm/model_executor/layers/mamba/mamba_utils.py
def is_conv_state_dim_first() -> bool:
    """True when the conv state is stored as (dim, state_len) per block."""
    return get_conv_state_layout() == "DS"