[Mlir-commits] [mlir] [MLIR][Complex] Check for FastMathFlag in DivOp folder (PR #176249)
Amr Hesham
llvmlistbot at llvm.org
Fri Feb 20 12:26:13 PST 2026
https://github.com/AmrDeveloper updated https://github.com/llvm/llvm-project/pull/176249
>From e44f984c63c42779ad8dc8eb842237d202de9761 Mon Sep 17 00:00:00 2001
From: Amr Hesham <amr96 at programmer.net>
Date: Thu, 15 Jan 2026 21:46:48 +0100
Subject: [PATCH 1/2] [MLIR][Complex] Check for FastMathFlag in DivOp folder
---
mlir/lib/Dialect/Complex/IR/ComplexOps.cpp | 47 +++++++++++----------
mlir/test/Dialect/Complex/canonicalize.mlir | 46 +++++++++++++++++++-
2 files changed, 69 insertions(+), 24 deletions(-)
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 25af2f3be3067..e3633a82074b7 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -371,35 +371,38 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
- auto rhs = adaptor.getRhs();
- auto lhs = adaptor.getLhs();
-
- // We can't fold without knowing that LHS isn't NaN
- if (!rhs || !lhs)
- return {};
+ Attribute rhs = adaptor.getRhs();
+ Attribute lhs = adaptor.getLhs();
+
+ // complex.div(complex.constant<NaN, NaN>, a) -> complex.constant<NaN, NaN>
+ // complex.div(complex.constant<NaN, a>, b) -> complex.constant<NaN, NaN>
+ // complex.div(complex.constant<a, NaN>, b) -> complex.constant<NaN, NaN>
+ bool isLhsComplexHasNan = false;
+ ArrayAttr lhsArrayAttr = dyn_cast_if_present<ArrayAttr>(lhs);
+ if (lhsArrayAttr && lhsArrayAttr.size() == 2) {
+ APFloat lhsReal = cast<FloatAttr>(lhsArrayAttr[0]).getValue();
+ APFloat lhsImag = cast<FloatAttr>(lhsArrayAttr[1]).getValue();
+ isLhsComplexHasNan = lhsReal.isNaN() || lhsImag.isNaN();
+ if (isLhsComplexHasNan) {
+ Attribute nanValue = lhsReal.isNaN() ? lhsArrayAttr[0] : lhsArrayAttr[1];
+ return ArrayAttr::get(getContext(), {nanValue, nanValue});
+ }
+ }
- ArrayAttr rhsArrayAttr = dyn_cast<ArrayAttr>(rhs);
+ ArrayAttr rhsArrayAttr = dyn_cast_if_present<ArrayAttr>(rhs);
if (!rhsArrayAttr || rhsArrayAttr.size() != 2)
return {};
- ArrayAttr lhsArrayAttr = dyn_cast<ArrayAttr>(lhs);
- if (!lhsArrayAttr || lhsArrayAttr.size() != 2)
- return {};
-
+ // Fold only if RHS is complex.constant<1.0, 0.0>
APFloat rhsImag = cast<FloatAttr>(rhsArrayAttr[1]).getValue();
- if (!rhsImag.isZero())
+ APFloat rhsReal = cast<FloatAttr>(rhsArrayAttr[0]).getValue();
+ if (rhsReal != APFloat(rhsReal.getSemantics(), 1) || !rhsImag.isZero())
return {};
- APFloat lhsReal = cast<FloatAttr>(lhsArrayAttr[0]).getValue();
- APFloat lhsImag = cast<FloatAttr>(lhsArrayAttr[1]).getValue();
- if (lhsReal.isNaN() || lhsImag.isNaN()) {
- Attribute nanValue = lhsReal.isNaN() ? lhsArrayAttr[0] : lhsArrayAttr[1];
- return ArrayAttr::get(getContext(), {nanValue, nanValue});
- }
-
- // complex.div(a, complex.constant<1.0, 0.0>) -> a
- APFloat rhsReal = cast<FloatAttr>(rhsArrayAttr[0]).getValue();
- if (rhsReal == APFloat(rhsReal.getSemantics(), 1))
+ // Fold to LHS if it doesn't contains NaNs or fast math flag nan is exists
+ // complex.div(a, complex.constant<1.0, 0.0>) fastmath<nnan> -> a
+ if ((lhsArrayAttr && !isLhsComplexHasNan) ||
+ arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan))
return getLhs();
return {};
diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir
index b3f49eb3f44c1..1c5216c82e5c3 100644
--- a/mlir/test/Dialect/Complex/canonicalize.mlir
+++ b/mlir/test/Dialect/Complex/canonicalize.mlir
@@ -327,8 +327,8 @@ func.func @div_one_f128() -> complex<f128> {
return %div : complex<f128>
}
-// CHECK-LABEL: div_op_with_rhs_has_nan
-func.func @div_op_with_rhs_has_nan() -> complex<f32> {
+// CHECK-LABEL: div_op_with_rhs_has_nan_real
+func.func @div_op_with_rhs_has_nan_real() -> complex<f32> {
%a = complex.constant [0x7fffffff : f32, 1.0 : f32]: complex<f32>
%b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
%div = complex.div %a, %b : complex<f32>
@@ -336,3 +336,45 @@ func.func @div_op_with_rhs_has_nan() -> complex<f32> {
// CHECK: return %[[DIV]] : complex<f32>
return %div : complex<f32>
}
+
+// CHECK-LABEL: div_op_with_rhs_has_nan_imag
+func.func @div_op_with_rhs_has_nan_imag() -> complex<f32> {
+ %a = complex.constant [1.0 : f32, 0x7fffffff : f32]: complex<f32>
+ %b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
+ %div = complex.div %a, %b : complex<f32>
+ // CHECK: %[[DIV:.*]] = complex.constant [0x7FFFFFFF : f32, 0x7FFFFFFF : f32] : complex<f32>
+ // CHECK: return %[[DIV]] : complex<f32>
+ return %div : complex<f32>
+}
+
+// CHECK-LABEL: div_op_with_rhs_has_nan_real_imag
+func.func @div_op_with_rhs_has_nan_real_imag() -> complex<f32> {
+ %a = complex.constant [0x7fffffff : f32, 0x7fffffff : f32]: complex<f32>
+ %b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
+ %div = complex.div %a, %b : complex<f32>
+ // CHECK: %[[DIV:.*]] = complex.constant [0x7FFFFFFF : f32, 0x7FFFFFFF : f32] : complex<f32>
+ // CHECK: return %[[DIV]] : complex<f32>
+ return %div : complex<f32>
+}
+
+// CHECK-LABEL: div_op_non_constant_lhs_with_fast_math
+func.func @div_op_non_constant_lhs_with_fast_math(%arg0: f32, %arg1: f32) -> complex<f32> {
+ %a = complex.create %arg0, %arg1 : complex<f32>
+ %b = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
+ %div = complex.div %a, %b fastmath<nnan> : complex<f32>
+ // CHECK: %[[COMPLEX:.*]] = complex.create %arg0, %arg1 : complex<f32>
+ // CHECK: return %[[COMPLEX]] : complex<f32>
+ return %div: complex<f32>
+}
+
+// CHECK-LABEL: div_op_non_constant_lhs_without_fast_math
+func.func @div_op_non_constant_lhs_without_fast_math(%arg0: f32, %arg1: f32) -> complex<f32> {
+ %a = complex.create %arg0, %arg1 : complex<f32>
+ %b = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
+ %div = complex.div %a, %b : complex<f32>
+ // CHECK: %[[B:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex<f32>
+ // CHECK: %[[A:.*]] = complex.create %arg0, %arg1 : complex<f32>
+ // CHECK: %[[DIV:.*]] = complex.div %[[A]], %[[B]] : complex<f32>
+ // CHECK: return %[[DIV]] : complex<f32>
+ return %div: complex<f32>
+}
>From 4a48b67dc0ebd93d8dcbba5d91d82cb7ef97616a Mon Sep 17 00:00:00 2001
From: Amr Hesham <amr96 at programmer.net>
Date: Fri, 20 Feb 2026 21:25:43 +0100
Subject: [PATCH 2/2] Address code review comments
---
mlir/lib/Dialect/Complex/IR/ComplexOps.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index e3633a82074b7..b5323597b7ca4 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -396,10 +396,10 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
// Fold only if RHS is complex.constant<1.0, 0.0>
APFloat rhsImag = cast<FloatAttr>(rhsArrayAttr[1]).getValue();
APFloat rhsReal = cast<FloatAttr>(rhsArrayAttr[0]).getValue();
- if (rhsReal != APFloat(rhsReal.getSemantics(), 1) || !rhsImag.isZero())
+ if (!rhsImag.isZero() || rhsReal != APFloat(rhsReal.getSemantics(), 1))
return {};
- // Fold to LHS if it doesn't contains NaNs or fast math flag nan is exists
+ // Fold to LHS if it doesn't contains NaNs or fast math flag nan is set
// complex.div(a, complex.constant<1.0, 0.0>) fastmath<nnan> -> a
if ((lhsArrayAttr && !isLhsComplexHasNan) ||
arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan))
More information about the Mlir-commits
mailing list