[Mlir-commits] [mlir] [MLIR][Math] Fix math.ceil expansion to avoid undefined behavior on Inf/NaN (PR #170028)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Feb 7 18:27:40 PST 2026
https://github.com/hankluo6 updated https://github.com/llvm/llvm-project/pull/170028
>From 3a9f80fdd3c0c8c0146652199b63f2b414de7d92 Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Sun, 16 Nov 2025 23:48:19 -0800
Subject: [PATCH 1/4] Fix ceilf expansion to avoid undefined behavior on
Inf/NaN
---
.../lib/Dialect/Math/Transforms/ExpandOps.cpp | 34 ++++++++++++++++++-
mlir/test/Dialect/Math/expand-math.mlir | 11 +++++-
2 files changed, 43 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index cd68039d0d964..e9f4811aae3fe 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -232,6 +232,37 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
+
+ auto operandETy = getElementTypeOrSelf(opType);
+ unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
+ unsigned mantissaWidth =
+ llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
+ unsigned exponentWidth = bitWidth - mantissaWidth - 1;
+
+ Type iTy = rewriter.getIntegerType(bitWidth);
+ if (auto shapedTy = dyn_cast<ShapedType>(opType))
+ iTy = shapedTy.clone(iTy);
+
+ Value cMantissaWidth = createIntConst(op->getLoc(), iTy, mantissaWidth, b);
+ Value cBias =
+ createIntConst(op->getLoc(), iTy, (1ull << (exponentWidth - 1)) - 1, b);
+ Value cExpMask =
+ createIntConst(op->getLoc(), iTy, (1ull << exponentWidth) - 1, b);
+
+ // Any floating-point value with an unbiased exponent ≥ `mantissaWidth`
+ // falls into one of these categories:
+ // - a large finite value (|x| ≥ 2^mantissaWidth), where all representable
+ // numbers are already integral, or
+ // - a special value (NaN or ±Inf), which also satisfies this exponent
+ // condition.
+ // For all such cases, `ceilf(x)` is defined to return `x` directly.
+ Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
+ Value operandExp = arith::AndIOp::create(
+ b, arith::ShRUIOp::create(b, operandBitcast, cMantissaWidth), cExpMask);
+ Value operandBiasedExp = arith::SubIOp::create(b, operandExp, cBias);
+ Value isSpecialValOrLargeVal = arith::CmpIOp::create(
+ b, arith::CmpIPredicate::sge, operandBiasedExp, cMantissaWidth);
+
Value fpFixedConvert = createTruncatedFPValue(operand, b);
// Creating constants for later use.
@@ -243,7 +274,8 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
Value incrValue =
arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
- Value ret = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
+ Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
+ Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add);
rewriter.replaceOp(op, ret);
return success();
}
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 615c607efc3c3..75f8e65b334a2 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -145,13 +145,22 @@ func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 {
func.func @ceilf_func(%a: f64) -> f64 {
// CHECK-DAG: [[CST:%.+]] = arith.constant 0.000
// CHECK-DAG: [[CST_0:%.+]] = arith.constant 1.000
+ // CHECK-DAG: [[C52:%.*]] = arith.constant 52
+ // CHECK-DAG: [[C1023:%.*]] = arith.constant 1023
+ // CHECK-DAG: [[EXP_MASK:%.*]] = arith.constant 2047
+ // CHECK-NEXT: [[ARG_BITCAST:%.*]] = arith.bitcast [[ARG0]] : f64 to i64
+ // CHECK-NEXT: [[ARG_BITCAST_SHIFTED:%.*]] = arith.shrui [[ARG_BITCAST]], [[C52]]
+ // CHECK-NEXT: [[ARG_EXP:%.*]] = arith.andi [[ARG_BITCAST_SHIFTED]], [[EXP_MASK]]
+ // CHECK-NEXT: [[ARG_BIASED_EXP:%.*]] = arith.subi [[ARG_EXP]], [[C1023]]
+ // CHECK-NEXT: [[IS_SPECIAL_VAL:%.*]] = arith.cmpi sge, [[ARG_BIASED_EXP]], [[C52]]
// CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]]
// CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
// CHECK-NEXT: [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
// CHECK-NEXT: [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]]
// CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
// CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
- // CHECK-NEXT: return [[ADDF]]
+ // CHECK-NEXT: [[RESULT:%.*]] = arith.select [[IS_SPECIAL_VAL]], [[ARG0]], [[ADDF]]
+ // CHECK-NEXT: return [[RESULT]]
// CHECK-FILTER: math.ceil
%ret = math.ceil %a : f64
return %ret : f64
>From 348f83aa9728b498ab98c9349944e1d72ce7c435 Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Sat, 7 Feb 2026 15:28:33 -0800
Subject: [PATCH 2/4] Support FNUZ-suffixed fp
---
.../lib/Dialect/Math/Transforms/ExpandOps.cpp | 48 ++++++++++++-------
mlir/test/Dialect/Math/expand-math.mlir | 45 +++++++++++++----
2 files changed, 66 insertions(+), 27 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index c1f3da8a75094..a541e603e24a4 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -232,36 +232,50 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value operand = op.getOperand();
Type opType = operand.getType();
+ Type operandETy = getElementTypeOrSelf(opType);
+ FloatType floatTy = llvm::dyn_cast<FloatType>(operandETy);
+ const llvm::fltSemantics &semantics = floatTy.getFloatSemantics();
- auto operandETy = getElementTypeOrSelf(opType);
- unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
- unsigned mantissaWidth =
- llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
+ unsigned bitWidth = floatTy.getWidth();
+ unsigned mantissaWidth = floatTy.getFPMantissaWidth() - 1;
unsigned exponentWidth = bitWidth - mantissaWidth - 1;
+ const int bias = (&semantics == &APFloat::Float8E8M0FNU())
+ ? -semantics.minExponent
+ : -(semantics.minExponent - 1);
+ bool hasNegativeZeroNaNEncoding =
+ (semantics.nanEncoding == llvm::fltNanEncoding::NegativeZero);
Type iTy = rewriter.getIntegerType(bitWidth);
if (auto shapedTy = dyn_cast<ShapedType>(opType))
iTy = shapedTy.clone(iTy);
- Value cMantissaWidth = createIntConst(op->getLoc(), iTy, mantissaWidth, b);
- Value cBias =
- createIntConst(op->getLoc(), iTy, (1ull << (exponentWidth - 1)) - 1, b);
- Value cExpMask =
- createIntConst(op->getLoc(), iTy, (1ull << exponentWidth) - 1, b);
-
- // Any floating-point value with an unbiased exponent ≥ `mantissaWidth`
- // falls into one of these categories:
+ // For IEEE-like floating-point formats with an unbiased exponent ≥
+ // `mantissaWidth` falls into one of these categories:
// - a large finite value (|x| ≥ 2^mantissaWidth), where all representable
// numbers are already integral, or
// - a special value (NaN or ±Inf), which also satisfies this exponent
// condition.
// For all such cases, `ceilf(x)` is defined to return `x` directly.
Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
- Value operandExp = arith::AndIOp::create(
- b, arith::ShRUIOp::create(b, operandBitcast, cMantissaWidth), cExpMask);
- Value operandBiasedExp = arith::SubIOp::create(b, operandExp, cBias);
- Value isSpecialValOrLargeVal = arith::CmpIOp::create(
- b, arith::CmpIPredicate::sge, operandBiasedExp, cMantissaWidth);
+ Value cMask =
+ createIntConst(op->getLoc(), iTy, (1ull << (bitWidth - 1)) - 1, b);
+ Value unsignedBits = arith::AndIOp::create(b, operandBitcast, cMask);
+ Value cThreshold = createIntConst(
+ op->getLoc(), iTy, (uint64_t(bias + mantissaWidth)) << mantissaWidth, b);
+ Value isLargeExp =
+ arith::CmpIOp::create(b, arith::CmpIPredicate::uge, unsignedBits, cThreshold);
+ Value isSpecialValOrLargeVal = isLargeExp;
+
+ // In FNUZ-suffixed floating point, NaN is represented by a sign bit of 1 and
+ // all 0s in the exponent and mantissa, therefore requires an explicit check.
+ if (hasNegativeZeroNaNEncoding) {
+ Value cNegZeroBits =
+ createIntConst(op->getLoc(), iTy, 1ull << (bitWidth - 1), b);
+ Value isNegZeroEncoding = arith::CmpIOp::create(
+ b, arith::CmpIPredicate::eq, operandBitcast, cNegZeroBits);
+ isSpecialValOrLargeVal =
+ arith::OrIOp::create(b, isLargeExp, isNegZeroEncoding);
+ }
Value fpFixedConvert = createTruncatedFPValue(operand, b);
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index c9eb13f845c2a..5f2843045b885 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -143,21 +143,18 @@ func.func @fmaf_func(%a: f64, %b: f64, %c: f64) -> f64 {
// CHECK-LABEL: func @ceilf_func
// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64
func.func @ceilf_func(%a: f64) -> f64 {
- // CHECK-DAG: [[CST:%.+]] = arith.constant 0.000
- // CHECK-DAG: [[CST_0:%.+]] = arith.constant 1.000
- // CHECK-DAG: [[C52:%.*]] = arith.constant 52
- // CHECK-DAG: [[C1023:%.*]] = arith.constant 1023
- // CHECK-DAG: [[EXP_MASK:%.*]] = arith.constant 2047
+ // CHECK-DAG: [[C_0:%.+]] = arith.constant 0.000
+ // CHECK-DAG: [[C_1:%.+]] = arith.constant 1.000
+ // CHECK-DAG: [[C_4841369599423283200:%.*]] = arith.constant 4841369599423283200
+ // CHECK-DAG: [[C_9223372036854775807:%.*]] = arith.constant 9223372036854775807
// CHECK-NEXT: [[ARG_BITCAST:%.*]] = arith.bitcast [[ARG0]] : f64 to i64
- // CHECK-NEXT: [[ARG_BITCAST_SHIFTED:%.*]] = arith.shrui [[ARG_BITCAST]], [[C52]]
- // CHECK-NEXT: [[ARG_EXP:%.*]] = arith.andi [[ARG_BITCAST_SHIFTED]], [[EXP_MASK]]
- // CHECK-NEXT: [[ARG_BIASED_EXP:%.*]] = arith.subi [[ARG_EXP]], [[C1023]]
- // CHECK-NEXT: [[IS_SPECIAL_VAL:%.*]] = arith.cmpi sge, [[ARG_BIASED_EXP]], [[C52]]
+ // CHECK-NEXT: [[ANDI:%.*]] = arith.andi [[ARG_BITCAST]], [[C_9223372036854775807]]
+ // CHECK-NEXT: [[IS_SPECIAL_VAL:%.*]] = arith.cmpi uge, [[ANDI]], [[C_4841369599423283200]]
// CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]]
// CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
// CHECK-NEXT: [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
// CHECK-NEXT: [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]]
- // CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
+ // CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[C_1]], [[C_0]]
// CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
// CHECK-NEXT: [[RESULT:%.*]] = arith.select [[IS_SPECIAL_VAL]], [[ARG0]], [[ADDF]]
// CHECK-NEXT: return [[RESULT]]
@@ -168,6 +165,34 @@ func.func @ceilf_func(%a: f64) -> f64 {
// -----
+// CHECK-LABEL: func @ceilf_fnuz_func
+// CHECK-SAME: ([[ARG0:%.+]]: f8E5M2FNUZ) -> f8E5M2FNUZ
+func.func @ceilf_fnuz_func(%a: f8E5M2FNUZ) -> f8E5M2FNUZ {
+ // CHECK-DAG: [[C_0:%.+]] = arith.constant 0.000
+ // CHECK-DAG: [[C_1:%.+]] = arith.constant 1.000
+ // CHECK-DAG: [[C_NEG_128:%.*]] = arith.constant -128
+ // CHECK-DAG: [[C_72:%.*]] = arith.constant 72
+ // CHECK-DAG: [[C_127:%.*]] = arith.constant 127
+ // CHECK-NEXT: [[ARG_BITCAST:%.*]] = arith.bitcast [[ARG0]] : f8E5M2FNUZ to i8
+ // CHECK-NEXT: [[ANDI:%.*]] = arith.andi [[ARG_BITCAST]], [[C_127]]
+ // CHECK-NEXT: [[IS_LARGE:%.+]] = arith.cmpi uge, [[ANDI]], [[C_72]]
+ // CHECK-NEXT: [[IS_NAN:%.+]] = arith.cmpi eq, [[ARG_BITCAST]], [[C_NEG_128]]
+ // CHECK-NEXT: [[IS_SPECIAL_VAL:%.+]] = arith.ori [[IS_LARGE]], [[IS_NAN]]
+ // CHECK-NEXT: [[CVTI:%.+]] = arith.fptosi [[ARG0]]
+ // CHECK-NEXT: [[CVTF:%.+]] = arith.sitofp [[CVTI]]
+ // CHECK-NEXT: [[COPYSIGN:%.+]] = math.copysign [[CVTF]], [[ARG0]]
+ // CHECK-NEXT: [[COMP:%.+]] = arith.cmpf ogt, [[ARG0]], [[COPYSIGN]]
+ // CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[C_1]], [[C_0]]
+ // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
+ // CHECK-NEXT: [[RESULT:%.*]] = arith.select [[IS_SPECIAL_VAL]], [[ARG0]], [[ADDF]]
+ // CHECK-NEXT: return [[RESULT]]
+ // CHECK-FILTER: math.ceil
+ %ret = math.ceil %a : f8E5M2FNUZ
+ return %ret : f8E5M2FNUZ
+}
+
+// -----
+
// CHECK-LABEL: func @exp2f_func
// CHECK-SAME: ([[ARG0:%.+]]: f64) -> f64
func.func @exp2f_func(%a: f64) -> f64 {
>From d6b88823dacff5441541f9511e07140f12ac197f Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Sat, 7 Feb 2026 18:00:24 -0800
Subject: [PATCH 3/4] Fix format
---
mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index a541e603e24a4..4fc8071b9e74d 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -262,8 +262,8 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
Value unsignedBits = arith::AndIOp::create(b, operandBitcast, cMask);
Value cThreshold = createIntConst(
op->getLoc(), iTy, (uint64_t(bias + mantissaWidth)) << mantissaWidth, b);
- Value isLargeExp =
- arith::CmpIOp::create(b, arith::CmpIPredicate::uge, unsignedBits, cThreshold);
+ Value isLargeExp = arith::CmpIOp::create(b, arith::CmpIPredicate::uge,
+ unsignedBits, cThreshold);
Value isSpecialValOrLargeVal = isLargeExp;
// In FNUZ-suffixed floating point, NaN is represented by a sign bit of 1 and
>From 03374d976ec28ec2a8fcac162fbb2debbeacbb67 Mon Sep 17 00:00:00 2001
From: hankluo6 <hankluo6 at gmail.com>
Date: Sat, 7 Feb 2026 18:27:21 -0800
Subject: [PATCH 4/4] Remove unused variable
---
mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp | 1 -
1 file changed, 1 deletion(-)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
index 4fc8071b9e74d..d563742da3361 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp
@@ -238,7 +238,6 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
unsigned bitWidth = floatTy.getWidth();
unsigned mantissaWidth = floatTy.getFPMantissaWidth() - 1;
- unsigned exponentWidth = bitWidth - mantissaWidth - 1;
const int bias = (&semantics == &APFloat::Float8E8M0FNU())
? -semantics.minExponent
: -(semantics.minExponent - 1);
More information about the Mlir-commits
mailing list