[Mlir-commits] [mlir] [mlir][amdgpu][rocdl] Allow for graceful wmma conversion failures (PR #165616)

Jakub Kuderski llvmlistbot at llvm.org
Wed Oct 29 12:36:43 PDT 2025


https://github.com/kuhar updated https://github.com/llvm/llvm-project/pull/165616

>From 6ed2feeff360c60f256a3135b96f54fa744c46a7 Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 29 Oct 2025 15:34:43 -0400
Subject: [PATCH 1/2] [mlir][amdgpu][rocdl] Allow for graceful wmma conversion
 failures

---
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  6 +++---
 .../Conversion/AMDGPUToROCDL/wmma-gfx11.mlir  |  4 ++--
 .../Conversion/AMDGPUToROCDL/wmma-gfx12.mlir  | 13 +++++++++++-
 .../AMDGPUToROCDL/wmma-gfx1250.mlir           | 20 ++++++++++++++-----
 4 files changed, 32 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 1eca43d96fe85..41e333c621eda 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1043,7 +1043,7 @@ wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
       return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
   }
 
-  llvm_unreachable("Unsupported k value");
+  return std::nullopt;
 }
 
 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
@@ -1135,7 +1135,7 @@ static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
     return std::nullopt;
   }
 
-  llvm_unreachable("Unsupported k value");
+  return std::nullopt;
 }
 
 /// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
@@ -1164,7 +1164,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
                                     elemDestType, k);
 
-  llvm_unreachable("unhandled WMMA case");
+  return std::nullopt;
 }
 
 namespace {
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
index d1301d0089220..9fcc1473d4a18 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s
 
 // CHECK-LABEL: @wmma_to_rocdl
 func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
@@ -32,5 +32,5 @@ func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 :
   // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
   amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
 
-  func.return
+  return
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
index b897323340402..57883473bbf06 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1200 \
+// RUN:   --split-input-file --verify-diagnostics | FileCheck %s
+
 // CHECK-LABEL: @wmma_to_rocdl
 func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
                          %arg2 : vector<8xf32>, %arg3 : vector<4xf32>,
@@ -66,3 +68,12 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
 
   func.return
 }
+
+// -----
+
+func.func @wmma_unsupported_k(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<8xf16>) {
+  // expected-error at below {{'amdgpu.wmma' op no intrinsic matching WMMA on the given chipset}}
+  // expected-error at below {{failed to legalize operation 'amdgpu.wmma'}}
+  amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg1 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>
+  func.return
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
index bcbdef040ebe3..267ae8bc4f4c0 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -1,10 +1,11 @@
-// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 \
+// RUN:   --split-input-file --verify-diagnostics | FileCheck %s
 
 // CHECK-LABEL: @wmma_k4
 func.func @wmma_k4(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) {
   // CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1
   amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32>
-  func.return
+  return
 }
 
 // CHECK-LABEL: @wmma_k32
@@ -22,7 +23,7 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec
   // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1)
   amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
 
-  func.return
+  return
 }
 
 // CHECK-LABEL: @wmma_k64
@@ -55,7 +56,7 @@ func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 :
   // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
   amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16>
 
-  func.return
+  return
 }
 
 // CHECK-LABEL: @wmma_k128
@@ -85,5 +86,14 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
   // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
   amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>
 
-  func.return
+  return
+}
+
+// -----
+
+func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) {
+  // expected-error at below {{'amdgpu.wmma' op no intrinsic matching WMMA on the given chipset}}
+  // expected-error at below {{failed to legalize operation 'amdgpu.wmma'}}
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  return
 }

>From 30d4397511f868913704a2be3569ce8381ad65cb Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Wed, 29 Oct 2025 15:36:33 -0400
Subject: [PATCH 2/2] Naming

---
 mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
index 267ae8bc4f4c0..5e77a3add3184 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -91,7 +91,7 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
 
 // -----
 
-func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) {
+func.func @wmma_unsupported_k(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) {
   // expected-error at below {{'amdgpu.wmma' op no intrinsic matching WMMA on the given chipset}}
   // expected-error at below {{failed to legalize operation 'amdgpu.wmma'}}
   amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<8xf16>, vector<8xf16>, vector<8xf32>



More information about the Mlir-commits mailing list