[Mlir-commits] [mlir] [WIP][AMDGPU] Added support for Sparce WMMA ops (PR #183360)

Ravil Dorozhinskii llvmlistbot at llvm.org
Wed Feb 25 12:50:11 PST 2026


https://github.com/ravil-mobile updated https://github.com/llvm/llvm-project/pull/183360

>From 6f2e2fab0b38e534c0bd4f19157a82dd5a045436 Mon Sep 17 00:00:00 2001
From: ravil-mobile <ravil.aviva.com at gmail.com>
Date: Wed, 25 Feb 2026 18:32:33 +0000
Subject: [PATCH 1/2] [WIP][AMDGPU] Added support for Sparce WMMA ops

---
 .../mlir/Dialect/AMDGPU/IR/AMDGPUOps.td       | 93 ++++++++++++++++++
 .../AMDGPUToROCDL/swmmac-gfx12.mlir           | 51 ++++++++++
 .../AMDGPUToROCDL/swmmac-gfx1250.mlir         | 96 +++++++++++++++++++
 3 files changed, 240 insertions(+)
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir
 create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
index bc88877247546..5eeb2b0dc856c 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
@@ -1149,6 +1149,99 @@ def AMDGPU_SparseMFMAOp :
   let hasVerifier = 1;
 }
 
+// sparse_wmma (swmmac)
+def SWMMACSparseInTypes : AnyTypeOf<[
+    VectorOfLengthAndType<[4, 8, 16], [F16]>,
+    VectorOfLengthAndType<[4, 8, 16], [BF16]>,
+    VectorOfLengthAndType<[4, 8, 32], [I8]>,
+    VectorOfLengthAndType<[4, 8, 16, 32], [F8E4M3FN, F8E5M2]>,
+    VectorOfLengthAndType<[4, 8, 16, 32], [F8E4M3FNUZ, F8E5M2FNUZ]>
+]>;
+
+def SWMMACDenseInTypes : AnyTypeOf<[
+    VectorOfLengthAndType<[8, 16, 32], [F16]>,
+    VectorOfLengthAndType<[8, 16, 32], [BF16]>,
+    VectorOfLengthAndType<[4, 8, 16, 64], [I8]>,
+    VectorOfLengthAndType<[4, 8, 16, 64], [F8E4M3FN, F8E5M2]>,
+    VectorOfLengthAndType<[4, 8, 16, 64], [F8E4M3FNUZ, F8E5M2FNUZ]>
+]>;
+
+def SWMMACOutTypes : AnyTypeOf<[
+    VectorOfLengthAndType<[4, 8, 16], [F32]>,
+    VectorOfLengthAndType<[4, 8], [F16]>,
+    VectorOfLengthAndType<[4, 8], [BF16]>,
+    VectorOfLengthAndType<[4, 8], [I32]>
+]>;
+
+def SWMMACIdxTypes : AnyTypeOf<[
+    FixedVectorOfLengthAndType<[4], [I8]>,
+]>;
+
+
+def AMDGPU_SparseWMMAOp :
+    AMDGPU_Op<"sparse_wmma", [AllTypesMatch<["destC", "destD"]>,
+                              Pure]>,
+    Arguments<(ins
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[32, 64, 128]>]>:$k,
+                   SWMMACSparseInTypes:$sourceA,
+                   SWMMACDenseInTypes:$sourceB,
+                   SWMMACOutTypes:$destC,
+                   SWMMACIdxTypes:$sparseIdx,
+                   UnitAttr:$unsignedA,
+                   UnitAttr:$unsignedB,
+                   UnitAttr:$reuseA,
+                   UnitAttr:$reuseB,
+                   UnitAttr:$clamp)>,
+    Results<(outs SWMMACOutTypes: $destD)> {
+  let summary = "MLIR wrapper for CDNA sparse mfma (smfmac) instructions";
+  let description = [{
+    The `amdgpu.sparse_wmma` op is an MLIR wrapper around intrinsics for various
+    `swmmac` instructions in the AMDGPU architecture, which perform matrix
+    multiply-accumulate operations using 2:4 structured sparsity on matrix A
+    with dense matrices B, C, and D.
+
+    On gfx12, swmmac intrinsics support:
+      - M=N=16, K=32 and M=N=32, K=16 for f16, bf16, i8 and i4 sources
+      - M=N=16, K=64 for i4 sources
+
+    On gfx1250, swmmac intrinsics additionally support:
+      - M=N=16, K=64 for f16 and bf16 sources
+      - M=N=16, K=128 for f16, bf16 and i8 sources
+
+    The `sparseIdx` parameter contains packed indices identifying the positions
+    of non-zero elements in the 2:4 sparse matrix A. For 16-bit source data,
+    use `vector<4xi8>` (four 8-bit indices). For 8-bit source data, use
+    `vector<2xi16>` (two 16-bit indices).
+
+    `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
+
+    The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
+    in case of overflow.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.sparse_wmma 16x16x32 %matA * %matB + %matC sparse(%idx : vector<4xi8>)
+        : vector<4xf16>, vector<8xf16>, vector<4xf32>
+
+      %1 = amdgpu.sparse_wmma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>)
+        : vector<8xi8>, vector<16xi8>, vector<4xi32>
+
+      %2 = amdgpu.sparse_wmma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>)
+        { unsignedA = 0 : i1, unsignedB = 1 : i1, clamp = 0 : i1 }
+        : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
+    ```
+  }];
+  let assemblyFormat = [{
+    custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
+    `sparse` `(` $sparseIdx `:` type($sparseIdx) `)`
+    attr-dict
+    `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+  }];
+  let hasVerifier = 0;
+}
+
 def AMDGPU_GatherToLDSOp :
     AMDGPU_Op<"gather_to_lds", [AttrSizedOperandSegments]>,
     Arguments<(ins
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir
new file mode 100644
index 0000000000000..cef8273e9e707
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1200 --split-input-file --verify-diagnostics | FileCheck %s
+
+
+func.func @rocdl.swmmac(
+  %v64i8 : vector<64xi8>, %v64f8 : vector<64xf8E4M3FN>, %v64bf8 : vector<64xf8E5M2>,
+  %v32f16 : vector<32xf16>, %v32bf16 : vector<32xbf16>, %v32i8 : vector<32xi8>, %v32f8 : vector<32xf8E4M3FN>, %v32bf8 : vector<32xf8E5M2>,
+  %v16f16 : vector<16xf16>, %v16bf16 : vector<16xbf16>, %v16i32 : vector<16xi32>, %v16i16 : vector<16xi16>, %v16i8 : vector<16xi8>,
+  %v16f8 : vector<16xf8E4M3FN>, %v16bf8 : vector<16xf8E5M2>,
+  %v8f32 : vector<8xf32>, %v8i32 : vector<8xi32>, %v8f16 : vector<8xf16>, %v8bf16 : vector<8xbf16>, %v8i16 : vector<8xi16>,  %v8i8 : vector<8xi8>,
+  %v8f8 : vector<8xf8E4M3FN>, %v8bf8 : vector<8xf8E5M2>,
+  %v4f32 : vector<4xf32>, %v4f16 : vector<4xf16>, %v4bf16 : vector<4xbf16>, %v4i32 : vector<4xi32>, %v4i16 : vector<4xi16>, %v4i8 : vector<4xi8>,
+  %v4f8 : vector<4xf8E4M3FN>, %v4bf8 : vector<4xf8E5M2>,
+  %v2i32 : vector<2xi32>, %v1i32 : i32, %idx : vector<4xi8>) {
+
+  // ---- Wave32 -----
+
+  // CHECK: rocdl.swmmac.f32.16x16x32.f16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<8xf16>, vector<16xf16>, vector<8xf32>, i32) -> vector<8xf32>
+  %w32_0 = amdgpu.sparse_wmma 16x16x32 %v8f16 * %v16f16 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xf16>, vector<16xf16>, vector<8xf32>
+  
+  // CHECK: rocdl.swmmac.f32.16x16x32.bf16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<8xi16>, vector<16xi16>, vector<8xf32>, i32) -> vector<8xf32>
+  %w32_1 = amdgpu.sparse_wmma 16x16x32 %v8bf16 * %v16bf16 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xbf16>, vector<16xbf16>, vector<8xf32>
+
+  // CHECK: rocdl.swmmac.f16.16x16x32.f16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<8xf16>, vector<16xf16>, vector<8xf16>, i32) -> vector<8xf16>
+  %w32_2 = amdgpu.sparse_wmma 16x16x32 %v8bf16 * %v16bf16 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xbf16>, vector<16xbf16>, vector<8xf32>
+
+  // CHECK: rocdl.swmmac.bf16.16x16x32.bf16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<8xi16>, vector<16xi16>, vector<8xi16>, i32) -> vector<8xi16>
+  %w32_3 = amdgpu.sparse_wmma 16x16x32 %v8bf16 * %v16bf16 + %v8bf16 sparse(%idx : vector<4xi8>) : vector<8xbf16>, vector<16xbf16>, vector<8xbf16>
+
+  // CHECK: rocdl.swmmac.i32.16x16x32.iu8 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (vector<2xi32>, vector<4xi32>, vector<8xi32>, i32) -> vector<8xi32>
+  %w32_4 = amdgpu.sparse_wmma 16x16x32 %v8i8 * %v16i8 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<8xi8>, vector<16xi8>, vector<8xi32>
+
+  // CHECK: rocdl.swmmac.i32.16x16x32.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (i32, vector<2xi32>, vector<8xi32>, i32) -> vector<8xi32>
+  %w32_5 = amdgpu.sparse_wmma 16x16x32 %v4i8 * %v8i8 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<8xi8>, vector<8xi32>
+
+  // CHECK: rocdl.swmmac.i32.16x16x64.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (vector<2xi32>, vector<4xi32>, vector<8xi32>, i32) -> vector<8xi32>
+  %w32_6 = amdgpu.sparse_wmma 16x16x64 %v8i8 * %v16i8 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<8xi8>, vector<16xi8>, vector<8xi32>
+
+  // CHECK: rocdl.swmmac.f32.16x16x32.fp8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<2xi32>, vector<4xi32>, vector<8xf32>, i32) -> vector<8xf32>
+  %w32_7 = amdgpu.sparse_wmma 16x16x32 %v8f8 * %v16f8 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xf8E4M3FN>, vector<16xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.swmmac.f32.16x16x32.fp8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<2xi32>, vector<4xi32>, vector<8xf32>, i32) -> vector<8xf32>
+  %w32_8 = amdgpu.sparse_wmma 16x16x32 %v8f8 * %v16bf8 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xf8E4M3FN>, vector<16xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.swmmac.f32.16x16x32.bf8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<2xi32>, vector<4xi32>, vector<8xf32>, i32) -> vector<8xf32>
+  %w32_9 = amdgpu.sparse_wmma 16x16x32 %v8bf8 * %v16f8 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xf8E5M2>, vector<16xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.swmmac.f32.16x16x32.bf8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<2xi32>, vector<4xi32>, vector<8xf32>, i32) -> vector<8xf32>
+  %w32_10 = amdgpu.sparse_wmma 16x16x32 %v8bf8 * %v16bf8 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xf8E5M2>, vector<16xf8E5M2>, vector<8xf32>
+
+  func.return
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir
new file mode 100644
index 0000000000000..da903452c7d51
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir
@@ -0,0 +1,96 @@
+
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func @rocdl.swmmac(
+  %v64i8 : vector<64xi8>, %v64f8 : vector<64xf8E4M3FN>, %v64bf8 : vector<64xf8E5M2>,
+  %v32f16 : vector<32xf16>, %v32bf16 : vector<32xbf16>, %v32i8 : vector<32xi8>, %v32f8 : vector<32xf8E4M3FN>, %v32bf8 : vector<32xf8E5M2>,
+  %v16f16 : vector<16xf16>, %v16bf16 : vector<16xbf16>, %v16i32 : vector<16xi32>, %v16i16 : vector<16xi16>, %v16i8 : vector<16xi8>,
+  %v16f8 : vector<16xf8E4M3FN>, %v16bf8 : vector<16xf8E5M2>,
+  %v8f32 : vector<8xf32>, %v8i32 : vector<8xi32>, %v8f16 : vector<8xf16>, %v8bf16 : vector<8xbf16>, %v8i16 : vector<8xi16>,  %v8i8 : vector<8xi8>,
+  %v8f8 : vector<8xf8E4M3FN>, %v8bf8 : vector<8xf8E5M2>,
+  %v4f32 : vector<4xf32>, %v4f16 : vector<4xf16>, %v4bf16 : vector<4xbf16>, %v4i32 : vector<4xi32>, %v4i16 : vector<4xi16>, %v4i8 : vector<4xi8>,
+  %v4f8 : vector<4xf8E4M3FN>, %v4bf8 : vector<4xf8E5M2>,
+  %v2i32 : vector<2xi32>, %v1i32 : i32, %idx : vector<4xi8>) {
+
+  // ---- Wave32 -----
+
+  // CHECK: rocdl.swmmac.f32.16x16x64.f16 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, reuseA = true, reuseB = true} : (vector<16xf16>, vector<32xf16>, vector<8xf32>, i32) -> vector<8xf32>
+  %w32_11 = amdgpu.sparse_wmma 16x16x64 %v16f16 * %v32f16 + %v8f32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, reuseA, reuseB} : vector<16xf16>, vector<32xf16>, vector<8xf32>
+
+  // CHECK: rocdl.swmmac.f32.16x16x64.bf16 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, reuseA = true, reuseB = true} : (vector<16xbf16>, vector<32xbf16>, vector<8xf32>, i32) -> vector<8xf32>
+  %w32_12 = amdgpu.sparse_wmma 16x16x64 %v16bf16 * %v32bf16 + %v8f32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, reuseA, reuseB} : vector<16xbf16>, vector<32xbf16>, vector<8xf32>
+
+  // CHECK: rocdl.swmmac.f16.16x16x64.f16 %v{{.*}}, %v{{.*}}, %v{{.*}}, %index {signA = true, signB = true, reuseA = true, reuseB = true} : (vector<16xf16>, vector<32xf16>, vector<8xf16>, i32) -> vector<8xf16>
+  %w32_13 = amdgpu.sparse_wmma 16x16x64 %v16f16 * %v32f16 + %v8f16 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, reuseA, reuseB} : vector<16xf16>, vector<32xf16>, vector<8xf16>
+
+  // CHECK: rocdl.swmmac.bf16.16x16x64.bf16 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, reuseA = true, reuseB = true} : (vector<16xbf16>, vector<32xbf16>, vector<8xbf16>, i32) -> vector<8xbf16>
+  %w32_14 = amdgpu.sparse_wmma 16x16x64 %v16bf16 * %v32bf16 + %v8bf16 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, reuseA, reuseB} : vector<16xbf16>, vector<32xbf16>, vector<8xbf16>
+ 
+  // CHECK:
+   // f32 output ?? %15 = amdgpu.sparse_wmma 16x16x64 %v16bf16 * %v32bf16 + %v8bf16 sparse(%idx : vector<4xi8>) {signA = false, signB = false} : vector<16xbf16>, vector<32xbf16>, vector<8xbf16>
+//  %w32_15 = rocdl.swmmac.bf16f32.16x16x64.bf16 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = false, signB = false, reuseA = false, reuseB = false} : (vector<16xbf16>, vector<32xbf16>, vector<8xbf16>, i32) -> vector<8xbf16>
+
+  // CHECK: rocdl.swmmac.f32.16x16x128.fp8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf32>, i32) -> vector<8xf32>
+  %w32_16 = amdgpu.sparse_wmma 16x16x128 %v32f8 * %v64f8 + %v8f32 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.swmmac.f32.16x16x128.fp8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf32>, i32) -> vector<8xf32>
+  %w32_17 = amdgpu.sparse_wmma 16x16x128 %v32f8 * %v64bf8 + %v8f32 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.swmmac.f32.16x16x128.bf8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf32>, i32) -> vector<8xf32>
+  %w32_18 = amdgpu.sparse_wmma 16x16x128 %v32bf8 * %v64f8 + %v8f32 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.swmmac.f32.16x16x128.bf8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf32>, i32) -> vector<8xf32>
+  %w32_19 = amdgpu.sparse_wmma 16x16x128 %v32bf8 * %v64bf8 + %v8f32 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E5M2>, vector<64xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.swmmac.f16.16x16x128.fp8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf16>, i32) -> vector<8xf16>
+  %w32_20 = amdgpu.sparse_wmma 16x16x128 %v32f8 * %v64f8 + %v8f16 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>
+
+  // CHECK: rocdl.swmmac.f16.16x16x128.fp8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf16>, i32) -> vector<8xf16>
+  %w32_21 = amdgpu.sparse_wmma 16x16x128 %v32f8 * %v64bf8 + %v8f16 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.swmmac.f16.16x16x128.bf8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf16>, i32) -> vector<8xf16>
+  %w32_22 = amdgpu.sparse_wmma 16x16x128 %v32bf8 * %v64f8 + %v8f16 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>
+
+  // CHECK: rocdl.swmmac.f16.16x16x128.bf8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf16>, i32) -> vector<8xf16>
+  %w32_23 = amdgpu.sparse_wmma 16x16x128 %v32bf8 * %v64bf8 + %v8f16 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E5M2>, vector<64xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.swmmac.i32.16x16x128.iu8 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, reuseA = true, reuseB = true, clamp = true} : (vector<8xi32>, vector<16xi32>, vector<8xi32>, i32) -> vector<8xi32>
+  %w32_24 = amdgpu.sparse_wmma 16x16x128 %v32i8 * %v64i8 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, reuseA, reuseB, clamp} : vector<32xi8>, vector<64xi8>, vector<8xi32>
+
+  // ---- Wave64 -----
+
+  // CHECK: rocdl.swmmac.f32.16x16x32.f16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32) -> vector<4xf32>
+  %w64_0 = amdgpu.sparse_wmma 16x16x32 %v4f16 * %v8f16 + %v4f32 sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xf16>, vector<4xf32>
+
+  // CHECK: rocdl.swmmac.f32.16x16x32.bf16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32) -> vector<4xf32>
+  %w64_1 = amdgpu.sparse_wmma 16x16x32 %v4bf16 * %v8bf16 + %v4f32 sparse(%idx : vector<4xi8>) : vector<4xbf16>, vector<8xbf16>, vector<4xf32>
+
+  // CHECK: rocdl.swmmac.f16.16x16x32.f16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<4xf16>, vector<8xf16>, vector<4xf16>, i32) -> vector<4xf16>
+  %w64_2 = amdgpu.sparse_wmma 16x16x32 %v4f16 * %v8f16 + %v4f16 sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xf16>, vector<4xf16>
+
+  // CHECK: rocdl.swmmac.bf16.16x16x32.bf16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<4xi16>, vector<8xi16>, vector<4xi16>, i32) -> vector<4xi16>
+  %w64_3 = amdgpu.sparse_wmma 16x16x32 %v4bf16 * %v8bf16 + %v4bf16 sparse(%idx : vector<4xi8>) : vector<4xbf16>, vector<8xbf16>, vector<4xbf16>
+
+  // CHECK: rocdl.swmmac.i32.16x16x32.iu8 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (i32, vector<2xi32>, vector<4xi32>, i32) -> vector<4xi32>
+  %w64_4 = amdgpu.sparse_wmma 16x16x32 %v4i8 * %v8i8 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<8xi8>, vector<4xi32>
+
+  // CHECK: rocdl.swmmac.i32.16x16x32.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (i32, i32, vector<4xi32>, i32) -> vector<4xi32>
+  %w64_5 = amdgpu.sparse_wmma 16x16x32 %v4i8 * %v4i8 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<4xi8>, vector<4xi32>
+
+  // CHECK: rocdl.swmmac.i32.16x16x64.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (i32, vector<2xi32>, vector<4xi32>, i32) -> vector<4xi32>
+  %w64_6 = amdgpu.sparse_wmma 16x16x64 %v4i8 * %v16i8 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<16xi8>, vector<4xi32>
+
+  // CHECK: rocdl.swmmac.f32.16x16x32.fp8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (i32, vector<2xi32>, vector<4xf32>, i32) -> vector<4xf32>
+  %w64_7 = amdgpu.sparse_wmma 16x16x32 %v4f8 * %v16f8 + %v4f32 sparse(%idx : vector<4xi8>) : vector<4xf8E4M3FN>, vector<16xf8E4M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.swmmac.f32.16x16x32.fp8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (i32, vector<2xi32>, vector<4xf32>, i32) -> vector<4xf32>
+  %w64_8 = amdgpu.sparse_wmma 16x16x32 %v4f8 * %v16bf8 + %v4f32 sparse(%idx : vector<4xi8>) : vector<4xf8E4M3FN>, vector<16xf8E5M2>, vector<4xf32>
+
+  // CHECK: rocdl.swmmac.f32.16x16x32.bf8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (i32, vector<2xi32>, vector<4xf32>, i32) -> vector<4xf32>
+  %w64_9 = amdgpu.sparse_wmma 16x16x32 %v4bf8 * %v16f8 + %v4f32 sparse(%idx : vector<4xi8>) : vector<4xf8E5M2>, vector<16xf8E4M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.swmmac.f32.16x16x32.bf8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (i32, vector<2xi32>, vector<4xf32>, i32) -> vector<4xf32>
+  %w64_10 = amdgpu.sparse_wmma 16x16x32 %v4bf8 * %v16bf8 + %v4f32 sparse(%idx : vector<4xi8>) : vector<4xf8E5M2>, vector<16xf8E5M2>, vector<4xf32>
+
+  func.return
+}

>From bce23c190727961ea03e790935c860b72eea58e5 Mon Sep 17 00:00:00 2001
From: ravil-mobile <ravil.aviva.com at gmail.com>
Date: Wed, 25 Feb 2026 19:24:33 +0000
Subject: [PATCH 2/2] [WIP][AMDGPU] Added placeholders for verifier and rewrite
 for swmmac

---
 .../mlir/Dialect/AMDGPU/IR/AMDGPUOps.td       |  4 +++-
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           | 14 ++++++++++++++
 mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp      | 10 ++++++++++
 .../AMDGPUToROCDL/swmmac-gfx12.mlir           | 19 ++++++++-----------
 .../AMDGPUToROCDL/swmmac-gfx1250.mlir         | 16 +++++++---------
 5 files changed, 42 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
index 5eeb2b0dc856c..d00aa00a11257 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
@@ -1154,6 +1154,7 @@ def SWMMACSparseInTypes : AnyTypeOf<[
     VectorOfLengthAndType<[4, 8, 16], [F16]>,
     VectorOfLengthAndType<[4, 8, 16], [BF16]>,
     VectorOfLengthAndType<[4, 8, 32], [I8]>,
+    VectorOfLengthAndType<[8, 16], [I<4>]>,
     VectorOfLengthAndType<[4, 8, 16, 32], [F8E4M3FN, F8E5M2]>,
     VectorOfLengthAndType<[4, 8, 16, 32], [F8E4M3FNUZ, F8E5M2FNUZ]>
 ]>;
@@ -1162,6 +1163,7 @@ def SWMMACDenseInTypes : AnyTypeOf<[
     VectorOfLengthAndType<[8, 16, 32], [F16]>,
     VectorOfLengthAndType<[8, 16, 32], [BF16]>,
     VectorOfLengthAndType<[4, 8, 16, 64], [I8]>,
+    VectorOfLengthAndType<[8, 16, 32], [I<4>]>,
     VectorOfLengthAndType<[4, 8, 16, 64], [F8E4M3FN, F8E5M2]>,
     VectorOfLengthAndType<[4, 8, 16, 64], [F8E4M3FNUZ, F8E5M2FNUZ]>
 ]>;
@@ -1239,7 +1241,7 @@ def AMDGPU_SparseWMMAOp :
     attr-dict
     `:` type($sourceA) `,` type($sourceB) `,` type($destC)
   }];
-  let hasVerifier = 0;
+  let hasVerifier = 1;
 }
 
 def AMDGPU_GatherToLDSOp :
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3c2c61b2426e9..6f5520e966179 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1592,6 +1592,20 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct SparseWMMAOpLowering : public ConvertOpToLLVMPattern<SparseWMMAOp> {
+  SparseWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<SparseWMMAOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(SparseWMMAOp op, SparseWMMAOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // TODO (Ravil)
+    return success();
+  }
+};
+
 struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
   ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
index f452d2de15dc8..8d4ec6271842c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
@@ -670,6 +670,16 @@ LogicalResult SparseMFMAOp::verify() {
   return success();
 }
 
+
+//===----------------------------------------------------------------------===//
+// SparseWMMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SparseWMMAOp::verify() {
+  // TODO (Ravil)
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // DPPOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir
index cef8273e9e707..4add3ce4bfd6b 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir
@@ -2,15 +2,12 @@
 
 
 func.func @rocdl.swmmac(
-  %v64i8 : vector<64xi8>, %v64f8 : vector<64xf8E4M3FN>, %v64bf8 : vector<64xf8E5M2>,
-  %v32f16 : vector<32xf16>, %v32bf16 : vector<32xbf16>, %v32i8 : vector<32xi8>, %v32f8 : vector<32xf8E4M3FN>, %v32bf8 : vector<32xf8E5M2>,
-  %v16f16 : vector<16xf16>, %v16bf16 : vector<16xbf16>, %v16i32 : vector<16xi32>, %v16i16 : vector<16xi16>, %v16i8 : vector<16xi8>,
-  %v16f8 : vector<16xf8E4M3FN>, %v16bf8 : vector<16xf8E5M2>,
-  %v8f32 : vector<8xf32>, %v8i32 : vector<8xi32>, %v8f16 : vector<8xf16>, %v8bf16 : vector<8xbf16>, %v8i16 : vector<8xi16>,  %v8i8 : vector<8xi8>,
-  %v8f8 : vector<8xf8E4M3FN>, %v8bf8 : vector<8xf8E5M2>,
-  %v4f32 : vector<4xf32>, %v4f16 : vector<4xf16>, %v4bf16 : vector<4xbf16>, %v4i32 : vector<4xi32>, %v4i16 : vector<4xi16>, %v4i8 : vector<4xi8>,
-  %v4f8 : vector<4xf8E4M3FN>, %v4bf8 : vector<4xf8E5M2>,
-  %v2i32 : vector<2xi32>, %v1i32 : i32, %idx : vector<4xi8>) {
+  %v32i4 : vector<32xi4>,
+  %v16f16 : vector<16xf16>, %v16bf16 : vector<16xbf16>, %v16i8 : vector<16xi8>,
+  %v16i4 : vector<16xi4>, %v16f8 : vector<16xf8E4M3FN>, %v16bf8 : vector<16xf8E5M2>,
+  %v8f32 : vector<8xf32>, %v8i32 : vector<8xi32>, %v8f16 : vector<8xf16>, %v8bf16 : vector<8xbf16>, %v8i8 : vector<8xi8>,
+  %v8i4 : vector<8xi4>, %v8f8 : vector<8xf8E4M3FN>, %v8bf8 : vector<8xf8E5M2>,
+  %idx : vector<4xi8>) {
 
   // ---- Wave32 -----
 
@@ -30,10 +27,10 @@ func.func @rocdl.swmmac(
   %w32_4 = amdgpu.sparse_wmma 16x16x32 %v8i8 * %v16i8 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<8xi8>, vector<16xi8>, vector<8xi32>
 
   // CHECK: rocdl.swmmac.i32.16x16x32.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (i32, vector<2xi32>, vector<8xi32>, i32) -> vector<8xi32>
-  %w32_5 = amdgpu.sparse_wmma 16x16x32 %v4i8 * %v8i8 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<8xi8>, vector<8xi32>
+  %w32_5 = amdgpu.sparse_wmma 16x16x32 %v8i4 * %v16i4 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<8xi4>, vector<16xi4>, vector<8xi32>
 
   // CHECK: rocdl.swmmac.i32.16x16x64.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (vector<2xi32>, vector<4xi32>, vector<8xi32>, i32) -> vector<8xi32>
-  %w32_6 = amdgpu.sparse_wmma 16x16x64 %v8i8 * %v16i8 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<8xi8>, vector<16xi8>, vector<8xi32>
+  %w32_6 = amdgpu.sparse_wmma 16x16x64 %v16i4 * %v32i4 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<16xi4>, vector<32xi4>, vector<8xi32>
 
   // CHECK: rocdl.swmmac.f32.16x16x32.fp8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<2xi32>, vector<4xi32>, vector<8xf32>, i32) -> vector<8xf32>
   %w32_7 = amdgpu.sparse_wmma 16x16x32 %v8f8 * %v16f8 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xf8E4M3FN>, vector<16xf8E4M3FN>, vector<8xf32>
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir
index da903452c7d51..aa33595310f10 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir
@@ -3,14 +3,12 @@
 
 func.func @rocdl.swmmac(
   %v64i8 : vector<64xi8>, %v64f8 : vector<64xf8E4M3FN>, %v64bf8 : vector<64xf8E5M2>,
-  %v32f16 : vector<32xf16>, %v32bf16 : vector<32xbf16>, %v32i8 : vector<32xi8>, %v32f8 : vector<32xf8E4M3FN>, %v32bf8 : vector<32xf8E5M2>,
-  %v16f16 : vector<16xf16>, %v16bf16 : vector<16xbf16>, %v16i32 : vector<16xi32>, %v16i16 : vector<16xi16>, %v16i8 : vector<16xi8>,
-  %v16f8 : vector<16xf8E4M3FN>, %v16bf8 : vector<16xf8E5M2>,
-  %v8f32 : vector<8xf32>, %v8i32 : vector<8xi32>, %v8f16 : vector<8xf16>, %v8bf16 : vector<8xbf16>, %v8i16 : vector<8xi16>,  %v8i8 : vector<8xi8>,
-  %v8f8 : vector<8xf8E4M3FN>, %v8bf8 : vector<8xf8E5M2>,
-  %v4f32 : vector<4xf32>, %v4f16 : vector<4xf16>, %v4bf16 : vector<4xbf16>, %v4i32 : vector<4xi32>, %v4i16 : vector<4xi16>, %v4i8 : vector<4xi8>,
+  %v32f16 : vector<32xf16>, %v32bf16 : vector<32xbf16>, %v32i8 : vector<32xi8>, %v32i4 : vector<32xi4>, %v32f8 : vector<32xf8E4M3FN>, %v32bf8 : vector<32xf8E5M2>,
+  %v16f16 : vector<16xf16>, %v16bf16 : vector<16xbf16>, %v16f8 : vector<16xf8E4M3FN>, %v16bf8 : vector<16xf8E5M2>,
+  %v8f32 : vector<8xf32>, %v8i32 : vector<8xi32>, %v8f16 : vector<8xf16>, %v8bf16 : vector<8xbf16>, %v8i8 : vector<8xi8>, %v8i4 : vector<8xi4>,
+  %v4f32 : vector<4xf32>, %v4f16 : vector<4xf16>, %v4bf16 : vector<4xbf16>, %v4i32 : vector<4xi32>, %v4i8 : vector<4xi8>,
   %v4f8 : vector<4xf8E4M3FN>, %v4bf8 : vector<4xf8E5M2>,
-  %v2i32 : vector<2xi32>, %v1i32 : i32, %idx : vector<4xi8>) {
+  %idx : vector<4xi8>) {
 
   // ---- Wave32 -----
 
@@ -75,10 +73,10 @@ func.func @rocdl.swmmac(
   %w64_4 = amdgpu.sparse_wmma 16x16x32 %v4i8 * %v8i8 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<8xi8>, vector<4xi32>
 
   // CHECK: rocdl.swmmac.i32.16x16x32.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (i32, i32, vector<4xi32>, i32) -> vector<4xi32>
-  %w64_5 = amdgpu.sparse_wmma 16x16x32 %v4i8 * %v4i8 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<4xi8>, vector<4xi32>
+  %w64_5 = amdgpu.sparse_wmma 16x16x32 %v8i4 * %v8i4 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32>
 
   // CHECK: rocdl.swmmac.i32.16x16x64.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (i32, vector<2xi32>, vector<4xi32>, i32) -> vector<4xi32>
-  %w64_6 = amdgpu.sparse_wmma 16x16x64 %v4i8 * %v16i8 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<16xi8>, vector<4xi32>
+  %w64_6 = amdgpu.sparse_wmma 16x16x64 %v8i4 * %v32i4 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<8xi4>, vector<32xi4>, vector<4xi32>
 
   // CHECK: rocdl.swmmac.f32.16x16x32.fp8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (i32, vector<2xi32>, vector<4xf32>, i32) -> vector<4xf32>
   %w64_7 = amdgpu.sparse_wmma 16x16x32 %v4f8 * %v16f8 + %v4f32 sparse(%idx : vector<4xi8>) : vector<4xf8E4M3FN>, vector<16xf8E4M3FN>, vector<4xf32>



More information about the Mlir-commits mailing list