[Mlir-commits] [mlir] [MLIR][Math] Fix math.ceil expansion to avoid undefined behavior on Inf/NaN (PR #170028)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 30 01:26:34 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-math
Author: Hank (hankluo6)
<details>
<summary>Changes</summary>
Fixes #<!-- -->151786
`fptosi` produces poison when the input is Inf, and any subsequent use leads to undefined behavior. This patch adds a safe path, similar to the existing `round` expansion, for large or special inputs and avoids the UB.
---
Full diff: https://github.com/llvm/llvm-project/pull/170028.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp (+33-1)
- (modified) mlir/test/Dialect/Math/expand-math.mlir (+10-1)
``````````diff
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index cd68039d0d964..e9f4811aae3fe 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -232,6 +232,37 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
+
+ auto operandETy = getElementTypeOrSelf(opType);
+ unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
+ unsigned mantissaWidth =
+ llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
+ unsigned exponentWidth = bitWidth - mantissaWidth - 1;
+
+ Type iTy = rewriter.getIntegerType(bitWidth);
+ if (auto shapedTy = dyn_cast<ShapedType>(opType))
+ iTy = shapedTy.clone(iTy);
+
+ Value cMantissaWidth = createIntConst(op->getLoc(), iTy, mantissaWidth, b);
+ Value cBias =
+ createIntConst(op->getLoc(), iTy, (1ull << (exponentWidth - 1)) - 1, b);
+ Value cExpMask =
+ createIntConst(op->getLoc(), iTy, (1ull << exponentWidth) - 1, b);
+
+ // Any floating-point value with an unbiased exponent ≥ `mantissaWidth`
+ // falls into one of these categories:
+ // - a large finite value (|x| ≥ 2^mantissaWidth), where all representable
+ // numbers are already integral, or
+ // - a special value (NaN or ±Inf), which also satisfies this exponent
+ // condition.
+ // For all such cases, `ceilf(x)` is defined to return `x` directly.
+ Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
+ Value operandExp = arith::AndIOp::create(
+ b, arith::ShRUIOp::create(b, operandBitcast, cMantissaWidth), cExpMask);
+ Value operandBiasedExp = arith::SubIOp::create(b, operandExp, cBias);
+ Value isSpecialValOrLargeVal = arith::CmpIOp::create(
+ b, arith::CmpIPredicate::sge, operandBiasedExp, cMantissaWidth);
+
Value fpFixedConvert = createTruncatedFPValue(operand, b);
// Creating constants for later use.
@@ -243,7 +274,8 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
Value incrValue =
arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
- Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
+ Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
+ Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add);
rewriter.replaceOp(op, ret);
return success();
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 615c607efc3c3..75f8e65b334a2 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -145,13 +145,22 @@ func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 {
func.func @ceilf_func(%a: f64) -> f64 {
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000
// CHECK-DAG: [[CST_0:%.+]] = arith.constant 1.000
+ // CHECK-DAG: [[C52:%.*]] = arith.constant 52
+ // CHECK-DAG: [[C1023:%.*]] = arith.constant 1023
+ // CHECK-DAG: [[EXP_MASK:%.*]] = arith.constant 2047
+ // CHECK-NEXT: [[ARG_BITCAST:%.*]] = arith.bitcast [[ARG0]] : f64 to i64
+ // CHECK-NEXT: [[ARG_BITCAST_SHIFTED:%.*]] = arith.shrui [[ARG_BITCAST]], [[C52]]
+ // CHECK-NEXT: [[ARG_EXP:%.*]] = arith.andi [[ARG_BITCAST_SHIFTED]], [[EXP_MASK]]
+ // CHECK-NEXT: [[ARG_BIASED_EXP:%.*]] = arith.subi [[ARG_EXP]], [[C1023]]
+ // CHECK-NEXT: [[IS_SPECIAL_VAL:%.*]] = arith.cmpi sge, [[ARG_BIASED_EXP]], [[C52]]
// CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]]
// CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
// CHECK-NEXT: [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
// CHECK-NEXT: [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]]
// CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
// CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
- // CHECK-NEXT: return [[ADDF]]
+ // CHECK-NEXT: [[RESULT:%.*]] = arith.select [[IS_SPECIAL_VAL]], [[ARG0]], [[ADDF]]
+ // CHECK-NEXT: return [[RESULT]]
// CHECK-FILTER: math.ceil
%ret = math.ceil %a : f64
return %ret : f64
``````````
</details>
https://github.com/llvm/llvm-project/pull/170028
More information about the Mlir-commits
mailing list