[Mlir-commits] [mlir] [MLIR][MathToLLVM] Add direct lowering of math.clampf to LLVM (PR #188776)

Mehdi Amini llvmlistbot at llvm.org
Thu Mar 26 08:53:43 PDT 2026


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/188776

`math.clampf` had no direct lowering to LLVM dialect, requiring an explicit expansion step via `math-expand-ops` before `convert-math-to-llvm`.

Add `ClampFOpLowering` which converts `clampf(value, min, max)` directly to `llvm.intr.maximum(llvm.intr.minimum(value, max), min)`, matching the semantics of IEEE 754-2019 minimum/maximum. Fast-math flags are propagated. Multi-dimensional vector types are handled via `handleMultidimensionalVectors`.

Fixes #164880

Assisted-by: Claude Code

>From 278e652908f93c3675f7cb428c2686035edd6d20 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Mar 2026 07:18:53 -0700
Subject: [PATCH] [MLIR][MathToLLVM] Add direct lowering of math.clampf to LLVM

`math.clampf` had no direct lowering to LLVM dialect, requiring an
explicit expansion step via `math-expand-ops` before `convert-math-to-llvm`.

Add `ClampFOpLowering` which converts `clampf(value, min, max)` directly
to `llvm.intr.maximum(llvm.intr.minimum(value, max), min)`, matching the
semantics of IEEE 754-2019 minimum/maximum. Fast-math flags are propagated.
Multi-dimensional vector types are handled via `handleMultidimensionalVectors`.

Fixes #164880

Assisted-by: Claude Code
---
 mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 51 +++++++++++++++++++
 .../Conversion/MathToLLVM/math-to-llvm.mlir   | 47 +++++++++++++++++
 2 files changed, 98 insertions(+)

diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 76ee9d00b53ad..446646913a4e1 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -403,6 +403,56 @@ struct IsFiniteOpLowering
   }
 };
 
+// A `clampf` is converted into `minimum(value, max)` followed by
+// `maximum(result, min)`, i.e. clampf(x, lo, hi) = maximum(minimum(x, hi), lo)
+struct ClampFOpLowering
+    : public ConvertOpToLLVMPattern<math::ClampFOp,
+                                    /*FailOnUnsupportedFP=*/true> {
+  using ConvertOpToLLVMPattern<
+      math::ClampFOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    const auto &typeConverter = *this->getTypeConverter();
+    auto operandType = adaptor.getValue().getType();
+    auto llvmOperandType = typeConverter.convertType(operandType);
+    if (!llvmOperandType)
+      return failure();
+
+    auto loc = op.getLoc();
+    ConvertFastMath<math::ClampFOp, LLVM::MinimumOp> minAttrs(op);
+    ConvertFastMath<math::ClampFOp, LLVM::MaximumOp> maxAttrs(op);
+
+    if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
+      auto minOp = LLVM::MinimumOp::create(
+          rewriter, loc, llvmOperandType,
+          ValueRange{adaptor.getValue(), adaptor.getMax()},
+          minAttrs.getAttrs());
+      rewriter.replaceOpWithNewOp<LLVM::MaximumOp>(
+          op, llvmOperandType, ValueRange{minOp.getResult(), adaptor.getMin()},
+          maxAttrs.getAttrs());
+      return success();
+    }
+
+    if (!isa<VectorType>(op.getResult().getType()))
+      return rewriter.notifyMatchFailure(op, "expected vector result type");
+
+    return LLVM::detail::handleMultidimensionalVectors(
+        op.getOperation(), adaptor.getOperands(), typeConverter,
+        [&](Type llvm1DVectorTy, ValueRange operands) {
+          // operands order: value, min, max
+          auto minOp = LLVM::MinimumOp::create(
+              rewriter, loc, llvm1DVectorTy,
+              ValueRange{operands[0], operands[2]}, minAttrs.getAttrs());
+          return LLVM::MaximumOp::create(
+              rewriter, loc, llvm1DVectorTy,
+              ValueRange{minOp.getResult(), operands[1]}, maxAttrs.getAttrs());
+        },
+        rewriter);
+  }
+};
+
 struct ConvertMathToLLVMPass
     : public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
   using Base::Base;
@@ -431,6 +481,7 @@ void mlir::populateMathToLLVMConversionPatterns(
     AbsFOpLowering,
     AbsIOpLowering,
     CeilOpLowering,
+    ClampFOpLowering,
     CopySignOpLowering,
     CosOpLowering,
     CoshOpLowering,
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 504dc1afb0eef..301a95bf716b7 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -641,3 +641,50 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: f4E2M1FN, %arg2: f4E2M1FN
   %2 = math.fma %arg1, %arg1, %arg2 : f4E2M1FN
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @clampf(
+// CHECK-SAME: %[[VAL:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32
+func.func @clampf(%arg0: f32, %arg1: f32, %arg2: f32) -> f32 {
+  // CHECK: %[[MIN_VAL:.*]] = llvm.intr.minimum(%[[VAL]], %[[MAX]]) : (f32, f32) -> f32
+  // CHECK: %[[RESULT:.*]] = llvm.intr.maximum(%[[MIN_VAL]], %[[MIN]]) : (f32, f32) -> f32
+  %0 = math.clampf %arg0 to [%arg1, %arg2] : f32
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @clampf_fmf(
+// CHECK-SAME: %[[VAL:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32
+func.func @clampf_fmf(%arg0: f32, %arg1: f32, %arg2: f32) -> f32 {
+  // CHECK: %[[MIN_VAL:.*]] = llvm.intr.minimum(%[[VAL]], %[[MAX]]) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32) -> f32
+  // CHECK: %[[RESULT:.*]] = llvm.intr.maximum(%[[MIN_VAL]], %[[MIN]]) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32) -> f32
+  %0 = math.clampf %arg0 to [%arg1, %arg2] fastmath<fast> : f32
+  return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @clampf_vector(
+// CHECK-SAME: %[[VAL:.*]]: vector<4xf32>, %[[MIN:.*]]: vector<4xf32>, %[[MAX:.*]]: vector<4xf32>
+func.func @clampf_vector(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
+  // CHECK: %[[MIN_VAL:.*]] = llvm.intr.minimum(%[[VAL]], %[[MAX]]) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
+  // CHECK: %[[RESULT:.*]] = llvm.intr.maximum(%[[MIN_VAL]], %[[MIN]]) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
+  %0 = math.clampf %arg0 to [%arg1, %arg2] : vector<4xf32>
+  return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @clampf_2dvector(
+func.func @clampf_2dvector(%arg0: vector<4x3xf32>, %arg1: vector<4x3xf32>, %arg2: vector<4x3xf32>) -> vector<4x3xf32> {
+  // CHECK: %[[EXTRACT_VAL:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+  // CHECK: %[[EXTRACT_MIN:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+  // CHECK: %[[EXTRACT_MAX:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+  // CHECK: %[[MIN_VAL:.*]] = llvm.intr.minimum(%[[EXTRACT_VAL]], %[[EXTRACT_MAX]]) : (vector<3xf32>, vector<3xf32>) -> vector<3xf32>
+  // CHECK: %[[MAX_VAL:.*]] = llvm.intr.maximum(%[[MIN_VAL]], %[[EXTRACT_MIN]]) : (vector<3xf32>, vector<3xf32>) -> vector<3xf32>
+  // CHECK: llvm.insertvalue %[[MAX_VAL]], %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+  %0 = math.clampf %arg0 to [%arg1, %arg2] : vector<4x3xf32>
+  return %0 : vector<4x3xf32>
+}



More information about the Mlir-commits mailing list