[Mlir-commits] [mlir] [mlir][AMDGPU] Add scaled wmma ops for gfx1250 (PR #169854)

Jakub Kuderski llvmlistbot at llvm.org
Sun Nov 30 06:26:26 PST 2025


================
@@ -1218,6 +1227,56 @@ def AMDGPU_ScaledMFMAOp :
   let hasCanonicalizer = 1;
 }
 
+def AMDGPU_ScaledWMMAOp
+    : AMDGPU_Op<"scaled_wmma", [AllTypesMatch<["destC", "destD"]>, Pure]>,
+      Arguments<(ins ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+          ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+          ConfinedAttr<I32Attr, [IntIsOneOf<[128]>]>:$k,
+          ScaledWMMAInTypes:$sourceA, ScaledWMMAInTypes:$sourceB,
+          ScaledWMMAOutTypes:$destC,
+          VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleA,
+          ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$scaleAIdx,
+          VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleB,
+          ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$scaleBIdx)>,
+      Results<(outs ScaledWMMAOutTypes:$destD)> {
+  let summary = "MLIR wrapper for scaled wmma instructions";
+  let description = [{
+    The `amdgpu.scaled_wmma` op is an MLIR wrapper around intrinsics for scaled
+    `wmma` instructions. These instructions perform matrix multiplication with
+    per-block scaling of inputs, supporting fp4, fp6, and fp8 data formats.
+
+    The scale instructions support two tile sizes:
+    - 16x16x128 with mixed f8/f6/f4 formats (output: vector<4xf32>)
+    - 32x16x128 with f4 format only (output: vector<8xf32>)
+
+    Scale parameters (`scaleA`, `scaleB`) are small vectors of f8 scale values
+    (either f8E8M0FNU, or f8E4M3FN). The index attributes (`scaleAIdx`, `scaleBIdx`)
+    select which element from the scale vector to use for scaling. During lowering,
+    these vectors are packed into i32/i64 values for the hardware intrinsics.
+
+    Example:
+    ```mlir
+      // 16x16x128: fp8 inputs
+      %0 = amdgpu.scaled_wmma (%scaleVecA[0] * %matA) * (%scaleVecB[0] * %matB) + %matC
+        { m = 16, n = 16, k = 128 } : vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>,
+        vector<4xf8E8M0FNU>, vector<64xf8E4M3FN>, vector<8xf32>
+
+      // 32x16x128: fp4 inputs
+      %1 = amdgpu.scaled_wmma (%scaleVecC[1] * %matD) * (%scaleVecD[0] * %matE) + %matF
+        { m = 32, n = 16, k = 128 } : vector<8xf8E4M3FN>, vector<128xf4E2M1FN>,
+        vector<8xf8E4M3FN>, vector<64xf4E2M1FN>, vector<16xf32>
----------------
kuhar wrote:

The example doesn't use the `MNKDimension` list syntax from the assembly format

https://github.com/llvm/llvm-project/pull/169854


More information about the Mlir-commits mailing list