[llvm] [mlir] [ROCDL] Add gfx1250 WMMA intrinsics (PR #162343)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 7 11:53:07 PDT 2025
https://github.com/Muzammiluddin-Syed-ECE created https://github.com/llvm/llvm-project/pull/162343
Adding
>From 048d50d8c7f083e265ceb80899dac7011527b562 Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Tue, 7 Oct 2025 13:22:25 -0500
Subject: [PATCH 1/2] Initial commit
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
llvm/include/llvm/IR/IntrinsicsAMDGPU.td | 6 +--
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 26 ++++++++++++-
mlir/test/Target/LLVMIR/rocdl.mlir | 39 +++++++++++++++++++-
3 files changed, 65 insertions(+), 6 deletions(-)
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index ded00b1274670..04ce1aedfdb4d 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -4085,11 +4085,11 @@ class AMDGPUWmmaScaleF4IntrinsicModsC<LLVMType scale_ty> :
defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
def int_amdgcn_wmma_f32_16x16x4_f32 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
-def int_amdgcn_wmma_f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyint_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_f32_16x16x32_f16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_f16_16x16x32_f16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
-def int_amdgcn_wmma_bf16_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
-def int_amdgcn_wmma_bf16f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllDiff<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_bf16_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyint_ty, llvm_anyint_ty>;
+def int_amdgcn_wmma_bf16f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllDiff<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_f32_16x16x64_fp8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_f32_16x16x64_fp8_bf8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
def int_amdgcn_wmma_f32_16x16x64_bf8_fp8 : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index db1b7e3af62fd..d1a2b48c5f704 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -471,7 +471,7 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
//===---------------------------------------------------------------------===//
// WMMA intrinsics
-class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
+class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands = [],
list<Trait> traits = []> :
ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
Arguments<(ins Variadic<LLVM_Type>:$args)> {
@@ -492,6 +492,30 @@ def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_b
def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
+// Available from gfx1250
+def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>;
+def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>;
+def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>;
+def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>;
+def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>;
+def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>;
+def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>;
+def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>;
+def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>;
+def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
+def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
//===---------------------------------------------------------------------===//
// LDS transpose intrinsics (available in GFX950)
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 1c0c2eba002aa..4b3e61b33a941 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -816,9 +816,12 @@ llvm.func @rocdl.mfma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
}
llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : vector<16 x i16>, %arg3 : vector<8 x i32>,
- %arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>) -> vector<8xf32> {
+ %arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>,
+ %arg9 : vector<32xf16>, %arg10 : vector<16xf32>, %arg11 : vector<4xf32>, %arg12 : vector<32xf32>,
+ %arg13 : vector<16xi32>, %arg14 : vector<64xf32>, %arg15 : vector<64xi32>, %arg16 : i32) -> vector<8xf32> {
%zero = llvm.mlir.constant(false) : i1
-
+ %zero_i16 = llvm.mlir.constant(0 : i16) : i16
+ %zero_i32 = llvm.mlir.constant(0 : i32) : i32
// ---- Wave32 -----
// f16 -> f32
@@ -849,6 +852,38 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
// CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
%r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
+ // f32 -> f32
+ // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 false, <16 x float> %10, i1 false, <16 x float> %10, i16 0, <4 x float> %11, i1 false, i1 false)
+ %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %zero, %arg10, %zero, %arg10, %zero_i16, %arg11, %zero, %zero : (i1, vector<16xf32>, i1, vector<16xf32>, i16, vector<4xf32>, i1, i1) -> vector<4xf32>
+
+ // bf16 -> f32
+ // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16i16(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <32 x float> %12, i1 false, i1 false)
+ %r2.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %zero, %arg2, %zero, %arg2, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xi16>, i1, vector<16xi16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32>
+
+ // f16 -> f32
+ // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %1, i1 false, <16 x half> %1, i16 0, <32 x float> %12, i1 false, i1 false)
+ %r3.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32>
+
+ // f16 -> f16
+ // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 false, <16 x half> %1, i1 false, <16 x half> %1, i16 0, <32 x half> %9, i1 false, i1 false)
+ %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg9, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf16>, i1, i1) -> vector<32xf16>
+
+ // bf16 -> bf16
+ // CHECK: call <16 x i32> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v16i32.v16i16(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <16 x i32> %13, i1 false, i1 false)
+ %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %zero, %arg2, %zero, %arg2, %zero_i16, %arg13, %zero, %zero : (i1, vector<16xi16>, i1, vector<16xi16>, i16, vector<16xi32>, i1, i1) -> vector<16xi32>
+
+ // bf16 -> bf16 / f32
+ // CHECK: call <16 x i32> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v16i32.v16i16.v32f32(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <32 x float> %12, i1 false, i1 false)
+ %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %zero, %arg2, %zero, %arg2, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xi16>, i1, vector<16xi16>, i16, vector<32xf32>, i1, i1) -> vector<16xi32>
+
+ // f8 -> f32
+ // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %5, <4 x i32> %5, i16 0, <64 x float> %14, i1 false, i1 false)
+ %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg14, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+
+ // iu8 -> i32
+ // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 false, <4 x i32> %5, i1 false, <4 x i32> %5, <64 x i32> %15, i1 false, i1 false)
+ %r8.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero, %arg5, %zero, %arg5, %arg15, %zero, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<64xi32>, i1, i1) -> vector<64xi32>
+
// ---- Wave64 -----
// f16 -> f32
>From ab6cfd98596a20ac3f845d3622c3c7b526afdf4f Mon Sep 17 00:00:00 2001
From: Muzammiluddin Syed <muzasyed at amd.com>
Date: Tue, 7 Oct 2025 13:28:29 -0500
Subject: [PATCH 2/2] wmma scales intrinsics
Signed-off-by: Muzammiluddin Syed <muzasyed at amd.com>
---
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 4 ++++
mlir/test/Target/LLVMIR/rocdl.mlir | 4 ++++
2 files changed, 8 insertions(+)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index d1a2b48c5f704..3814a2dae0f3f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -516,6 +516,10 @@ def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8
def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
+def ROCDL_wmma_scale_f32_16x16x128_f8f6f4 : ROCDL_Wmma_IntrOp<"wmma.scale.f32.16x16x128.f8f6f4">;
+def ROCDL_wmma_scale16_f32_16x16x128_f8f6f4 : ROCDL_Wmma_IntrOp<"wmma.scale16.f32.16x16x128.f8f6f4">;
+def ROCDL_wmma_scale_f32_32x16x128_f4 : ROCDL_Wmma_IntrOp<"wmma.scale.f32.32x16x128.f4">;
+def ROCDL_wmma_scale16_f32_32x16x128_f4 : ROCDL_Wmma_IntrOp<"wmma.scale16.f32.32x16x128.f4">;
//===---------------------------------------------------------------------===//
// LDS transpose intrinsics (available in GFX950)
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 4b3e61b33a941..e0400bf07b563 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -884,6 +884,10 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
// CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 false, <4 x i32> %5, i1 false, <4 x i32> %5, <64 x i32> %15, i1 false, i1 false)
%r8.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero, %arg5, %zero, %arg5, %arg15, %zero, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<64xi32>, i1, i1) -> vector<64xi32>
+ %r9.gfx1250 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %zero_i32, %arg5, %zero_i32, %arg5, %zero_i16, %arg11, %zero_i32, %zero_i32, %arg16, %zero_i32, %zero_i32, %arg16, %zero, %zero : (i32, vector<4xi32>, i32, vector<4xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+ // %r7.gfx1250 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4
+ // %r7.gfx1250 = rocdl.wmma.scale.f32.32x16x128.f4
+ // %r7.gfx1250 = rocdl.wmma.scale16.f32.32x16x128.f4
// ---- Wave64 -----
// f16 -> f32
More information about the llvm-commits
mailing list