[Mlir-commits] [mlir] [MLIR][Math] Fix math.ceil expansion to avoid undefined behavior on Inf/NaN (PR #170028)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Feb 7 18:27:40 PST 2026


https://github.com/hankluo6 updated https://github.com/llvm/llvm-project/pull/170028

>From 3a9f80fdd3c0c8c0146652199b63f2b414de7d92 Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Sun, 16 Nov 2025 23:48:19 -0800
Subject: [PATCH 1/4] Fix ceilf expansion to avoid undefined behavior on
 Inf/NaN

---
 .../lib/Dialect/Math/Transforms/ExpandOps.cpp | 34 ++++++++++++++++++-
 mlir/test/Dialect/Math/expand-math.mlir       | 11 +++++-
 2 files changed, 43 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index cd68039d0d964..e9f4811aae3fe 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -232,6 +232,37 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type opType = operand.getType();
+
+  auto operandETy = getElementTypeOrSelf(opType);
+  unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
+  unsigned mantissaWidth =
+      llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
+  unsigned exponentWidth = bitWidth - mantissaWidth - 1;
+
+  Type iTy = rewriter.getIntegerType(bitWidth);
+  if (auto shapedTy = dyn_cast<ShapedType>(opType))
+    iTy = shapedTy.clone(iTy);
+
+  Value cMantissaWidth = createIntConst(op->getLoc(), iTy, mantissaWidth, b);
+  Value cBias =
+      createIntConst(op->getLoc(), iTy, (1ull << (exponentWidth - 1)) - 1, b);
+  Value cExpMask =
+      createIntConst(op->getLoc(), iTy, (1ull << exponentWidth) - 1, b);
+
+  // Any floating-point value with an unbiased exponent ≥ `mantissaWidth`
+  // falls into one of these categories:
+  //   - a large finite value (|x| ≥ 2^mantissaWidth), where all representable
+  //     numbers are already integral, or
+  //   - a special value (NaN or ±Inf), which also satisfies this exponent
+  //     condition.
+  // For all such cases, `ceilf(x)` is defined to return `x` directly.
+  Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
+  Value operandExp = arith::AndIOp::create(
+      b, arith::ShRUIOp::create(b, operandBitcast, cMantissaWidth), cExpMask);
+  Value operandBiasedExp = arith::SubIOp::create(b, operandExp, cBias);
+  Value isSpecialValOrLargeVal = arith::CmpIOp::create(
+      b, arith::CmpIPredicate::sge, operandBiasedExp, cMantissaWidth);
+
   Value fpFixedConvert = createTruncatedFPValue(operand, b);
 
   // Creating constants for later use.
@@ -243,7 +274,8 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   Value incrValue =
       arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
 
-  Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
+  Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
+  Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add);
   rewriter.replaceOp(op, ret);
   return success();
 }
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 615c607efc3c3..75f8e65b334a2 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -145,13 +145,22 @@ func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 {
 func.func @ceilf_func(%a: f64) -> f64 {
   // CHECK-DAG:   [[CST:%.+]] = arith.constant 0.000
   // CHECK-DAG:   [[CST_0:%.+]] = arith.constant 1.000
+  // CHECK-DAG:   [[C52:%.*]] = arith.constant 52
+  // CHECK-DAG:   [[C1023:%.*]] = arith.constant 1023
+  // CHECK-DAG:   [[EXP_MASK:%.*]] = arith.constant 2047
+  // CHECK-NEXT:   [[ARG_BITCAST:%.*]] = arith.bitcast [[ARG0]] : f64 to i64
+  // CHECK-NEXT:   [[ARG_BITCAST_SHIFTED:%.*]] = arith.shrui [[ARG_BITCAST]], [[C52]]
+  // CHECK-NEXT:   [[ARG_EXP:%.*]] = arith.andi [[ARG_BITCAST_SHIFTED]], [[EXP_MASK]]
+  // CHECK-NEXT:   [[ARG_BIASED_EXP:%.*]] = arith.subi [[ARG_EXP]], [[C1023]]
+  // CHECK-NEXT:   [[IS_SPECIAL_VAL:%.*]] = arith.cmpi sge, [[ARG_BIASED_EXP]], [[C52]]
   // CHECK-NEXT:   [[CVTI:%.+]] = arith.fptosi [[ARG0]]
   // CHECK-NEXT:   [[CVTF:%.+]] = arith.sitofp [[CVTI]]
   // CHECK-NEXT:   [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
   // CHECK-NEXT:   [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]]
   // CHECK-NEXT:   [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
   // CHECK-NEXT:   [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
-  // CHECK-NEXT:   return [[ADDF]]
+  // CHECK-NEXT:   [[RESULT:%.*]] = arith.select [[IS_SPECIAL_VAL]], [[ARG0]], [[ADDF]]
+  // CHECK-NEXT:   return [[RESULT]]
   // CHECK-FILTER: math.ceil
   %ret = math.ceil %a : f64
   return %ret : f64

>From 348f83aa9728b498ab98c9349944e1d72ce7c435 Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Sat, 7 Feb 2026 15:28:33 -0800
Subject: [PATCH 2/4] Support FNUZ-suffixed fp

---
 .../lib/Dialect/Math/Transforms/ExpandOps.cpp | 48 ++++++++++++-------
 mlir/test/Dialect/Math/expand-math.mlir       | 45 +++++++++++++----
 2 files changed, 66 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index c1f3da8a75094..a541e603e24a4 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -232,36 +232,50 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   ImplicitLocOpBuilder b(op->getLoc(), rewriter);
   Value operand = op.getOperand();
   Type opType = operand.getType();
+  Type operandETy = getElementTypeOrSelf(opType);
+  FloatType floatTy = llvm::dyn_cast<FloatType>(operandETy);
+  const llvm::fltSemantics &semantics = floatTy.getFloatSemantics();
 
-  auto operandETy = getElementTypeOrSelf(opType);
-  unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
-  unsigned mantissaWidth =
-      llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
+  unsigned bitWidth = floatTy.getWidth();
+  unsigned mantissaWidth = floatTy.getFPMantissaWidth() - 1;
   unsigned exponentWidth = bitWidth - mantissaWidth - 1;
+  const int bias = (&semantics == &APFloat::Float8E8M0FNU())
+                       ? -semantics.minExponent
+                       : -(semantics.minExponent - 1);
+  bool hasNegativeZeroNaNEncoding =
+      (semantics.nanEncoding == llvm::fltNanEncoding::NegativeZero);
 
   Type iTy = rewriter.getIntegerType(bitWidth);
   if (auto shapedTy = dyn_cast<ShapedType>(opType))
     iTy = shapedTy.clone(iTy);
 
-  Value cMantissaWidth = createIntConst(op->getLoc(), iTy, mantissaWidth, b);
-  Value cBias =
-      createIntConst(op->getLoc(), iTy, (1ull << (exponentWidth - 1)) - 1, b);
-  Value cExpMask =
-      createIntConst(op->getLoc(), iTy, (1ull << exponentWidth) - 1, b);
-
-  // Any floating-point value with an unbiased exponent ≥ `mantissaWidth`
-  // falls into one of these categories:
+  // For IEEE-like floating-point formats with an unbiased exponent ≥
+  // `mantissaWidth` falls into one of these categories:
   //   - a large finite value (|x| ≥ 2^mantissaWidth), where all representable
   //     numbers are already integral, or
   //   - a special value (NaN or ±Inf), which also satisfies this exponent
   //     condition.
   // For all such cases, `ceilf(x)` is defined to return `x` directly.
   Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
-  Value operandExp = arith::AndIOp::create(
-      b, arith::ShRUIOp::create(b, operandBitcast, cMantissaWidth), cExpMask);
-  Value operandBiasedExp = arith::SubIOp::create(b, operandExp, cBias);
-  Value isSpecialValOrLargeVal = arith::CmpIOp::create(
-      b, arith::CmpIPredicate::sge, operandBiasedExp, cMantissaWidth);
+  Value cMask =
+      createIntConst(op->getLoc(), iTy, (1ull << (bitWidth - 1)) - 1, b);
+  Value unsignedBits = arith::AndIOp::create(b, operandBitcast, cMask);
+  Value cThreshold = createIntConst(
+      op->getLoc(), iTy, (uint64_t(bias + mantissaWidth)) << mantissaWidth, b);
+  Value isLargeExp =
+      arith::CmpIOp::create(b, arith::CmpIPredicate::uge, unsignedBits, cThreshold);
+  Value isSpecialValOrLargeVal = isLargeExp;
+
+  // In FNUZ-suffixed floating point, NaN is represented by a sign bit of 1 and
+  // all 0s in the exponent and mantissa, therefore requires an explicit check.
+  if (hasNegativeZeroNaNEncoding) {
+    Value cNegZeroBits =
+        createIntConst(op->getLoc(), iTy, 1ull << (bitWidth - 1), b);
+    Value isNegZeroEncoding = arith::CmpIOp::create(
+        b, arith::CmpIPredicate::eq, operandBitcast, cNegZeroBits);
+    isSpecialValOrLargeVal =
+        arith::OrIOp::create(b, isLargeExp, isNegZeroEncoding);
+  }
 
   Value fpFixedConvert = createTruncatedFPValue(operand, b);
 
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index c9eb13f845c2a..5f2843045b885 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -143,21 +143,18 @@ func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 {
 // CHECK-LABEL:     func @ceilf_func
 // CHECK-SAME:      ([[ARG0:%.+]]: f64) -> f64
 func.func @ceilf_func(%a: f64) -> f64 {
-  // CHECK-DAG:   [[CST:%.+]] = arith.constant 0.000
-  // CHECK-DAG:   [[CST_0:%.+]] = arith.constant 1.000
-  // CHECK-DAG:   [[C52:%.*]] = arith.constant 52
-  // CHECK-DAG:   [[C1023:%.*]] = arith.constant 1023
-  // CHECK-DAG:   [[EXP_MASK:%.*]] = arith.constant 2047
+  // CHECK-DAG:   [[C_0:%.+]] = arith.constant 0.000
+  // CHECK-DAG:   [[C_1:%.+]] = arith.constant 1.000
+  // CHECK-DAG:   [[C_4841369599423283200:%.*]] = arith.constant 4841369599423283200
+  // CHECK-DAG:   [[C_9223372036854775807:%.*]] = arith.constant 9223372036854775807
   // CHECK-NEXT:   [[ARG_BITCAST:%.*]] = arith.bitcast [[ARG0]] : f64 to i64
-  // CHECK-NEXT:   [[ARG_BITCAST_SHIFTED:%.*]] = arith.shrui [[ARG_BITCAST]], [[C52]]
-  // CHECK-NEXT:   [[ARG_EXP:%.*]] = arith.andi [[ARG_BITCAST_SHIFTED]], [[EXP_MASK]]
-  // CHECK-NEXT:   [[ARG_BIASED_EXP:%.*]] = arith.subi [[ARG_EXP]], [[C1023]]
-  // CHECK-NEXT:   [[IS_SPECIAL_VAL:%.*]] = arith.cmpi sge, [[ARG_BIASED_EXP]], [[C52]]
+  // CHECK-NEXT:   [[ANDI:%.*]] = arith.andi [[ARG_BITCAST]], [[C_9223372036854775807]]
+  // CHECK-NEXT:   [[IS_SPECIAL_VAL:%.*]] = arith.cmpi uge, [[ANDI]], [[C_4841369599423283200]]
   // CHECK-NEXT:   [[CVTI:%.+]] = arith.fptosi [[ARG0]]
   // CHECK-NEXT:   [[CVTF:%.+]] = arith.sitofp [[CVTI]]
   // CHECK-NEXT:   [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
   // CHECK-NEXT:   [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]]
-  // CHECK-NEXT:   [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
+  // CHECK-NEXT:   [[INCR:%.+]] = arith.select [[COMP]], [[C_1]], [[C_0]]
   // CHECK-NEXT:   [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
   // CHECK-NEXT:   [[RESULT:%.*]] = arith.select [[IS_SPECIAL_VAL]], [[ARG0]], [[ADDF]]
   // CHECK-NEXT:   return [[RESULT]]
@@ -168,6 +165,34 @@ func.func @ceilf_func(%a: f64) -> f64 {
 
 // -----
 
+// CHECK-LABEL:     func @ceilf_fnuz_func
+// CHECK-SAME:      ([[ARG0:%.+]]: f8E5M2FNUZ) -> f8E5M2FNUZ
+func.func @ceilf_fnuz_func(%a: f8E5M2FNUZ) -> f8E5M2FNUZ {
+  // CHECK-DAG:   [[C_0:%.+]] = arith.constant 0.000
+  // CHECK-DAG:   [[C_1:%.+]] = arith.constant 1.000
+  // CHECK-DAG:   [[C_NEG_128:%.*]] = arith.constant -128
+  // CHECK-DAG:   [[C_72:%.*]] = arith.constant 72
+  // CHECK-DAG:   [[C_127:%.*]] = arith.constant 127
+  // CHECK-NEXT:   [[ARG_BITCAST:%.*]] = arith.bitcast [[ARG0]] : f8E5M2FNUZ to i8
+  // CHECK-NEXT:   [[ANDI:%.*]] = arith.andi [[ARG_BITCAST]], [[C_127]]
+  // CHECK-NEXT:   [[IS_LARGE:%.+]] = arith.cmpi uge, [[ANDI]], [[C_72]]
+  // CHECK-NEXT:   [[IS_NAN:%.+]] = arith.cmpi eq, [[ARG_BITCAST]], [[C_NEG_128]]
+  // CHECK-NEXT:   [[IS_SPECIAL_VAL:%.+]] = arith.ori [[IS_LARGE]], [[IS_NAN]]
+  // CHECK-NEXT:   [[CVTI:%.+]] = arith.fptosi [[ARG0]]
+  // CHECK-NEXT:   [[CVTF:%.+]] = arith.sitofp [[CVTI]]
+  // CHECK-NEXT:   [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
+  // CHECK-NEXT:   [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]]
+  // CHECK-NEXT:   [[INCR:%.+]] = arith.select [[COMP]], [[C_1]], [[C_0]]
+  // CHECK-NEXT:   [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
+  // CHECK-NEXT:   [[RESULT:%.*]] = arith.select [[IS_SPECIAL_VAL]], [[ARG0]], [[ADDF]]
+  // CHECK-NEXT:   return [[RESULT]]
+  // CHECK-FILTER: math.ceil
+  %ret = math.ceil %a : f8E5M2FNUZ
+  return %ret : f8E5M2FNUZ
+}
+
+// -----
+
 // CHECK-LABEL:     func @exp2f_func
 // CHECK-SAME:      ([[ARG0:%.+]]: f64) -> f64
 func.func @exp2f_func(%a: f64) -> f64 {

>From d6b88823dacff5441541f9511e07140f12ac197f Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Sat, 7 Feb 2026 18:00:24 -0800
Subject: [PATCH 3/4] Fix format

---
 mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index a541e603e24a4..4fc8071b9e74d 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -262,8 +262,8 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
   Value unsignedBits = arith::AndIOp::create(b, operandBitcast, cMask);
   Value cThreshold = createIntConst(
       op->getLoc(), iTy, (uint64_t(bias + mantissaWidth)) << mantissaWidth, b);
-  Value isLargeExp =
-      arith::CmpIOp::create(b, arith::CmpIPredicate::uge, unsignedBits, cThreshold);
+  Value isLargeExp = arith::CmpIOp::create(b, arith::CmpIPredicate::uge,
+                                           unsignedBits, cThreshold);
   Value isSpecialValOrLargeVal = isLargeExp;
 
   // In FNUZ-suffixed floating point, NaN is represented by a sign bit of 1 and

>From 03374d976ec28ec2a8fcac162fbb2debbeacbb67 Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Sat, 7 Feb 2026 18:27:21 -0800
Subject: [PATCH 4/4] Remove unused variable

---
 mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index 4fc8071b9e74d..d563742da3361 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -238,7 +238,6 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
 
   unsigned bitWidth = floatTy.getWidth();
   unsigned mantissaWidth = floatTy.getFPMantissaWidth() - 1;
-  unsigned exponentWidth = bitWidth - mantissaWidth - 1;
   const int bias = (&semantics == &APFloat::Float8E8M0FNU())
                        ? -semantics.minExponent
                        : -(semantics.minExponent - 1);



More information about the Mlir-commits mailing list