[Mlir-commits] [mlir] [MLIR][MathToLLVM] Add direct lowering of math.clampf to LLVM (PR #188776)
Mehdi Amini
llvmlistbot at llvm.org
Thu Mar 26 08:53:43 PDT 2026
https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/188776
`math.clampf` had no direct lowering to LLVM dialect, requiring an explicit expansion step via `math-expand-ops` before `convert-math-to-llvm`.
Add `ClampFOpLowering` which converts `clampf(value, min, max)` directly to `llvm.intr.maximum(llvm.intr.minimum(value, max), min)`, matching the semantics of IEEE 754-2019 minimum/maximum. Fast-math flags are propagated. Multi-dimensional vector types are handled via `handleMultidimensionalVectors`.
Fixes #164880
Assisted-by: Claude Code
>From 278e652908f93c3675f7cb428c2686035edd6d20 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Mar 2026 07:18:53 -0700
Subject: [PATCH] [MLIR][MathToLLVM] Add direct lowering of math.clampf to LLVM
`math.clampf` had no direct lowering to LLVM dialect, requiring an
explicit expansion step via `math-expand-ops` before `convert-math-to-llvm`.
Add `ClampFOpLowering` which converts `clampf(value, min, max)` directly
to `llvm.intr.maximum(llvm.intr.minimum(value, max), min)`, matching the
semantics of IEEE 754-2019 minimum/maximum. Fast-math flags are propagated.
Multi-dimensional vector types are handled via `handleMultidimensionalVectors`.
Fixes #164880
Assisted-by: Claude Code
---
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 51 +++++++++++++++++++
.../Conversion/MathToLLVM/math-to-llvm.mlir | 47 +++++++++++++++++
2 files changed, 98 insertions(+)
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 76ee9d00b53ad..446646913a4e1 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -403,6 +403,56 @@ struct IsFiniteOpLowering
}
};
+// A `clampf` is converted into `minimum(value, max)` followed by
+// `maximum(result, min)`, i.e. clampf(x, lo, hi) = maximum(minimum(x, hi), lo)
+struct ClampFOpLowering
+ : public ConvertOpToLLVMPattern<math::ClampFOp,
+ /*FailOnUnsupportedFP=*/true> {
+ using ConvertOpToLLVMPattern<
+ math::ClampFOp, /*FailOnUnsupportedFP=*/true>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(math::ClampFOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ const auto &typeConverter = *this->getTypeConverter();
+ auto operandType = adaptor.getValue().getType();
+ auto llvmOperandType = typeConverter.convertType(operandType);
+ if (!llvmOperandType)
+ return failure();
+
+ auto loc = op.getLoc();
+ ConvertFastMath<math::ClampFOp, LLVM::MinimumOp> minAttrs(op);
+ ConvertFastMath<math::ClampFOp, LLVM::MaximumOp> maxAttrs(op);
+
+ if (!isa<LLVM::LLVMArrayType>(llvmOperandType)) {
+ auto minOp = LLVM::MinimumOp::create(
+ rewriter, loc, llvmOperandType,
+ ValueRange{adaptor.getValue(), adaptor.getMax()},
+ minAttrs.getAttrs());
+ rewriter.replaceOpWithNewOp<LLVM::MaximumOp>(
+ op, llvmOperandType, ValueRange{minOp.getResult(), adaptor.getMin()},
+ maxAttrs.getAttrs());
+ return success();
+ }
+
+ if (!isa<VectorType>(op.getResult().getType()))
+ return rewriter.notifyMatchFailure(op, "expected vector result type");
+
+ return LLVM::detail::handleMultidimensionalVectors(
+ op.getOperation(), adaptor.getOperands(), typeConverter,
+ [&](Type llvm1DVectorTy, ValueRange operands) {
+ // operands order: value, min, max
+ auto minOp = LLVM::MinimumOp::create(
+ rewriter, loc, llvm1DVectorTy,
+ ValueRange{operands[0], operands[2]}, minAttrs.getAttrs());
+ return LLVM::MaximumOp::create(
+ rewriter, loc, llvm1DVectorTy,
+ ValueRange{minOp.getResult(), operands[1]}, maxAttrs.getAttrs());
+ },
+ rewriter);
+ }
+};
+
struct ConvertMathToLLVMPass
: public impl::ConvertMathToLLVMPassBase<ConvertMathToLLVMPass> {
using Base::Base;
@@ -431,6 +481,7 @@ void mlir::populateMathToLLVMConversionPatterns(
AbsFOpLowering,
AbsIOpLowering,
CeilOpLowering,
+ ClampFOpLowering,
CopySignOpLowering,
CosOpLowering,
CoshOpLowering,
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 504dc1afb0eef..301a95bf716b7 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -641,3 +641,50 @@ func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: f4E2M1FN, %arg2: f4E2M1FN
%2 = math.fma %arg1, %arg1, %arg2 : f4E2M1FN
return
}
+
+// -----
+
+// CHECK-LABEL: func @clampf(
+// CHECK-SAME: %[[VAL:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32
+func.func @clampf(%arg0: f32, %arg1: f32, %arg2: f32) -> f32 {
+ // CHECK: %[[MIN_VAL:.*]] = llvm.intr.minimum(%[[VAL]], %[[MAX]]) : (f32, f32) -> f32
+ // CHECK: %[[RESULT:.*]] = llvm.intr.maximum(%[[MIN_VAL]], %[[MIN]]) : (f32, f32) -> f32
+ %0 = math.clampf %arg0 to [%arg1, %arg2] : f32
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @clampf_fmf(
+// CHECK-SAME: %[[VAL:.*]]: f32, %[[MIN:.*]]: f32, %[[MAX:.*]]: f32
+func.func @clampf_fmf(%arg0: f32, %arg1: f32, %arg2: f32) -> f32 {
+ // CHECK: %[[MIN_VAL:.*]] = llvm.intr.minimum(%[[VAL]], %[[MAX]]) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32) -> f32
+ // CHECK: %[[RESULT:.*]] = llvm.intr.maximum(%[[MIN_VAL]], %[[MIN]]) {fastmathFlags = #llvm.fastmath<fast>} : (f32, f32) -> f32
+ %0 = math.clampf %arg0 to [%arg1, %arg2] fastmath<fast> : f32
+ return %0 : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @clampf_vector(
+// CHECK-SAME: %[[VAL:.*]]: vector<4xf32>, %[[MIN:.*]]: vector<4xf32>, %[[MAX:.*]]: vector<4xf32>
+func.func @clampf_vector(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<4xf32> {
+ // CHECK: %[[MIN_VAL:.*]] = llvm.intr.minimum(%[[VAL]], %[[MAX]]) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
+ // CHECK: %[[RESULT:.*]] = llvm.intr.maximum(%[[MIN_VAL]], %[[MIN]]) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
+ %0 = math.clampf %arg0 to [%arg1, %arg2] : vector<4xf32>
+ return %0 : vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @clampf_2dvector(
+func.func @clampf_2dvector(%arg0: vector<4x3xf32>, %arg1: vector<4x3xf32>, %arg2: vector<4x3xf32>) -> vector<4x3xf32> {
+ // CHECK: %[[EXTRACT_VAL:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+ // CHECK: %[[EXTRACT_MIN:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+ // CHECK: %[[EXTRACT_MAX:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+ // CHECK: %[[MIN_VAL:.*]] = llvm.intr.minimum(%[[EXTRACT_VAL]], %[[EXTRACT_MAX]]) : (vector<3xf32>, vector<3xf32>) -> vector<3xf32>
+ // CHECK: %[[MAX_VAL:.*]] = llvm.intr.maximum(%[[MIN_VAL]], %[[EXTRACT_MIN]]) : (vector<3xf32>, vector<3xf32>) -> vector<3xf32>
+ // CHECK: llvm.insertvalue %[[MAX_VAL]], %{{.*}}[0] : !llvm.array<4 x vector<3xf32>>
+ %0 = math.clampf %arg0 to [%arg1, %arg2] : vector<4x3xf32>
+ return %0 : vector<4x3xf32>
+}
More information about the Mlir-commits
mailing list