[Mlir-commits] [mlir] [mlir][amdgpu][rocdl] Allow for graceful wmma conversion failures (PR #165616)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 29 12:36:20 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jakub Kuderski (kuhar)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/165616.diff
4 Files Affected:
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+3-3)
- (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir (+2-2)
- (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir (+12-1)
- (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir (+15-5)
``````````diff
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 1eca43d96fe85..41e333c621eda 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1043,7 +1043,7 @@ wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
}
- llvm_unreachable("Unsupported k value");
+ return std::nullopt;
}
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
@@ -1135,7 +1135,7 @@ static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
return std::nullopt;
}
- llvm_unreachable("Unsupported k value");
+ return std::nullopt;
}
/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
@@ -1164,7 +1164,7 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
elemDestType, k);
- llvm_unreachable("unhandled WMMA case");
+ return std::nullopt;
}
namespace {
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
index d1301d0089220..9fcc1473d4a18 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 | FileCheck %s
// CHECK-LABEL: @wmma_to_rocdl
func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
@@ -32,5 +32,5 @@ func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 :
// CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
- func.return
+ return
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
index b897323340402..57883473bbf06 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1200 \
+// RUN: --split-input-file --verify-diagnostics | FileCheck %s
+
// CHECK-LABEL: @wmma_to_rocdl
func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
%arg2 : vector<8xf32>, %arg3 : vector<4xf32>,
@@ -66,3 +68,12 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>,
func.return
}
+
+// -----
+
+func.func @wmma_unsupported_k(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<8xf16>) {
+ // expected-error at below {{'amdgpu.wmma' op no intrinsic matching WMMA on the given chipset}}
+ // expected-error at below {{failed to legalize operation 'amdgpu.wmma'}}
+ amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg1 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>
+ func.return
+}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
index bcbdef040ebe3..267ae8bc4f4c0 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -1,10 +1,11 @@
-// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 \
+// RUN: --split-input-file --verify-diagnostics | FileCheck %s
// CHECK-LABEL: @wmma_k4
func.func @wmma_k4(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) {
// CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1
amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32>
- func.return
+ return
}
// CHECK-LABEL: @wmma_k32
@@ -22,7 +23,7 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec
// CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1)
amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
- func.return
+ return
}
// CHECK-LABEL: @wmma_k64
@@ -55,7 +56,7 @@ func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 :
// CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16>
- func.return
+ return
}
// CHECK-LABEL: @wmma_k128
@@ -85,5 +86,14 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
// CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>
- func.return
+ return
+}
+
+// -----
+
+func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) {
+ // expected-error at below {{'amdgpu.wmma' op no intrinsic matching WMMA on the given chipset}}
+ // expected-error at below {{failed to legalize operation 'amdgpu.wmma'}}
+ amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<8xf16>, vector<8xf16>, vector<8xf32>
+ return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/165616
More information about the Mlir-commits
mailing list