[Mlir-commits] [mlir] [MLIR][Math] Fix math.ceil expansion to avoid undefined behavior on Inf/NaN (PR #170028)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 30 01:26:34 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-math

Author: Hank (hankluo6)

<details>
<summary>Changes</summary>

Fixes #<!-- -->151786

`fptosi` produces poison when the input is Inf, and any subsequent use leads to undefined behavior. This patch adds a safe path, similar to the existing `round` expansion, for large or special inputs and avoids the UB.

---
Full diff: https://github.com/llvm/llvm-project/pull/170028.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp (+33-1) 
- (modified) mlir/test/Dialect/Math/expand-math.mlir (+10-1) 


``````````diff
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index cd68039d0d964..e9f4811aae3fe 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -232,6 +232,37 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type opType = operand.getType();
+
+  auto operandETy = getElementTypeOrSelf(opType);
+  unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
+  unsigned mantissaWidth =
+      llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
+  unsigned exponentWidth = bitWidth - mantissaWidth - 1;
+
+  Type iTy = rewriter.getIntegerType(bitWidth);
+  if (auto shapedTy = dyn_cast<ShapedType>(opType))
+    iTy = shapedTy.clone(iTy);
+
+  Value cMantissaWidth = createIntConst(op->getLoc(), iTy, mantissaWidth, b);
+  Value cBias =
+      createIntConst(op->getLoc(), iTy, (1ull << (exponentWidth - 1)) - 1, b);
+  Value cExpMask =
+      createIntConst(op->getLoc(), iTy, (1ull << exponentWidth) - 1, b);
+
+  // Any floating-point value with an unbiased exponent ≥ `mantissaWidth`
+  // falls into one of these categories:
+  //   - a large finite value (|x| ≥ 2^mantissaWidth), where all representable
+  //     numbers are already integral, or
+  //   - a special value (NaN or ±Inf), which also satisfies this exponent
+  //     condition.
+  // For all such cases, `ceilf(x)` is defined to return `x` directly.
+  Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
+  Value operandExp = arith::AndIOp::create(
+      b, arith::ShRUIOp::create(b, operandBitcast, cMantissaWidth), cExpMask);
+  Value operandBiasedExp = arith::SubIOp::create(b, operandExp, cBias);
+  Value isSpecialValOrLargeVal = arith::CmpIOp::create(
+      b, arith::CmpIPredicate::sge, operandBiasedExp, cMantissaWidth);
+
   Value fpFixedConvert = createTruncatedFPValue(operand, b);
 
   // Creating constants for later use.
@@ -243,7 +274,8 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   Value incrValue =
       arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
 
-  Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
+  Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
+  Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add);
   rewriter.replaceOp(op, ret);
   return success();
 }
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 615c607efc3c3..75f8e65b334a2 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -145,13 +145,22 @@ func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 {
 func.func @ceilf_func(%a: f64) -> f64 {
   // CHECK-DAG:   [[CST:%.+]] = arith.constant 0.000
   // CHECK-DAG:   [[CST_0:%.+]] = arith.constant 1.000
+  // CHECK-DAG:   [[C52:%.*]] = arith.constant 52
+  // CHECK-DAG:   [[C1023:%.*]] = arith.constant 1023
+  // CHECK-DAG:   [[EXP_MASK:%.*]] = arith.constant 2047
+  // CHECK-NEXT:   [[ARG_BITCAST:%.*]] = arith.bitcast [[ARG0]] : f64 to i64
+  // CHECK-NEXT:   [[ARG_BITCAST_SHIFTED:%.*]] = arith.shrui [[ARG_BITCAST]], [[C52]]
+  // CHECK-NEXT:   [[ARG_EXP:%.*]] = arith.andi [[ARG_BITCAST_SHIFTED]], [[EXP_MASK]]
+  // CHECK-NEXT:   [[ARG_BIASED_EXP:%.*]] = arith.subi [[ARG_EXP]], [[C1023]]
+  // CHECK-NEXT:   [[IS_SPECIAL_VAL:%.*]] = arith.cmpi sge, [[ARG_BIASED_EXP]], [[C52]]
   // CHECK-NEXT:   [[CVTI:%.+]] = arith.fptosi [[ARG0]]
   // CHECK-NEXT:   [[CVTF:%.+]] = arith.sitofp [[CVTI]]
   // CHECK-NEXT:   [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
   // CHECK-NEXT:   [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]]
   // CHECK-NEXT:   [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
   // CHECK-NEXT:   [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
-  // CHECK-NEXT:   return [[ADDF]]
+  // CHECK-NEXT:   [[RESULT:%.*]] = arith.select [[IS_SPECIAL_VAL]], [[ARG0]], [[ADDF]]
+  // CHECK-NEXT:   return [[RESULT]]
   // CHECK-FILTER: math.ceil
   %ret = math.ceil %a : f64
   return %ret : f64

``````````

</details>


https://github.com/llvm/llvm-project/pull/170028


More information about the Mlir-commits mailing list