[Mlir-commits] [mlir] [WIP][ROCDL] ]Added math.clampf -> rocdl.fmed3 conversion{ (PR #160100)
Keshav Vinayak Jha
llvmlistbot at llvm.org
Mon Sep 22 06:22:54 PDT 2025
https://github.com/keshavvinayak01 updated https://github.com/llvm/llvm-project/pull/160100
>From 2761997ffff3355d4e0fd392603c51fde6405380 Mon Sep 17 00:00:00 2001
From: keshavvinayak01 <keshavvinayakjha at gmail.com>
Date: Mon, 22 Sep 2025 13:17:34 +0000
Subject: [PATCH 1/2] Added arith.clampf -> rocdl.fmed3 conversion{
Signed-off-by: keshavvinayak01 <keshavvinayakjha at gmail.com>
---
mlir/include/mlir/Conversion/Passes.td | 6 +++
.../Conversion/MathToROCDL/MathToROCDL.cpp | 41 +++++++++++++++----
.../Conversion/MathToROCDL/math-to-rocdl.mlir | 33 ++++++++++++++-
3 files changed, 72 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 1a37d057776e2..060c7183fcb3a 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -785,6 +785,12 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
"ROCDL::ROCDLDialect",
"vector::VectorDialect",
];
+ let options = [
+ Option<"chipset", "chipset", "std::string",
+ /*default=*/"\"gfx000\"",
+ "Chipset that these operations will run on">
+ ];
+
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index df219f3ff4f6e..6da7e9a850ef7 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -10,6 +10,7 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -120,25 +121,51 @@ void mlir::populateMathToROCDLConversionPatterns(
"__ocml_fmod_f64", "__ocml_fmod_f16");
}
-namespace {
-struct ConvertMathToROCDLPass
- : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
- ConvertMathToROCDLPass() = default;
+struct ClampFOpConversion : public ConvertOpToLLVMPattern<math::ClampFOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ ClampFOpConversion(const LLVMTypeConverter &converter,
+ amdgpu::Chipset chipset)
+ : ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}
+
+ LogicalResult
+ matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // V_MED3_F16/F32 only exists in gfx9+ artchitectures
+ if (chipset.majorVersion < 9) {
+ std::string msg =
+ ("pre-gfx9 (gfx" + std::to_string(chipset.majorVersion) +
+ "): V_MED_F16 / V_MED3_F32 not supported.");
+ return rewriter.notifyMatchFailure(op, msg);
+ }
+ rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
+ op.getMin(), op.getMax());
+ return success();
+ }
+ amdgpu::Chipset chipset;
+};
+
+struct ConvertMathToROCDLPass final
+ : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
+ using impl::ConvertMathToROCDLBase<
+ ConvertMathToROCDLPass>::ConvertMathToROCDLBase;
+
void runOnOperation() override;
};
-} // namespace
void ConvertMathToROCDLPass::runOnOperation() {
auto m = getOperation();
MLIRContext *ctx = m.getContext();
+ FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
RewritePatternSet patterns(&getContext());
LowerToLLVMOptions options(ctx, DataLayout(m));
LLVMTypeConverter converter(ctx, options);
+ patterns.add<ClampFOpConversion>(converter, *maybeChipset);
populateMathToROCDLConversionPatterns(converter, patterns);
ConversionTarget target(getContext());
- target.addLegalDialect<BuiltinDialect, func::FuncDialect,
- vector::VectorDialect, LLVM::LLVMDialect>();
+ target
+ .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
+ LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index dbff23339d8b3..488133ad8bddc 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -convert-math-to-rocdl -allow-unregistered-dialect -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' | FileCheck %s --check-prefix=PRE9
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' | FileCheck %s --check-prefix=POST9
module @test_module {
// CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
@@ -596,3 +597,33 @@ module @test_module {
func.return %result : vector<2x2xf16>
}
}
+
+// -----
+
+// f16 clamp → rocdl.fmed3 on gfx9+
+func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 {
+ %r = math.clampf %x to [%lo, %hi] : f16
+ return %r : f16
+}
+
+// f32 clamp → rocdl.fmed3 on gfx9+
+func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 {
+ %r = math.clampf %x to [%lo, %hi] : f32
+ return %r : f32
+}
+
+// POST9-LABEL: func.func @clampf_f16
+// POST9: rocdl.fmed3 {{.*}} : f16
+// POST9: return
+
+// POST9-LABEL: func.func @clampf_f32
+// POST9: rocdl.fmed3 {{.*}} : f32
+// POST9: return
+
+// PRE9-LABEL: func.func @clampf_f16
+// PRE9-NOT: rocdl.fmed3
+// PRE9: math.clampf {{.*}} : f16
+
+// PRE9-LABEL: func.func @clampf_f32
+// PRE9-NOT: rocdl.fmed3
+// PRE9: math.clampf {{.*}} : f32
\ No newline at end of file
>From da4645f37fe1382f590a27103b01abf9e16f03c5 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <31160700+keshavvinayak01 at users.noreply.github.com>
Date: Mon, 22 Sep 2025 18:52:45 +0530
Subject: [PATCH 2/2] Update math-to-rocdl.mlir
---
mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index 488133ad8bddc..541d8d53cac4c 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -626,4 +626,4 @@ func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 {
// PRE9-LABEL: func.func @clampf_f32
// PRE9-NOT: rocdl.fmed3
-// PRE9: math.clampf {{.*}} : f32
\ No newline at end of file
+// PRE9: math.clampf {{.*}} : f32
More information about the Mlir-commits
mailing list