class DeepseekV4Attention(nn.Module, AttentionLayerBase, ABC):
"""DeepseekV4 MLA attention layer.
The platform-specific sparse-MLA forward (``forward_mqa`` /
``get_padded_num_q_heads`` / ``_o_proj`` / ``backend_cls``) is provided by a
subclass — ``DeepseekV4FlashMLAAttention`` / ``DeepseekV4FlashInferMLAAttention``
(CUDA) or ``DeepseekV4ROCMAiterMLAAttention`` (ROCm) — selected by the
platform-specific deepseek_v4 model module. The base is never instantiated
directly.
"""
# Provided by the platform subclass.
backend_cls: ClassVar[type[AttentionBackend]]
# KV-cache per-token block format (both layouts are paged). True (default)
# = FlashMLA / ROCm fp8_ds_mla (UE8M0 block-scaled fp8 packed as uint8);
# False = FlashInfer plain bf16 / per-tensor fp8 KV row.
use_flashmla_fp8_layout: ClassVar[bool] = True
# Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather
# workspace allocated in _forward_prefill and is also read by the dummy-run
# path to pre-reserve that workspace.
PREFILL_CHUNK_SIZE: ClassVar[int] = 4
@classmethod
@abstractmethod
def get_padded_num_q_heads(cls, num_heads: int) -> int:
"""Q head count the q/output buffers are allocated at.
The layer allocates the q/output buffers at
``[N, get_padded_num_q_heads(n_local_heads), head_dim]``. Must satisfy
``result >= num_heads``. Backends with no padding constraint return
``num_heads``.
"""
raise NotImplementedError
@abstractmethod
def forward_mqa(
self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
"""Platform-specific sparse MLA forward; writes attention into ``output``."""
raise NotImplementedError
@abstractmethod
def _o_proj(self, o: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
"""Inverse-RoPE + wo_a + wo_b output projection (platform-specific)."""
raise NotImplementedError
def __init__(
self,
vllm_config: VllmConfig,
prefix: str,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream_list: list[torch.cuda.Stream] | None = None,
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
cache_config = vllm_config.cache_config
tp_size = get_tensor_model_parallel_world_size()
layer_id = extract_layer_index(prefix)
self.prefix = prefix # Alias for compatibility with compressor
self.hidden_size = config.hidden_size
self.n_heads = config.num_attention_heads
assert self.n_heads % tp_size == 0
self.n_local_heads = self.n_heads // tp_size
self.q_lora_rank = config.q_lora_rank
self.o_lora_rank = config.o_lora_rank
self.head_dim = config.head_dim
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = self.head_dim - self.rope_head_dim
self.n_groups = config.o_groups
self.n_local_groups = self.n_groups // tp_size
self.window_size = config.sliding_window
# NOTE(zyongye) Compress ratio can't be 0
# we do this for because MTP layer is not included
# in the compress ratio list
if layer_id < config.num_hidden_layers:
self.compress_ratio = max(1, config.compress_ratios[layer_id])
else:
self.compress_ratio = 1
self.eps = config.rms_norm_eps
self.scale = self.head_dim**-0.5
# Padded Q head count is dictated by the platform subclass.
self.padded_heads = self.get_padded_num_q_heads(self.n_local_heads)
# Sink padded to the same head count, initialized to -inf (no sink
# effect). Weight loading fills the first n_local_heads slots.
self.attn_sink = nn.Parameter(
torch.full((self.padded_heads,), -float("inf"), dtype=torch.float32),
requires_grad=False,
)
self.fused_wqa_wkv = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_wqa_wkv",
disable_tp=True, # fused ReplicatedLinear
)
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
self.wq_b = ColumnParallelLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wq_b",
)
self.kv_norm = RMSNorm(self.head_dim, self.eps)
self.wo_a = ColumnParallelLinear(
self.n_heads * self.head_dim // self.n_groups,
self.n_groups * self.o_lora_rank,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_a",
)
self.wo_a.is_bmm = True
self.wo_a.bmm_batch_size = self.n_local_groups
self.wo_b = RowParallelLinear(
self.n_groups * self.o_lora_rank,
self.hidden_size,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_b",
)
# Initialize rotary embedding before the indexer/compressor consume it.
self.rotary_emb = build_deepseek_v4_rope(
config,
head_dim=self.head_dim,
rope_head_dim=self.rope_head_dim,
max_position_embeddings=config.max_position_embeddings,
compress_ratio=self.compress_ratio,
)
self.indexer_rotary_emb = self.rotary_emb
self.topk_indices_buffer = topk_indices_buffer
self.indexer = None
if self.compress_ratio == 4:
# Only C4A uses sparse attention and hence has indexer.
# aux_stream_list[2] is free here (outer GEMMs joined) for the inner
# overlap of wq_b+fused_indexer_q_rope_quant vs compressor. None on
# ROCm, where aux_stream_list is None.
indexer_aux_stream = (
aux_stream_list[2] if aux_stream_list is not None else None
)
self.indexer = DeepseekV4Indexer(
vllm_config,
config=config,
hidden_size=self.hidden_size,
q_lora_rank=self.q_lora_rank,
quant_config=quant_config,
cache_config=cache_config,
topk_indices_buffer=topk_indices_buffer,
compress_ratio=self.compress_ratio,
prefix=f"{prefix}.indexer",
aux_stream=indexer_aux_stream,
)
# Will be None on ROCm for now.
self.aux_stream_list = aux_stream_list
# [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events;
# [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins
# before post-GEMM starts.
self.ln_events = [torch.cuda.Event() for _ in range(4)]
assert cache_config is not None, "DeepseekV4 attention requires cache_config"
# ---- Attention / KV-cache setup ----
self.max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens
)
self.max_model_len = vllm_config.model_config.max_model_len
# Resolve the kv-cache dtype from this backend's block format (a
# ClassVar set by the subclass): fp8_ds_mla (UE8M0 block-scaled fp8 as
# uint8) for FlashMLA / ROCm, vs a plain bf16 / per-tensor fp8 row for
# FlashInfer. The same resolution drives the SWA cache tensor dtype
# below.
self.kv_cache_dtype, self.kv_cache_torch_dtype = _resolve_dsv4_kv_cache_dtype(
self.use_flashmla_fp8_layout, cache_config.cache_dtype, cache_config
)
self.swa_cache_layer = DeepseekV4SWACache(
head_dim=self.head_dim,
window_size=self.window_size,
dtype=self.kv_cache_torch_dtype,
prefix=f"{prefix}.swa_cache",
cache_config=cache_config,
)
# Register with compilation context for metadata lookup.
compilation_config = vllm_config.compilation_config
if prefix and prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
if prefix:
compilation_config.static_forward_context[prefix] = self
self.kv_cache = torch.tensor([])
# Create the compressor for layers with compress_ratio > 1; after the
# attention setup above so its KV-cache prefix (self.prefix) is set.
self.compressor = None
if self.compress_ratio > 1:
self.compressor = DeepseekCompressor(
vllm_config=vllm_config,
compress_ratio=self.compress_ratio,
hidden_size=self.hidden_size,
head_dim=self.head_dim,
rotate=True,
prefix=f"{prefix}.compressor",
k_cache_prefix=self.prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None = None,
) -> torch.Tensor:
# Pre-allocate attention output with FlashMLA-padded head count.
# The op writes into `o_padded`; we slice to n_local_heads after.
num_tokens = hidden_states.shape[0]
o_padded = torch.empty(
(num_tokens, self.padded_heads, self.head_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
# attention_impl is wrapped with @eager_break_during_capture: this is
# where the breakable cudagraph capture breaks (the attention op runs
# eagerly between captured graph segments).
self.attention_impl(hidden_states, positions, o_padded)
o = o_padded[:, : self.n_local_heads, :]
# Inverse-RoPE + wo_a + wo_b output projection (platform-specific).
return self._o_proj(o, positions)
def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
aux_streams = self.aux_stream_list
if aux_streams is not None:
assert len(aux_streams) >= 3
aux_streams = aux_streams[:3]
# fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs
# on aux streams 0..2 when their owning module exists. ln_events[0]
# is the fan-out start event; ln_events[1..3] are per-aux done events.
# On ROCm, aux_streams is None and execute_in_parallel runs serially.
aux_fns: list[Callable[[], Any] | None] = [None, None, None]
if self.compressor is not None:
# Local ref so the closure keeps a non-None type for mypy.
compressor = self.compressor
def compressor_kv_score() -> torch.Tensor:
return torch.mm(
hidden_states,
compressor.fused_wkv_wgate.weight.T,
out_dtype=torch.float32,
)
aux_fns[0] = compressor_kv_score
if self.indexer is not None:
indexer = self.indexer
def indexer_weights_proj() -> torch.Tensor:
# ReplicatedLinear returns (output, bias); bias is None.
weights, _ = indexer.weights_proj(hidden_states)
return weights
def indexer_compressor_kv_score() -> torch.Tensor:
return torch.mm(
hidden_states,
indexer.compressor.fused_wkv_wgate.weight.T,
out_dtype=torch.float32,
)
aux_fns[1] = indexer_weights_proj
aux_fns[2] = indexer_compressor_kv_score
def fused_wqa_wkv() -> torch.Tensor:
# MergedColumnParallelLinear returns (output, bias); bias is None.
qr_kv, _ = self.fused_wqa_wkv(hidden_states)
return qr_kv
qr_kv, (kv_score, indexer_weights, indexer_kv_score) = execute_in_parallel(
fused_wqa_wkv,
aux_fns,
self.ln_events[0],
self.ln_events[1:4],
aux_streams,
enable=hidden_states.shape[0]
<= envs.VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD,
)
return qr_kv, kv_score, indexer_kv_score, indexer_weights
@eager_break_during_capture
def attention_impl(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place
) -> None:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
qr_kv, kv_score, indexer_kv_score, indexer_weights = (
self.attn_gemm_parallel_execute(hidden_states)
)
qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1)
qr, kv = fused_q_kv_rmsnorm(
qr,
kv,
self.q_norm.weight.data,
self.kv_norm.weight.data,
self.eps,
)
# wq_b + kv_insert (+ MLA compressor when an indexer is present) ride
# on the default stream so q stays on its consumer stream (forward_mqa
# downstream reads q on default). Indexer/compressor go on aux for
# overlap with default's GEMM + cache write.
if self.indexer is not None:
aux_streams = self.aux_stream_list
indexer = self.indexer
# Local ref so the closure keeps a non-None type for mypy.
assert self.compressor is not None
compressor = self.compressor
def wq_b_kv_insert() -> torch.Tensor:
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
q = self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
return q
# 3-way overlap (matches TRT-LLM PR #14142 Level 1): default runs
# wq_b+kv_insert; slot [0] runs the full indexer; slot [1] runs the
# MLA compressor. Slot [2] is reserved for the indexer's inner
# overlap. ROCm (aux_streams is None) falls back to sequential.
q, _ = execute_in_parallel(
wq_b_kv_insert,
[
lambda: indexer(
hidden_states,
qr,
indexer_kv_score,
indexer_weights,
positions,
self.indexer_rotary_emb,
),
lambda: compressor(kv_score, positions, self.rotary_emb),
],
self.ln_events[0],
[self.ln_events[1], self.ln_events[2]],
[aux_streams[0], aux_streams[1]] if aux_streams is not None else None,
enable=aux_streams is not None,
)
elif self.compressor is not None:
# wq_b + kv_insert on default, compressor on aux.
aux_stream = (
self.aux_stream_list[0] if self.aux_stream_list is not None else None
)
compressor = self.compressor
def wq_b_kv_insert() -> torch.Tensor:
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
q = self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
return q
q, _ = maybe_execute_in_parallel(
wq_b_kv_insert,
lambda: compressor(kv_score, positions, self.rotary_emb),
self.ln_events[0],
self.ln_events[1],
aux_stream,
)
else:
# SWA-only layer: no compressor, no overlap.
q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim)
q = self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata)
# MLA attention writes into the pre-allocated `out` buffer
# ([num_tokens, padded_heads, head_dim]).
self.forward_mqa(q, kv, positions, out)
def _fused_qnorm_rope_kv_insert(
self,
q: torch.Tensor,
kv: torch.Tensor,
positions: torch.Tensor,
attn_metadata: (
dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None
),
) -> torch.Tensor:
if not isinstance(attn_metadata, dict):
# Profile run: kernel doesn't fire; produce a padded tensor so
# downstream FlashMLA gets the right shape.
if self.n_local_heads < self.padded_heads:
return F.pad(
q,
(0, 0, 0, self.padded_heads - self.n_local_heads),
value=0.0,
)
return q
swa_metadata = cast(
"DeepseekSparseSWAMetadata | None",
attn_metadata.get(self.swa_cache_layer.prefix),
)
assert swa_metadata is not None
swa_kv_cache = self.swa_cache_layer.kv_cache
# The fused insert ops require int64 position_ids; the runner's positions
# buffer is already int64, so no cast is needed.
assert positions.dtype == torch.int64
cos_sin_cache = self.rotary_emb.cos_sin_cache
cache_dtype = swa_kv_cache.dtype
# kv is unchanged; attention reads kv solely via swa_kv_cache.
if cache_dtype == torch.uint8:
# Legacy FlashMLA UE8M0 paged path. Horizontally fused:
# Q side: per-head RMSNorm (no weight) + GPT-J RoPE, zero-filling
# the padding head slots; the kernel allocates and returns
# the padded q tensor.
# KV side: GPT-J RoPE + UE8M0 FP8 quant + paged cache insert.
swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1)
return torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert(
q,
kv,
swa_kv_cache_2d,
swa_metadata.slot_mapping,
positions,
cos_sin_cache,
self.padded_heads,
self.eps,
swa_metadata.block_size,
)
# FlashInfer full-cache path: the [num_blocks, block_size, 512] cache
# stores the KV row in its plain dtype (no Q padding). bf16 rewrites q
# in place; per-tensor fp8 writes a separately-allocated fp8 q and
# quantizes the KV row.
block_size = swa_metadata.block_size
swa_kv_cache_3d = swa_kv_cache.view(-1, block_size, self.head_dim)
if cache_dtype == torch.bfloat16:
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_bf16_insert(
q,
kv,
swa_kv_cache_3d,
swa_metadata.slot_mapping,
positions,
cos_sin_cache,
self.eps,
block_size,
)
return q
# per-tensor fp8 (torch.float8_e4m3fn)
q_fp8 = torch.empty_like(q, dtype=torch.float8_e4m3fn)
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_full_cache_fp8_insert(
q,
kv,
q_fp8,
swa_kv_cache_3d,
swa_metadata.slot_mapping,
positions,
cos_sin_cache,
self._flashinfer_fp8_kv_scale,
self._flashinfer_fp8_q_scale_inv,
self.eps,
block_size,
)
return q_fp8
def get_attn_backend(self) -> type[AttentionBackend]:
return self.backend_cls
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
if (
self.compress_ratio <= 1
): # SWA part. Allocated separately as DeepseekV4SWACache.
return None
# FlashMLA uses the fp8_ds_mla block format (UE8M0 block-scaled fp8 as
# uint8, 576B aligned); FlashInfer stores a plain bf16 / per-tensor fp8
# row with no extra alignment.
is_flashmla = self.kv_cache_dtype == "fp8_ds_mla"
return MLAAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=torch.uint8 if is_flashmla else self.kv_cache_torch_dtype,
compress_ratio=self.compress_ratio,
cache_dtype_str=self.kv_cache_dtype,
alignment=576 if is_flashmla else None, # FlashMLA needs 576B
model_version="deepseek_v4",
)