[Mlir-commits] [mlir] 1e6df64 - [MLIR][ROCDL] Add math.clampf -> rocdl.fmed3 conversion (#163259)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 14 13:21:31 PDT 2025


Author: Keshav Vinayak Jha
Date: 2025-10-14T15:21:26-05:00
New Revision: 1e6df640e2bae9e20b640bb9ebdba140ba7bc0c1

URL: https://github.com/llvm/llvm-project/commit/1e6df640e2bae9e20b640bb9ebdba140ba7bc0c1
DIFF: https://github.com/llvm/llvm-project/commit/1e6df640e2bae9e20b640bb9ebdba140ba7bc0c1.diff

LOG: [MLIR][ROCDL] Add math.clampf -> rocdl.fmed3 conversion (#163259)

Added Pattern for lowering `Math::ClampFOp` to `ROCDL::FMED3`.
Also added `chipset` option to `MathToRocdl` pass to check for arch
support ISA instructions

Solves [#15072](https://github.com/llvm/llvm-project/issues/157052)

Reapplies https://github.com/llvm/llvm-project/pull/160100

---------

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
    mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
    mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
index 46573e7966ccc..770f257d89bd5 100644
--- a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -9,6 +9,7 @@
 #define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
 
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
 #include "mlir/IR/PatternMatch.h"
 #include <memory>
 
@@ -20,7 +21,8 @@ class Pass;
 
 /// Populate the given list with patterns that convert from Math to ROCDL calls.
 void populateMathToROCDLConversionPatterns(const LLVMTypeConverter &converter,
-                                           RewritePatternSet &patterns);
+                                           RewritePatternSet &patterns,
+                                           amdgpu::Chipset chipset);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 25e9d34f3e653..a2eb335faac6c 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -778,6 +778,10 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
   let summary = "Convert Math dialect to ROCDL library calls";
   let description = [{
     This pass converts supported Math ops to ROCDL library calls.
+
+    The chipset option specifies the target AMDGPU architecture. If the chipset
+    is empty, none of the chipset-dependent patterns are added and the pass
+    will not attempt to parse the chipset.
   }];
   let dependentDialects = [
     "arith::ArithDialect",
@@ -785,6 +789,9 @@ def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
     "ROCDL::ROCDLDialect",
     "vector::VectorDialect",
   ];
+  let options = [Option<"chipset", "chipset", "std::string",
+                        /*default=*/"\"gfx000\"",
+                        "Chipset that these operations will run on">];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index b215211e131d4..c03f3a5d3889c 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -484,5 +484,5 @@ void mlir::populateGpuToROCDLConversionPatterns(
                GPUSubgroupBroadcastOpToROCDL>(converter);
   patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);
 
-  populateMathToROCDLConversionPatterns(converter, patterns);
+  populateMathToROCDLConversionPatterns(converter, patterns, chipset);
 }

diff  --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index df219f3ff4f6e..4ba7eab64a785 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -10,6 +10,8 @@
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
+#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -42,8 +44,65 @@ static void populateOpPatterns(const LLVMTypeConverter &converter,
                                            f32ApproxFunc, f16Func);
 }
 
+struct ClampFOpConversion final
+    : public ConvertOpToLLVMPattern<math::ClampFOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+  ClampFOpConversion(const LLVMTypeConverter &converter,
+                     amdgpu::Chipset chipset)
+      : ConvertOpToLLVMPattern<math::ClampFOp>(converter), chipset(chipset) {}
+
+  LogicalResult
+  matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // Only f16 and f32 types are supported by fmed3
+    Type opTy = op.getType();
+    auto resultType = getTypeConverter()->convertType(opTy);
+
+    if (auto vectorType = dyn_cast<VectorType>(opTy)) {
+      opTy = vectorType.getElementType();
+    }
+
+    if (!isa<Float16Type, Float32Type>(opTy)) {
+      return rewriter.notifyMatchFailure(
+          op, "fmed3 only supports f16 and f32 types");
+    }
+
+    // Handle multi-dimensional vectors (converted to LLVM arrays)
+    if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(resultType)) {
+      // Handle multi-dimensional vectors (converted to LLVM arrays)
+      return LLVM::detail::handleMultidimensionalVectors(
+          op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
+          [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
+            typename math::ClampFOp::Adaptor adaptor(operands);
+            return ROCDL::FMed3Op::create(rewriter, op.getLoc(), llvm1DVectorTy,
+                                          adaptor.getValue(), adaptor.getMin(),
+                                          adaptor.getMax());
+          },
+          rewriter);
+    }
+
+    // Handle 1D vectors and scalars directly
+    rewriter.replaceOpWithNewOp<ROCDL::FMed3Op>(op, op.getType(), op.getValue(),
+                                                op.getMin(), op.getMax());
+    return success();
+  }
+
+  amdgpu::Chipset chipset;
+};
+
+static void addChipsetDependentPatterns(const LLVMTypeConverter &converter,
+                                        RewritePatternSet &patterns,
+                                        amdgpu::Chipset chipset) {
+
+  // V_MED3_F16/F32 only exists in gfx9+ architectures
+  if (chipset.majorVersion >= 9) {
+    patterns.add<ClampFOpConversion>(converter, chipset);
+  }
+}
+
 void mlir::populateMathToROCDLConversionPatterns(
-    const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+    const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    amdgpu::Chipset chipset) {
   // Handled by mathToLLVM: math::AbsIOp
   // Handled by mathToLLVM: math::AbsFOp
   // Handled by mathToLLVM: math::CopySignOp
@@ -118,15 +177,17 @@ void mlir::populateMathToROCDLConversionPatterns(
   // worth creating a separate pass for it.
   populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
                                     "__ocml_fmod_f64", "__ocml_fmod_f16");
+
+  addChipsetDependentPatterns(converter, patterns, chipset);
 }
 
-namespace {
-struct ConvertMathToROCDLPass
-    : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
-  ConvertMathToROCDLPass() = default;
+struct ConvertMathToROCDLPass final
+    : impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
+  using impl::ConvertMathToROCDLBase<
+      ConvertMathToROCDLPass>::ConvertMathToROCDLBase;
+
   void runOnOperation() override;
 };
-} // namespace
 
 void ConvertMathToROCDLPass::runOnOperation() {
   auto m = getOperation();
@@ -135,10 +196,20 @@ void ConvertMathToROCDLPass::runOnOperation() {
   RewritePatternSet patterns(&getContext());
   LowerToLLVMOptions options(ctx, DataLayout(m));
   LLVMTypeConverter converter(ctx, options);
-  populateMathToROCDLConversionPatterns(converter, patterns);
+
+  // Only populate chipset-dependent patterns if chipset is specified
+  if (!chipset.empty()) {
+    FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
+    if (failed(maybeChipset)) {
+      return signalPassFailure();
+    }
+    populateMathToROCDLConversionPatterns(converter, patterns, *maybeChipset);
+  }
+
   ConversionTarget target(getContext());
-  target.addLegalDialect<BuiltinDialect, func::FuncDialect,
-                         vector::VectorDialect, LLVM::LLVMDialect>();
+  target
+      .addLegalDialect<BuiltinDialect, func::FuncDialect, vector::VectorDialect,
+                       LLVM::LLVMDialect, ROCDL::ROCDLDialect>();
   target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
                       LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
                       LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,

diff  --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index dbff23339d8b3..455f886839604 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -convert-math-to-rocdl -allow-unregistered-dialect -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx803})' | FileCheck %s --check-prefix=PRE9
+// RUN: mlir-opt %s -allow-unregistered-dialect -split-input-file -pass-pipeline='builtin.module(convert-math-to-rocdl{chipset=gfx942})' | FileCheck %s --check-prefix=POST9
 
 module @test_module {
   // CHECK: llvm.func @__ocml_fmod_f16(f16, f16) -> f16
@@ -596,3 +597,76 @@ module @test_module {
     func.return %result : vector<2x2xf16>
   }
 }
+
+// -----
+
+// f16 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_f16
+func.func @clampf_f16(%x: f16, %lo: f16, %hi: f16) -> f16 {
+  %r = math.clampf %x to [%lo, %hi] : f16
+  return %r : f16
+  // POST9: rocdl.fmed3 {{.*}} : f16
+  // PRE9-NOT: rocdl.fmed3
+  // PRE9: math.clampf {{.*}} : f16
+}
+
+// f32 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_f32
+func.func @clampf_f32(%x: f32, %lo: f32, %hi: f32) -> f32 {
+  %r = math.clampf %x to [%lo, %hi] : f32
+  return %r : f32
+  // POST9: rocdl.fmed3 {{.*}} : f32
+  // PRE9-NOT: rocdl.fmed3
+  // PRE9: math.clampf {{.*}} : f32
+}
+
+// -----
+
+// Vector f16 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_vector_f16
+func.func @clampf_vector_f16(%x: vector<2xf16>, %lo: vector<2xf16>, %hi: vector<2xf16>) -> vector<2xf16> {
+  %r = math.clampf %x to [%lo, %hi] : vector<2xf16>
+  return %r : vector<2xf16>
+  // POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
+  // PRE9-NOT: rocdl.fmed3
+  // PRE9: math.clampf {{.*}} : vector<2xf16>
+}
+
+// -----
+
+// Vector f32 clamp → rocdl.fmed3 on gfx9+
+// CHECK-LABEL: func.func @clampf_vector_f32
+func.func @clampf_vector_f32(%x: vector<2xf32>, %lo: vector<2xf32>, %hi: vector<2xf32>) -> vector<2xf32> {
+  %r = math.clampf %x to [%lo, %hi] : vector<2xf32>
+  return %r : vector<2xf32>
+  // POST9: rocdl.fmed3 {{.*}} : vector<2xf32>
+  // PRE9-NOT: rocdl.fmed3
+  // PRE9: math.clampf {{.*}} : vector<2xf32>
+}
+
+// -----
+
+// Multi-dimensional vector f16 clamp → rocdl.fmed3 on gfx9+ (unrolled to 1D vectors)
+// CHECK-LABEL: func.func @clampf_vector_2d_f16
+func.func @clampf_vector_2d_f16(%x: vector<2x2xf16>, %lo: vector<2x2xf16>, %hi: vector<2x2xf16>) -> vector<2x2xf16> {
+  %r = math.clampf %x to [%lo, %hi] : vector<2x2xf16>
+  return %r : vector<2x2xf16>
+  // POST9: builtin.unrealized_conversion_cast {{.*}} : vector<2x2xf16> to !llvm.array<2 x vector<2xf16>>
+  // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+  // POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
+  // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+  // POST9: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+  // POST9: rocdl.fmed3 {{.*}} : vector<2xf16>
+  // POST9: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf16>>
+  // PRE9-NOT: rocdl.fmed3
+  // PRE9: math.clampf {{.*}} : vector<2x2xf16>
+}
+
+// -----
+// CHECK-LABEL: func.func @clampf_bf16
+func.func @clampf_bf16(%x: bf16, %lo: bf16, %hi: bf16) -> bf16 {
+  %r = math.clampf %x to [%lo, %hi] : bf16
+  return %r : bf16
+  // CHECK: math.clampf {{.*}} : bf16
+  // CHECK-NOT: rocdl.fmed3
+}


        


More information about the Mlir-commits mailing list