[Mlir-commits] [mlir] 4b8ff4c - [mlir][math] Fix the semantics of math.clampf (#175012)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 19 02:49:06 PST 2026


Author: Jack
Date: 2026-01-19T10:49:01Z
New Revision: 4b8ff4c63d681ac7376bb874f0302538ac9a6d27

URL: https://github.com/llvm/llvm-project/commit/4b8ff4c63d681ac7376bb874f0302538ac9a6d27
DIFF: https://github.com/llvm/llvm-project/commit/4b8ff4c63d681ac7376bb874f0302538ac9a6d27.diff

LOG: [mlir][math] Fix the semantics of math.clampf (#175012)

The `math.clampf` op is semantically incorrect when compared to both the
CUDA reference implementation and the SPIRV spec, both of which have a
clamp op.
- Fix the definition of `math.clampf` to agree with CUDA and SPIRV
- Explicitly state when `math.clampf` produces `ub.poison`
- Update the ExpandOps pass to reflect the corrected semantics

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Math/IR/MathOps.td
    mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
    mlir/test/Dialect/Math/expand-math.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index af65af6fedec6..df5787dc48403 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -365,8 +365,9 @@ def Math_ClampFOp : Math_FloatTernaryOp<"clampf"> {
 
     The semantics of the operation are described by:
     ```
-      clampf(value, min, max) = maxf(minf(value, min), max)
+      clampf(value, min, max) = maxf(minf(value, max), min)
     ```
+    If `min > max` the resulting value is poison.
 
     Example:
 

diff  --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index cd68039d0d964..249a95cc7924a 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -669,8 +669,8 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
 static LogicalResult convertClampfOp(math::ClampFOp op,
                                      PatternRewriter &rewriter) {
   auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
-                                         op.getMin(), op.getFastmath());
-  rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMax(),
+                                         op.getMax(), op.getFastmath());
+  rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMin(),
                                                  op.getFastmath());
   return success();
 }

diff  --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 615c607efc3c3..1d26c826e8d6b 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -827,8 +827,8 @@ func.func @unranked_rsqrt_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
 
 // CHECK-LABEL:    func.func @clampf_scalar_op
 // CHECK-SAME:     (%[[ARG:.*]]: f16, %[[MIN:.*]]: f16, %[[MAX:.*]]: f16)
-// CHECK:          %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] : f16
-// CHECK:          %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] : f16
+// CHECK:          %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MAX]] : f16
+// CHECK:          %[[V1:.*]] = arith.maximumf %[[V0]], %[[MIN]] : f16
 // CHECK:          return %[[V1]] : f16
 
 func.func @clampf_scalar_op(%arg: f16, %min: f16, %max: f16) -> f16 {
@@ -838,8 +838,8 @@ func.func @clampf_scalar_op(%arg: f16, %min: f16, %max: f16) -> f16 {
 
 // CHECK-LABEL:    func.func @clampf_vector_op
 // CHECK-SAME:     (%[[ARG:.*]]: vector<3x4xf32>, %[[MIN:.*]]: vector<3x4xf32>, %[[MAX:.*]]: vector<3x4xf32>)
-// CHECK:          %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] fastmath<fast> : vector<3x4xf32>
-// CHECK:          %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] fastmath<fast> : vector<3x4xf32>
+// CHECK:          %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MAX]] fastmath<fast> : vector<3x4xf32>
+// CHECK:          %[[V1:.*]] = arith.maximumf %[[V0]], %[[MIN]] fastmath<fast> : vector<3x4xf32>
 // CHECK:          return %[[V1]] : vector<3x4xf32>
 
 func.func @clampf_vector_op(%arg: vector<3x4xf32>, %min: vector<3x4xf32>, %max: vector<3x4xf32>) -> vector<3x4xf32>{


        


More information about the Mlir-commits mailing list