vllm.compilation.passes.fusion.rocm_aiter_fusion ¶
AddAiterRMSNormPadPattern ¶
This pattern replaces an aiter_rmsnorm_with_add & a pad op with a custom triton_add_rmsnorm_pad op from AITER.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
AiterFusedAddRMSFp8GroupQuantPattern ¶
Bases: AiterRMSNormQuantPattern
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops into a aiter rms_norm_with_add_group_fp8_quant op.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
AiterFusedAddRMSNormDynamicQuantPattern ¶
Bases: AiterRMSNormQuantPattern
AITER RMSNorm Fused Add + Dynamic Quantization pattern.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
AiterFusedAddRMSNormMXFP4QuantPattern ¶
Bases: AiterRMSNormQuantPattern
Fuse AITER fused_add_rms_norm + dynamic MXFP4 quant into a single kernel.
Matched 3-node subgraph::
torch.ops.vllm_ir.fused_add_rms_norm(x, residual, weight, eps)
→ torch.ops.vllm.rocm_aiter_dynamic_mxfp4_quant(z)
Replacement: single AITER fused Triton call rocm_aiter_rmsnorm_add_mxfp4_quant(x, residual, weight, eps), returning (fp4_data, scale, updated_residual).
Registered BEFORE :class:AiterRMSNormMXFP4QuantPattern so that the larger subgraph is attempted first (greedy matching).
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
AiterRMSFp8GroupQuantPattern ¶
Bases: AiterRMSNormQuantPattern
This pattern fuses aiter rms_norm & group fp8 quant custom ops into an aiter rms_norm_group_fp8_quant op.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
AiterRMSNormDynamicQuantPattern ¶
Bases: AiterRMSNormQuantPattern
AITER RMSNorm + Dynamic Quantization pattern.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
AiterRMSNormMXFP4QuantPattern ¶
Bases: AiterRMSNormQuantPattern
Fuse AITER rms_norm + dynamic MXFP4 quant into a single kernel.
Matched 2-node subgraph::
torch.ops.vllm_ir.rms_norm(x, weight, eps)
→ torch.ops.vllm.rocm_aiter_dynamic_mxfp4_quant(z)
Replacement: single AITER fused Triton call rocm_aiter_rmsnorm_mxfp4_quant(x, weight, eps).
Registered in :class:RocmAiterRMSNormQuantFusionPass only when rocm_aiter_ops.has_fused_rmsnorm_mxfp4_quant() returns True (i.e. aiter.ops.triton.fused_mxfp4_quant is importable).
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
AiterSiluMulFp8GroupQuantPattern ¶
Bases: VllmPatternReplacement
This pattern fuses aiter silu_and_mul & group fp8 quant custom ops into an aiter silu_and_mul_group_fp8_quant op.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
MLADualRMSNormFusionPass ¶
Bases: VllmFusionPatternMatcherPass
Post-grad PatternMatcher pass that fuses paired q / kv RMS norms in MLA attention into fused_mla_dual_rms_norm backed by aiter's fused_qk_rmsnorm HIP kernel.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
MLADualRMSNormPattern ¶
Bases: VllmPatternReplacement[..., tuple[Tensor, Tensor, Tensor]]
Fuse paired q_a_layernorm + kv_a_layernorm in MLA attention into AITER's fused_qk_rmsnorm HIP kernel.
Target FX-graph pattern (unfused, vllm_ir stage)::
gemm -> split_with_sizes([q_dim, kv_dim])
+-- q_c -> vllm_ir.rms_norm(q_c, q_w, eps)
+-- kv_lora -> split_with_sizes([kv_c_dim, k_pe_dim])
+-- kv_c -> vllm_ir.rms_norm(kv_c, kv_w, eps)
+-- k_pe
The pattern covers the connected subgraph rooted at the first split_with_sizes (which produces q_c and kv_lora), through the two rms_norm calls, and the k_pe passthrough.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 | |
RocmAiterRMSNormQuantFusionPass ¶
Bases: VllmPatternMatcherPass
This pass fuses aiter rms_norm & vllm/aiter quant custom ops into a fused rms_norm_quant op. It also supports fused_add_rms_norm.
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 | |
RocmAiterSiluMulFp8GroupQuantFusionPass ¶
Bases: VllmFusionPatternMatcherPass
This pass fuses a pre-defined set of custom ops into fused ops. It uses the torch pattern matcher to find the patterns and replace them.
Because patterns can only be registered once, the pass is a singleton. This will be addressed in a future version of PyTorch: https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
Source code in vllm/compilation/passes/fusion/rocm_aiter_fusion.py
RocmAiterTritonAddRMSNormPadFusionPass ¶
Bases: VllmPatternMatcherPass
This pass replaces an AITER CK RMSNorm + residual add and a pad op with an triton_add_rmsnorm_pad op from AITER.