[Mlir-commits] [mlir] 4b8ff4c - [mlir][math] Fix the semantics of math.clampf (#175012)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 19 02:49:06 PST 2026
Author: Jack
Date: 2026-01-19T10:49:01Z
New Revision: 4b8ff4c63d681ac7376bb874f0302538ac9a6d27
URL: https://github.com/llvm/llvm-project/commit/4b8ff4c63d681ac7376bb874f0302538ac9a6d27
DIFF: https://github.com/llvm/llvm-project/commit/4b8ff4c63d681ac7376bb874f0302538ac9a6d27.diff
LOG: [mlir][math] Fix the semantics of math.clampf (#175012)
The `math.clampf` op is semantically incorrect when compared to both the
CUDA reference implementation and the SPIRV spec, both of which have a
clamp op.
- Fix the definition of `math.clampf` to agree with CUDA and SPIRV
- Explicitly state when `math.clampf` produces `ub.poison`
- Update the ExpandOps pass to reflect the corrected semantics
Added:
Modified:
mlir/include/mlir/Dialect/Math/IR/MathOps.td
mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
mlir/test/Dialect/Math/expand-math.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index af65af6fedec6..df5787dc48403 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -365,8 +365,9 @@ def Math_ClampFOp : Math_FloatTernaryOp<"clampf"> {
The semantics of the operation are described by:
```
- clampf(value, min, max) = maxf(minf(value, min), max)
+ clampf(value, min, max) = maxf(minf(value, max), min)
```
+ If `min > max` the resulting value is poison.
Example:
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index cd68039d0d964..249a95cc7924a 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -669,8 +669,8 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
static LogicalResult convertClampfOp(math::ClampFOp op,
PatternRewriter &rewriter) {
auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
- op.getMin(), op.getFastmath());
- rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMax(),
+ op.getMax(), op.getFastmath());
+ rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMin(),
op.getFastmath());
return success();
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 615c607efc3c3..1d26c826e8d6b 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -827,8 +827,8 @@ func.func @unranked_rsqrt_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
// CHECK-LABEL: func.func @clampf_scalar_op
// CHECK-SAME: (%[[ARG:.*]]: f16, %[[MIN:.*]]: f16, %[[MAX:.*]]: f16)
-// CHECK: %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] : f16
-// CHECK: %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] : f16
+// CHECK: %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MAX]] : f16
+// CHECK: %[[V1:.*]] = arith.maximumf %[[V0]], %[[MIN]] : f16
// CHECK: return %[[V1]] : f16
func.func @clampf_scalar_op(%arg: f16, %min: f16, %max: f16) -> f16 {
@@ -838,8 +838,8 @@ func.func @clampf_scalar_op(%arg: f16, %min: f16, %max: f16) -> f16 {
// CHECK-LABEL: func.func @clampf_vector_op
// CHECK-SAME: (%[[ARG:.*]]: vector<3x4xf32>, %[[MIN:.*]]: vector<3x4xf32>, %[[MAX:.*]]: vector<3x4xf32>)
-// CHECK: %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] fastmath<fast> : vector<3x4xf32>
-// CHECK: %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] fastmath<fast> : vector<3x4xf32>
+// CHECK: %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MAX]] fastmath<fast> : vector<3x4xf32>
+// CHECK: %[[V1:.*]] = arith.maximumf %[[V0]], %[[MIN]] fastmath<fast> : vector<3x4xf32>
// CHECK: return %[[V1]] : vector<3x4xf32>
func.func @clampf_vector_op(%arg: vector<3x4xf32>, %min: vector<3x4xf32>, %max: vector<3x4xf32>) -> vector<3x4xf32>{
More information about the Mlir-commits
mailing list