[Mlir-commits] [mlir] [WIP][AMDGPU] Added support for Sparce WMMA ops (PR #183360)
Ravil Dorozhinskii
llvmlistbot at llvm.org
Wed Feb 25 12:50:11 PST 2026
https://github.com/ravil-mobile updated https://github.com/llvm/llvm-project/pull/183360
>From 6f2e2fab0b38e534c0bd4f19157a82dd5a045436 Mon Sep 17 00:00:00 2001
From: ravil-mobile <ravil.aviva.com at gmail.com>
Date: Wed, 25 Feb 2026 18:32:33 +0000
Subject: [PATCH 1/2] [WIP][AMDGPU] Added support for Sparce WMMA ops
---
.../mlir/Dialect/AMDGPU/IR/AMDGPUOps.td | 93 ++++++++++++++++++
.../AMDGPUToROCDL/swmmac-gfx12.mlir | 51 ++++++++++
.../AMDGPUToROCDL/swmmac-gfx1250.mlir | 96 +++++++++++++++++++
3 files changed, 240 insertions(+)
create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir
create mode 100644 mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
index bc88877247546..5eeb2b0dc856c 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
@@ -1149,6 +1149,99 @@ def AMDGPU_SparseMFMAOp :
let hasVerifier = 1;
}
+// sparse_wmma (swmmac)
+def SWMMACSparseInTypes : AnyTypeOf<[
+ VectorOfLengthAndType<[4, 8, 16], [F16]>,
+ VectorOfLengthAndType<[4, 8, 16], [BF16]>,
+ VectorOfLengthAndType<[4, 8, 32], [I8]>,
+ VectorOfLengthAndType<[4, 8, 16, 32], [F8E4M3FN, F8E5M2]>,
+ VectorOfLengthAndType<[4, 8, 16, 32], [F8E4M3FNUZ, F8E5M2FNUZ]>
+]>;
+
+def SWMMACDenseInTypes : AnyTypeOf<[
+ VectorOfLengthAndType<[8, 16, 32], [F16]>,
+ VectorOfLengthAndType<[8, 16, 32], [BF16]>,
+ VectorOfLengthAndType<[4, 8, 16, 64], [I8]>,
+ VectorOfLengthAndType<[4, 8, 16, 64], [F8E4M3FN, F8E5M2]>,
+ VectorOfLengthAndType<[4, 8, 16, 64], [F8E4M3FNUZ, F8E5M2FNUZ]>
+]>;
+
+def SWMMACOutTypes : AnyTypeOf<[
+ VectorOfLengthAndType<[4, 8, 16], [F32]>,
+ VectorOfLengthAndType<[4, 8], [F16]>,
+ VectorOfLengthAndType<[4, 8], [BF16]>,
+ VectorOfLengthAndType<[4, 8], [I32]>
+]>;
+
+def SWMMACIdxTypes : AnyTypeOf<[
+ FixedVectorOfLengthAndType<[4], [I8]>,
+]>;
+
+
+def AMDGPU_SparseWMMAOp :
+ AMDGPU_Op<"sparse_wmma", [AllTypesMatch<["destC", "destD"]>,
+ Pure]>,
+ Arguments<(ins
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[32, 64, 128]>]>:$k,
+ SWMMACSparseInTypes:$sourceA,
+ SWMMACDenseInTypes:$sourceB,
+ SWMMACOutTypes:$destC,
+ SWMMACIdxTypes:$sparseIdx,
+ UnitAttr:$unsignedA,
+ UnitAttr:$unsignedB,
+ UnitAttr:$reuseA,
+ UnitAttr:$reuseB,
+ UnitAttr:$clamp)>,
+ Results<(outs SWMMACOutTypes: $destD)> {
+ let summary = "MLIR wrapper for CDNA sparse mfma (smfmac) instructions";
+ let description = [{
+ The `amdgpu.sparse_wmma` op is an MLIR wrapper around intrinsics for various
+ `swmmac` instructions in the AMDGPU architecture, which perform matrix
+ multiply-accumulate operations using 2:4 structured sparsity on matrix A
+ with dense matrices B, C, and D.
+
+ On gfx12, swmmac intrinsics support:
+ - M=N=16, K=32 and M=N=32, K=16 for f16, bf16, i8 and i4 sources
+ - M=N=16, K=64 for i4 sources
+
+ On gfx1250, swmmac intrinsics additionally support:
+ - M=N=16, K=64 for f16 and bf16 sources
+ - M=N=16, K=128 for f16, bf16 and i8 sources
+
+ The `sparseIdx` parameter contains packed indices identifying the positions
+ of non-zero elements in the 2:4 sparse matrix A. For 16-bit source data,
+ use `vector<4xi8>` (four 8-bit indices). For 8-bit source data, use
+ `vector<2xi16>` (two 16-bit indices).
+
+ `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
+
+ The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
+ in case of overflow.
+
+ Example:
+ ```mlir
+ %0 = amdgpu.sparse_wmma 16x16x32 %matA * %matB + %matC sparse(%idx : vector<4xi8>)
+ : vector<4xf16>, vector<8xf16>, vector<4xf32>
+
+ %1 = amdgpu.sparse_wmma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>)
+ : vector<8xi8>, vector<16xi8>, vector<4xi32>
+
+ %2 = amdgpu.sparse_wmma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>)
+ { unsignedA = 0 : i1, unsignedB = 1 : i1, clamp = 0 : i1 }
+ : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32>
+ ```
+ }];
+ let assemblyFormat = [{
+ custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
+ `sparse` `(` $sparseIdx `:` type($sparseIdx) `)`
+ attr-dict
+ `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+ }];
+ let hasVerifier = 0;
+}
+
def AMDGPU_GatherToLDSOp :
AMDGPU_Op<"gather_to_lds", [AttrSizedOperandSegments]>,
Arguments<(ins
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir
new file mode 100644
index 0000000000000..cef8273e9e707
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1200 --split-input-file --verify-diagnostics | FileCheck %s
+
+
+func.func @rocdl.swmmac(
+ %v64i8 : vector<64xi8>, %v64f8 : vector<64xf8E4M3FN>, %v64bf8 : vector<64xf8E5M2>,
+ %v32f16 : vector<32xf16>, %v32bf16 : vector<32xbf16>, %v32i8 : vector<32xi8>, %v32f8 : vector<32xf8E4M3FN>, %v32bf8 : vector<32xf8E5M2>,
+ %v16f16 : vector<16xf16>, %v16bf16 : vector<16xbf16>, %v16i32 : vector<16xi32>, %v16i16 : vector<16xi16>, %v16i8 : vector<16xi8>,
+ %v16f8 : vector<16xf8E4M3FN>, %v16bf8 : vector<16xf8E5M2>,
+ %v8f32 : vector<8xf32>, %v8i32 : vector<8xi32>, %v8f16 : vector<8xf16>, %v8bf16 : vector<8xbf16>, %v8i16 : vector<8xi16>, %v8i8 : vector<8xi8>,
+ %v8f8 : vector<8xf8E4M3FN>, %v8bf8 : vector<8xf8E5M2>,
+ %v4f32 : vector<4xf32>, %v4f16 : vector<4xf16>, %v4bf16 : vector<4xbf16>, %v4i32 : vector<4xi32>, %v4i16 : vector<4xi16>, %v4i8 : vector<4xi8>,
+ %v4f8 : vector<4xf8E4M3FN>, %v4bf8 : vector<4xf8E5M2>,
+ %v2i32 : vector<2xi32>, %v1i32 : i32, %idx : vector<4xi8>) {
+
+ // ---- Wave32 -----
+
+ // CHECK: rocdl.swmmac.f32.16x16x32.f16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<8xf16>, vector<16xf16>, vector<8xf32>, i32) -> vector<8xf32>
+ %w32_0 = amdgpu.sparse_wmma 16x16x32 %v8f16 * %v16f16 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xf16>, vector<16xf16>, vector<8xf32>
+
+ // CHECK: rocdl.swmmac.f32.16x16x32.bf16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<8xi16>, vector<16xi16>, vector<8xf32>, i32) -> vector<8xf32>
+ %w32_1 = amdgpu.sparse_wmma 16x16x32 %v8bf16 * %v16bf16 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xbf16>, vector<16xbf16>, vector<8xf32>
+
+ // CHECK: rocdl.swmmac.f16.16x16x32.f16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<8xf16>, vector<16xf16>, vector<8xf16>, i32) -> vector<8xf16>
+ %w32_2 = amdgpu.sparse_wmma 16x16x32 %v8bf16 * %v16bf16 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xbf16>, vector<16xbf16>, vector<8xf32>
+
+ // CHECK: rocdl.swmmac.bf16.16x16x32.bf16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<8xi16>, vector<16xi16>, vector<8xi16>, i32) -> vector<8xi16>
+ %w32_3 = amdgpu.sparse_wmma 16x16x32 %v8bf16 * %v16bf16 + %v8bf16 sparse(%idx : vector<4xi8>) : vector<8xbf16>, vector<16xbf16>, vector<8xbf16>
+
+ // CHECK: rocdl.swmmac.i32.16x16x32.iu8 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (vector<2xi32>, vector<4xi32>, vector<8xi32>, i32) -> vector<8xi32>
+ %w32_4 = amdgpu.sparse_wmma 16x16x32 %v8i8 * %v16i8 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<8xi8>, vector<16xi8>, vector<8xi32>
+
+ // CHECK: rocdl.swmmac.i32.16x16x32.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (i32, vector<2xi32>, vector<8xi32>, i32) -> vector<8xi32>
+ %w32_5 = amdgpu.sparse_wmma 16x16x32 %v4i8 * %v8i8 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<8xi8>, vector<8xi32>
+
+ // CHECK: rocdl.swmmac.i32.16x16x64.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (vector<2xi32>, vector<4xi32>, vector<8xi32>, i32) -> vector<8xi32>
+ %w32_6 = amdgpu.sparse_wmma 16x16x64 %v8i8 * %v16i8 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<8xi8>, vector<16xi8>, vector<8xi32>
+
+ // CHECK: rocdl.swmmac.f32.16x16x32.fp8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<2xi32>, vector<4xi32>, vector<8xf32>, i32) -> vector<8xf32>
+ %w32_7 = amdgpu.sparse_wmma 16x16x32 %v8f8 * %v16f8 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xf8E4M3FN>, vector<16xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.swmmac.f32.16x16x32.fp8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<2xi32>, vector<4xi32>, vector<8xf32>, i32) -> vector<8xf32>
+ %w32_8 = amdgpu.sparse_wmma 16x16x32 %v8f8 * %v16bf8 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xf8E4M3FN>, vector<16xf8E5M2>, vector<8xf32>
+
+ // CHECK: rocdl.swmmac.f32.16x16x32.bf8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<2xi32>, vector<4xi32>, vector<8xf32>, i32) -> vector<8xf32>
+ %w32_9 = amdgpu.sparse_wmma 16x16x32 %v8bf8 * %v16f8 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xf8E5M2>, vector<16xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.swmmac.f32.16x16x32.bf8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<2xi32>, vector<4xi32>, vector<8xf32>, i32) -> vector<8xf32>
+ %w32_10 = amdgpu.sparse_wmma 16x16x32 %v8bf8 * %v16bf8 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xf8E5M2>, vector<16xf8E5M2>, vector<8xf32>
+
+ func.return
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir
new file mode 100644
index 0000000000000..da903452c7d51
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir
@@ -0,0 +1,96 @@
+
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --split-input-file --verify-diagnostics | FileCheck %s
+
+func.func @rocdl.swmmac(
+ %v64i8 : vector<64xi8>, %v64f8 : vector<64xf8E4M3FN>, %v64bf8 : vector<64xf8E5M2>,
+ %v32f16 : vector<32xf16>, %v32bf16 : vector<32xbf16>, %v32i8 : vector<32xi8>, %v32f8 : vector<32xf8E4M3FN>, %v32bf8 : vector<32xf8E5M2>,
+ %v16f16 : vector<16xf16>, %v16bf16 : vector<16xbf16>, %v16i32 : vector<16xi32>, %v16i16 : vector<16xi16>, %v16i8 : vector<16xi8>,
+ %v16f8 : vector<16xf8E4M3FN>, %v16bf8 : vector<16xf8E5M2>,
+ %v8f32 : vector<8xf32>, %v8i32 : vector<8xi32>, %v8f16 : vector<8xf16>, %v8bf16 : vector<8xbf16>, %v8i16 : vector<8xi16>, %v8i8 : vector<8xi8>,
+ %v8f8 : vector<8xf8E4M3FN>, %v8bf8 : vector<8xf8E5M2>,
+ %v4f32 : vector<4xf32>, %v4f16 : vector<4xf16>, %v4bf16 : vector<4xbf16>, %v4i32 : vector<4xi32>, %v4i16 : vector<4xi16>, %v4i8 : vector<4xi8>,
+ %v4f8 : vector<4xf8E4M3FN>, %v4bf8 : vector<4xf8E5M2>,
+ %v2i32 : vector<2xi32>, %v1i32 : i32, %idx : vector<4xi8>) {
+
+ // ---- Wave32 -----
+
+ // CHECK: rocdl.swmmac.f32.16x16x64.f16 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, reuseA = true, reuseB = true} : (vector<16xf16>, vector<32xf16>, vector<8xf32>, i32) -> vector<8xf32>
+ %w32_11 = amdgpu.sparse_wmma 16x16x64 %v16f16 * %v32f16 + %v8f32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, reuseA, reuseB} : vector<16xf16>, vector<32xf16>, vector<8xf32>
+
+ // CHECK: rocdl.swmmac.f32.16x16x64.bf16 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, reuseA = true, reuseB = true} : (vector<16xbf16>, vector<32xbf16>, vector<8xf32>, i32) -> vector<8xf32>
+ %w32_12 = amdgpu.sparse_wmma 16x16x64 %v16bf16 * %v32bf16 + %v8f32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, reuseA, reuseB} : vector<16xbf16>, vector<32xbf16>, vector<8xf32>
+
+ // CHECK: rocdl.swmmac.f16.16x16x64.f16 %v{{.*}}, %v{{.*}}, %v{{.*}}, %index {signA = true, signB = true, reuseA = true, reuseB = true} : (vector<16xf16>, vector<32xf16>, vector<8xf16>, i32) -> vector<8xf16>
+ %w32_13 = amdgpu.sparse_wmma 16x16x64 %v16f16 * %v32f16 + %v8f16 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, reuseA, reuseB} : vector<16xf16>, vector<32xf16>, vector<8xf16>
+
+ // CHECK: rocdl.swmmac.bf16.16x16x64.bf16 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, reuseA = true, reuseB = true} : (vector<16xbf16>, vector<32xbf16>, vector<8xbf16>, i32) -> vector<8xbf16>
+ %w32_14 = amdgpu.sparse_wmma 16x16x64 %v16bf16 * %v32bf16 + %v8bf16 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, reuseA, reuseB} : vector<16xbf16>, vector<32xbf16>, vector<8xbf16>
+
+ // CHECK:
+ // f32 output ?? %15 = amdgpu.sparse_wmma 16x16x64 %v16bf16 * %v32bf16 + %v8bf16 sparse(%idx : vector<4xi8>) {signA = false, signB = false} : vector<16xbf16>, vector<32xbf16>, vector<8xbf16>
+// %w32_15 = rocdl.swmmac.bf16f32.16x16x64.bf16 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = false, signB = false, reuseA = false, reuseB = false} : (vector<16xbf16>, vector<32xbf16>, vector<8xbf16>, i32) -> vector<8xbf16>
+
+ // CHECK: rocdl.swmmac.f32.16x16x128.fp8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf32>, i32) -> vector<8xf32>
+ %w32_16 = amdgpu.sparse_wmma 16x16x128 %v32f8 * %v64f8 + %v8f32 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.swmmac.f32.16x16x128.fp8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf32>, i32) -> vector<8xf32>
+ %w32_17 = amdgpu.sparse_wmma 16x16x128 %v32f8 * %v64bf8 + %v8f32 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32>
+
+ // CHECK: rocdl.swmmac.f32.16x16x128.bf8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf32>, i32) -> vector<8xf32>
+ %w32_18 = amdgpu.sparse_wmma 16x16x128 %v32bf8 * %v64f8 + %v8f32 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.swmmac.f32.16x16x128.bf8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf32>, i32) -> vector<8xf32>
+ %w32_19 = amdgpu.sparse_wmma 16x16x128 %v32bf8 * %v64bf8 + %v8f32 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E5M2>, vector<64xf8E5M2>, vector<8xf32>
+
+ // CHECK: rocdl.swmmac.f16.16x16x128.fp8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf16>, i32) -> vector<8xf16>
+ %w32_20 = amdgpu.sparse_wmma 16x16x128 %v32f8 * %v64f8 + %v8f16 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>
+
+ // CHECK: rocdl.swmmac.f16.16x16x128.fp8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf16>, i32) -> vector<8xf16>
+ %w32_21 = amdgpu.sparse_wmma 16x16x128 %v32f8 * %v64bf8 + %v8f16 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16>
+
+ // CHECK: rocdl.swmmac.f16.16x16x128.bf8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf16>, i32) -> vector<8xf16>
+ %w32_22 = amdgpu.sparse_wmma 16x16x128 %v32bf8 * %v64f8 + %v8f16 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>
+
+ // CHECK: rocdl.swmmac.f16.16x16x128.bf8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index {reuseA = true, reuseB = true} : (vector<8xi32>, vector<16xi32>, vector<8xf16>, i32) -> vector<8xf16>
+ %w32_23 = amdgpu.sparse_wmma 16x16x128 %v32bf8 * %v64bf8 + %v8f16 sparse(%idx : vector<4xi8>) {reuseA, reuseB} : vector<32xf8E5M2>, vector<64xf8E5M2>, vector<8xf16>
+
+ // CHECK: rocdl.swmmac.i32.16x16x128.iu8 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, reuseA = true, reuseB = true, clamp = true} : (vector<8xi32>, vector<16xi32>, vector<8xi32>, i32) -> vector<8xi32>
+ %w32_24 = amdgpu.sparse_wmma 16x16x128 %v32i8 * %v64i8 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, reuseA, reuseB, clamp} : vector<32xi8>, vector<64xi8>, vector<8xi32>
+
+ // ---- Wave64 -----
+
+ // CHECK: rocdl.swmmac.f32.16x16x32.f16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32) -> vector<4xf32>
+ %w64_0 = amdgpu.sparse_wmma 16x16x32 %v4f16 * %v8f16 + %v4f32 sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xf16>, vector<4xf32>
+
+ // CHECK: rocdl.swmmac.f32.16x16x32.bf16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32) -> vector<4xf32>
+ %w64_1 = amdgpu.sparse_wmma 16x16x32 %v4bf16 * %v8bf16 + %v4f32 sparse(%idx : vector<4xi8>) : vector<4xbf16>, vector<8xbf16>, vector<4xf32>
+
+ // CHECK: rocdl.swmmac.f16.16x16x32.f16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<4xf16>, vector<8xf16>, vector<4xf16>, i32) -> vector<4xf16>
+ %w64_2 = amdgpu.sparse_wmma 16x16x32 %v4f16 * %v8f16 + %v4f16 sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xf16>, vector<4xf16>
+
+ // CHECK: rocdl.swmmac.bf16.16x16x32.bf16 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<4xi16>, vector<8xi16>, vector<4xi16>, i32) -> vector<4xi16>
+ %w64_3 = amdgpu.sparse_wmma 16x16x32 %v4bf16 * %v8bf16 + %v4bf16 sparse(%idx : vector<4xi8>) : vector<4xbf16>, vector<8xbf16>, vector<4xbf16>
+
+ // CHECK: rocdl.swmmac.i32.16x16x32.iu8 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (i32, vector<2xi32>, vector<4xi32>, i32) -> vector<4xi32>
+ %w64_4 = amdgpu.sparse_wmma 16x16x32 %v4i8 * %v8i8 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<8xi8>, vector<4xi32>
+
+ // CHECK: rocdl.swmmac.i32.16x16x32.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (i32, i32, vector<4xi32>, i32) -> vector<4xi32>
+ %w64_5 = amdgpu.sparse_wmma 16x16x32 %v4i8 * %v4i8 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<4xi8>, vector<4xi32>
+
+ // CHECK: rocdl.swmmac.i32.16x16x64.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (i32, vector<2xi32>, vector<4xi32>, i32) -> vector<4xi32>
+ %w64_6 = amdgpu.sparse_wmma 16x16x64 %v4i8 * %v16i8 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<16xi8>, vector<4xi32>
+
+ // CHECK: rocdl.swmmac.f32.16x16x32.fp8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (i32, vector<2xi32>, vector<4xf32>, i32) -> vector<4xf32>
+ %w64_7 = amdgpu.sparse_wmma 16x16x32 %v4f8 * %v16f8 + %v4f32 sparse(%idx : vector<4xi8>) : vector<4xf8E4M3FN>, vector<16xf8E4M3FN>, vector<4xf32>
+
+ // CHECK: rocdl.swmmac.f32.16x16x32.fp8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (i32, vector<2xi32>, vector<4xf32>, i32) -> vector<4xf32>
+ %w64_8 = amdgpu.sparse_wmma 16x16x32 %v4f8 * %v16bf8 + %v4f32 sparse(%idx : vector<4xi8>) : vector<4xf8E4M3FN>, vector<16xf8E5M2>, vector<4xf32>
+
+ // CHECK: rocdl.swmmac.f32.16x16x32.bf8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (i32, vector<2xi32>, vector<4xf32>, i32) -> vector<4xf32>
+ %w64_9 = amdgpu.sparse_wmma 16x16x32 %v4bf8 * %v16f8 + %v4f32 sparse(%idx : vector<4xi8>) : vector<4xf8E5M2>, vector<16xf8E4M3FN>, vector<4xf32>
+
+ // CHECK: rocdl.swmmac.f32.16x16x32.bf8.bf8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (i32, vector<2xi32>, vector<4xf32>, i32) -> vector<4xf32>
+ %w64_10 = amdgpu.sparse_wmma 16x16x32 %v4bf8 * %v16bf8 + %v4f32 sparse(%idx : vector<4xi8>) : vector<4xf8E5M2>, vector<16xf8E5M2>, vector<4xf32>
+
+ func.return
+}
>From bce23c190727961ea03e790935c860b72eea58e5 Mon Sep 17 00:00:00 2001
From: ravil-mobile <ravil.aviva.com at gmail.com>
Date: Wed, 25 Feb 2026 19:24:33 +0000
Subject: [PATCH 2/2] [WIP][AMDGPU] Added placeholders for verifier and rewrite
for swmmac
---
.../mlir/Dialect/AMDGPU/IR/AMDGPUOps.td | 4 +++-
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 14 ++++++++++++++
mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp | 10 ++++++++++
.../AMDGPUToROCDL/swmmac-gfx12.mlir | 19 ++++++++-----------
.../AMDGPUToROCDL/swmmac-gfx1250.mlir | 16 +++++++---------
5 files changed, 42 insertions(+), 21 deletions(-)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
index 5eeb2b0dc856c..d00aa00a11257 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUOps.td
@@ -1154,6 +1154,7 @@ def SWMMACSparseInTypes : AnyTypeOf<[
VectorOfLengthAndType<[4, 8, 16], [F16]>,
VectorOfLengthAndType<[4, 8, 16], [BF16]>,
VectorOfLengthAndType<[4, 8, 32], [I8]>,
+ VectorOfLengthAndType<[8, 16], [I<4>]>,
VectorOfLengthAndType<[4, 8, 16, 32], [F8E4M3FN, F8E5M2]>,
VectorOfLengthAndType<[4, 8, 16, 32], [F8E4M3FNUZ, F8E5M2FNUZ]>
]>;
@@ -1162,6 +1163,7 @@ def SWMMACDenseInTypes : AnyTypeOf<[
VectorOfLengthAndType<[8, 16, 32], [F16]>,
VectorOfLengthAndType<[8, 16, 32], [BF16]>,
VectorOfLengthAndType<[4, 8, 16, 64], [I8]>,
+ VectorOfLengthAndType<[8, 16, 32], [I<4>]>,
VectorOfLengthAndType<[4, 8, 16, 64], [F8E4M3FN, F8E5M2]>,
VectorOfLengthAndType<[4, 8, 16, 64], [F8E4M3FNUZ, F8E5M2FNUZ]>
]>;
@@ -1239,7 +1241,7 @@ def AMDGPU_SparseWMMAOp :
attr-dict
`:` type($sourceA) `,` type($sourceB) `,` type($destC)
}];
- let hasVerifier = 0;
+ let hasVerifier = 1;
}
def AMDGPU_GatherToLDSOp :
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 3c2c61b2426e9..6f5520e966179 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1592,6 +1592,20 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct SparseWMMAOpLowering : public ConvertOpToLLVMPattern<SparseWMMAOp> {
+ SparseWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<SparseWMMAOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(SparseWMMAOp op, SparseWMMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // TODO (Ravil)
+ return success();
+ }
+};
+
struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
index f452d2de15dc8..8d4ec6271842c 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUOps.cpp
@@ -670,6 +670,16 @@ LogicalResult SparseMFMAOp::verify() {
return success();
}
+
+//===----------------------------------------------------------------------===//
+// SparseWMMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SparseWMMAOp::verify() {
+ // TODO (Ravil)
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// DPPOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir
index cef8273e9e707..4add3ce4bfd6b 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx12.mlir
@@ -2,15 +2,12 @@
func.func @rocdl.swmmac(
- %v64i8 : vector<64xi8>, %v64f8 : vector<64xf8E4M3FN>, %v64bf8 : vector<64xf8E5M2>,
- %v32f16 : vector<32xf16>, %v32bf16 : vector<32xbf16>, %v32i8 : vector<32xi8>, %v32f8 : vector<32xf8E4M3FN>, %v32bf8 : vector<32xf8E5M2>,
- %v16f16 : vector<16xf16>, %v16bf16 : vector<16xbf16>, %v16i32 : vector<16xi32>, %v16i16 : vector<16xi16>, %v16i8 : vector<16xi8>,
- %v16f8 : vector<16xf8E4M3FN>, %v16bf8 : vector<16xf8E5M2>,
- %v8f32 : vector<8xf32>, %v8i32 : vector<8xi32>, %v8f16 : vector<8xf16>, %v8bf16 : vector<8xbf16>, %v8i16 : vector<8xi16>, %v8i8 : vector<8xi8>,
- %v8f8 : vector<8xf8E4M3FN>, %v8bf8 : vector<8xf8E5M2>,
- %v4f32 : vector<4xf32>, %v4f16 : vector<4xf16>, %v4bf16 : vector<4xbf16>, %v4i32 : vector<4xi32>, %v4i16 : vector<4xi16>, %v4i8 : vector<4xi8>,
- %v4f8 : vector<4xf8E4M3FN>, %v4bf8 : vector<4xf8E5M2>,
- %v2i32 : vector<2xi32>, %v1i32 : i32, %idx : vector<4xi8>) {
+ %v32i4 : vector<32xi4>,
+ %v16f16 : vector<16xf16>, %v16bf16 : vector<16xbf16>, %v16i8 : vector<16xi8>,
+ %v16i4 : vector<16xi4>, %v16f8 : vector<16xf8E4M3FN>, %v16bf8 : vector<16xf8E5M2>,
+ %v8f32 : vector<8xf32>, %v8i32 : vector<8xi32>, %v8f16 : vector<8xf16>, %v8bf16 : vector<8xbf16>, %v8i8 : vector<8xi8>,
+ %v8i4 : vector<8xi4>, %v8f8 : vector<8xf8E4M3FN>, %v8bf8 : vector<8xf8E5M2>,
+ %idx : vector<4xi8>) {
// ---- Wave32 -----
@@ -30,10 +27,10 @@ func.func @rocdl.swmmac(
%w32_4 = amdgpu.sparse_wmma 16x16x32 %v8i8 * %v16i8 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<8xi8>, vector<16xi8>, vector<8xi32>
// CHECK: rocdl.swmmac.i32.16x16x32.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (i32, vector<2xi32>, vector<8xi32>, i32) -> vector<8xi32>
- %w32_5 = amdgpu.sparse_wmma 16x16x32 %v4i8 * %v8i8 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<8xi8>, vector<8xi32>
+ %w32_5 = amdgpu.sparse_wmma 16x16x32 %v8i4 * %v16i4 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<8xi4>, vector<16xi4>, vector<8xi32>
// CHECK: rocdl.swmmac.i32.16x16x64.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (vector<2xi32>, vector<4xi32>, vector<8xi32>, i32) -> vector<8xi32>
- %w32_6 = amdgpu.sparse_wmma 16x16x64 %v8i8 * %v16i8 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<8xi8>, vector<16xi8>, vector<8xi32>
+ %w32_6 = amdgpu.sparse_wmma 16x16x64 %v16i4 * %v32i4 + %v8i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<16xi4>, vector<32xi4>, vector<8xi32>
// CHECK: rocdl.swmmac.f32.16x16x32.fp8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (vector<2xi32>, vector<4xi32>, vector<8xf32>, i32) -> vector<8xf32>
%w32_7 = amdgpu.sparse_wmma 16x16x32 %v8f8 * %v16f8 + %v8f32 sparse(%idx : vector<4xi8>) : vector<8xf8E4M3FN>, vector<16xf8E4M3FN>, vector<8xf32>
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir
index da903452c7d51..aa33595310f10 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/swmmac-gfx1250.mlir
@@ -3,14 +3,12 @@
func.func @rocdl.swmmac(
%v64i8 : vector<64xi8>, %v64f8 : vector<64xf8E4M3FN>, %v64bf8 : vector<64xf8E5M2>,
- %v32f16 : vector<32xf16>, %v32bf16 : vector<32xbf16>, %v32i8 : vector<32xi8>, %v32f8 : vector<32xf8E4M3FN>, %v32bf8 : vector<32xf8E5M2>,
- %v16f16 : vector<16xf16>, %v16bf16 : vector<16xbf16>, %v16i32 : vector<16xi32>, %v16i16 : vector<16xi16>, %v16i8 : vector<16xi8>,
- %v16f8 : vector<16xf8E4M3FN>, %v16bf8 : vector<16xf8E5M2>,
- %v8f32 : vector<8xf32>, %v8i32 : vector<8xi32>, %v8f16 : vector<8xf16>, %v8bf16 : vector<8xbf16>, %v8i16 : vector<8xi16>, %v8i8 : vector<8xi8>,
- %v8f8 : vector<8xf8E4M3FN>, %v8bf8 : vector<8xf8E5M2>,
- %v4f32 : vector<4xf32>, %v4f16 : vector<4xf16>, %v4bf16 : vector<4xbf16>, %v4i32 : vector<4xi32>, %v4i16 : vector<4xi16>, %v4i8 : vector<4xi8>,
+ %v32f16 : vector<32xf16>, %v32bf16 : vector<32xbf16>, %v32i8 : vector<32xi8>, %v32i4 : vector<32xi4>, %v32f8 : vector<32xf8E4M3FN>, %v32bf8 : vector<32xf8E5M2>,
+ %v16f16 : vector<16xf16>, %v16bf16 : vector<16xbf16>, %v16f8 : vector<16xf8E4M3FN>, %v16bf8 : vector<16xf8E5M2>,
+ %v8f32 : vector<8xf32>, %v8i32 : vector<8xi32>, %v8f16 : vector<8xf16>, %v8bf16 : vector<8xbf16>, %v8i8 : vector<8xi8>, %v8i4 : vector<8xi4>,
+ %v4f32 : vector<4xf32>, %v4f16 : vector<4xf16>, %v4bf16 : vector<4xbf16>, %v4i32 : vector<4xi32>, %v4i8 : vector<4xi8>,
%v4f8 : vector<4xf8E4M3FN>, %v4bf8 : vector<4xf8E5M2>,
- %v2i32 : vector<2xi32>, %v1i32 : i32, %idx : vector<4xi8>) {
+ %idx : vector<4xi8>) {
// ---- Wave32 -----
@@ -75,10 +73,10 @@ func.func @rocdl.swmmac(
%w64_4 = amdgpu.sparse_wmma 16x16x32 %v4i8 * %v8i8 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<8xi8>, vector<4xi32>
// CHECK: rocdl.swmmac.i32.16x16x32.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (i32, i32, vector<4xi32>, i32) -> vector<4xi32>
- %w64_5 = amdgpu.sparse_wmma 16x16x32 %v4i8 * %v4i8 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<4xi8>, vector<4xi32>
+ %w64_5 = amdgpu.sparse_wmma 16x16x32 %v8i4 * %v8i4 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32>
// CHECK: rocdl.swmmac.i32.16x16x64.iu4 %{{.*}}, %{{.*}}, %{{.*}}, %index {signA = true, signB = true, clamp = true} : (i32, vector<2xi32>, vector<4xi32>, i32) -> vector<4xi32>
- %w64_6 = amdgpu.sparse_wmma 16x16x64 %v4i8 * %v16i8 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<4xi8>, vector<16xi8>, vector<4xi32>
+ %w64_6 = amdgpu.sparse_wmma 16x16x64 %v8i4 * %v32i4 + %v4i32 sparse(%idx : vector<4xi8>) {unsignedA, unsignedB, clamp} : vector<8xi4>, vector<32xi4>, vector<4xi32>
// CHECK: rocdl.swmmac.f32.16x16x32.fp8.fp8 %{{.*}}, %{{.*}}, %{{.*}}, %index : (i32, vector<2xi32>, vector<4xf32>, i32) -> vector<4xf32>
%w64_7 = amdgpu.sparse_wmma 16x16x32 %v4f8 * %v16f8 + %v4f32 sparse(%idx : vector<4xi8>) : vector<4xf8E4M3FN>, vector<16xf8E4M3FN>, vector<4xf32>
More information about the Mlir-commits
mailing list