[Mlir-commits] [mlir] [MLIR][Complex] Check for FastMathFlag in DivOp folder (PR #176249)

Amr Hesham llvmlistbot at llvm.org
Fri Feb 20 12:26:13 PST 2026


https://github.com/AmrDeveloper updated https://github.com/llvm/llvm-project/pull/176249

>From e44f984c63c42779ad8dc8eb842237d202de9761 Mon Sep 17 00:00:00 2001
From: Amr Hesham <amr96 at programmer.net>
Date: Thu, 15 Jan 2026 21:46:48 +0100
Subject: [PATCH 1/2] [MLIR][Complex] Check for FastMathFlag in DivOp folder

---
 mlir/lib/Dialect/Complex/IR/ComplexOps.cpp  | 47 +++++++++++----------
 mlir/test/Dialect/Complex/canonicalize.mlir | 46 +++++++++++++++++++-
 2 files changed, 69 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 25af2f3be3067..e3633a82074b7 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -371,35 +371,38 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
-  auto rhs = adaptor.getRhs();
-  auto lhs = adaptor.getLhs();
-
-  // We can't fold without knowing that LHS isn't NaN
-  if (!rhs || !lhs)
-    return {};
+  Attribute rhs = adaptor.getRhs();
+  Attribute lhs = adaptor.getLhs();
+
+  // complex.div(complex.constant<NaN, NaN>, a) -> complex.constant<NaN, NaN>
+  // complex.div(complex.constant<NaN, a>, b) -> complex.constant<NaN, NaN>
+  // complex.div(complex.constant<a, NaN>, b) -> complex.constant<NaN, NaN>
+  bool isLhsComplexHasNan = false;
+  ArrayAttr lhsArrayAttr = dyn_cast_if_present<ArrayAttr>(lhs);
+  if (lhsArrayAttr && lhsArrayAttr.size() == 2) {
+    APFloat lhsReal = cast<FloatAttr>(lhsArrayAttr[0]).getValue();
+    APFloat lhsImag = cast<FloatAttr>(lhsArrayAttr[1]).getValue();
+    isLhsComplexHasNan = lhsReal.isNaN() || lhsImag.isNaN();
+    if (isLhsComplexHasNan) {
+      Attribute nanValue = lhsReal.isNaN() ? lhsArrayAttr[0] : lhsArrayAttr[1];
+      return ArrayAttr::get(getContext(), {nanValue, nanValue});
+    }
+  }
 
-  ArrayAttr rhsArrayAttr = dyn_cast<ArrayAttr>(rhs);
+  ArrayAttr rhsArrayAttr = dyn_cast_if_present<ArrayAttr>(rhs);
   if (!rhsArrayAttr || rhsArrayAttr.size() != 2)
     return {};
 
-  ArrayAttr lhsArrayAttr = dyn_cast<ArrayAttr>(lhs);
-  if (!lhsArrayAttr || lhsArrayAttr.size() != 2)
-    return {};
-
+  // Fold only if RHS is complex.constant<1.0, 0.0>
   APFloat rhsImag = cast<FloatAttr>(rhsArrayAttr[1]).getValue();
-  if (!rhsImag.isZero())
+  APFloat rhsReal = cast<FloatAttr>(rhsArrayAttr[0]).getValue();
+  if (rhsReal != APFloat(rhsReal.getSemantics(), 1) || !rhsImag.isZero())
     return {};
 
-  APFloat lhsReal = cast<FloatAttr>(lhsArrayAttr[0]).getValue();
-  APFloat lhsImag = cast<FloatAttr>(lhsArrayAttr[1]).getValue();
-  if (lhsReal.isNaN() || lhsImag.isNaN()) {
-    Attribute nanValue = lhsReal.isNaN() ? lhsArrayAttr[0] : lhsArrayAttr[1];
-    return ArrayAttr::get(getContext(), {nanValue, nanValue});
-  }
-
-  // complex.div(a, complex.constant<1.0, 0.0>) -> a
-  APFloat rhsReal = cast<FloatAttr>(rhsArrayAttr[0]).getValue();
-  if (rhsReal == APFloat(rhsReal.getSemantics(), 1))
+  // Fold to LHS if it doesn't contains NaNs or fast math flag nan is exists
+  // complex.div(a, complex.constant<1.0, 0.0>) fastmath<nnan> -> a
+  if ((lhsArrayAttr && !isLhsComplexHasNan) ||
+      arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan))
     return getLhs();
 
   return {};
diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir
index b3f49eb3f44c1..1c5216c82e5c3 100644
--- a/mlir/test/Dialect/Complex/canonicalize.mlir
+++ b/mlir/test/Dialect/Complex/canonicalize.mlir
@@ -327,8 +327,8 @@ func.func @div_one_f128() -> complex<f128> {
   return %div : complex<f128>
 }
 
-// CHECK-LABEL: div_op_with_rhs_has_nan
-func.func @div_op_with_rhs_has_nan() -> complex<f32> {
+// CHECK-LABEL: div_op_with_rhs_has_nan_real
+func.func @div_op_with_rhs_has_nan_real() -> complex<f32> {
   %a = complex.constant [0x7fffffff : f32, 1.0 : f32]: complex<f32>
   %b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
   %div = complex.div %a, %b : complex<f32>
@@ -336,3 +336,45 @@ func.func @div_op_with_rhs_has_nan() -> complex<f32> {
   // CHECK: return %[[DIV]] : complex<f32>
   return %div : complex<f32>
 }
+
+// CHECK-LABEL: div_op_with_rhs_has_nan_imag
+func.func @div_op_with_rhs_has_nan_imag() -> complex<f32> {
+  %a = complex.constant [1.0 : f32, 0x7fffffff : f32]: complex<f32>
+  %b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
+  %div = complex.div %a, %b : complex<f32>
+  // CHECK: %[[DIV:.*]] = complex.constant [0x7FFFFFFF : f32, 0x7FFFFFFF : f32] : complex<f32>
+  // CHECK: return %[[DIV]] : complex<f32>
+  return %div : complex<f32>
+}
+
+// CHECK-LABEL: div_op_with_rhs_has_nan_real_imag
+func.func @div_op_with_rhs_has_nan_real_imag() -> complex<f32> {
+  %a = complex.constant [0x7fffffff : f32, 0x7fffffff : f32]: complex<f32>
+  %b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
+  %div = complex.div %a, %b : complex<f32>
+  // CHECK: %[[DIV:.*]] = complex.constant [0x7FFFFFFF : f32, 0x7FFFFFFF : f32] : complex<f32>
+  // CHECK: return %[[DIV]] : complex<f32>
+  return %div : complex<f32>
+}
+
+// CHECK-LABEL: div_op_non_constant_lhs_with_fast_math
+func.func @div_op_non_constant_lhs_with_fast_math(%arg0: f32, %arg1: f32) -> complex<f32> {
+  %a = complex.create %arg0, %arg1 : complex<f32>
+  %b = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
+  %div = complex.div %a, %b fastmath<nnan> : complex<f32>
+  // CHECK: %[[COMPLEX:.*]] = complex.create %arg0, %arg1 : complex<f32>
+  // CHECK: return %[[COMPLEX]] : complex<f32>
+  return %div: complex<f32>
+}
+
+// CHECK-LABEL: div_op_non_constant_lhs_without_fast_math
+func.func @div_op_non_constant_lhs_without_fast_math(%arg0: f32, %arg1: f32) -> complex<f32> {
+  %a = complex.create %arg0, %arg1 : complex<f32>
+  %b = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
+  %div = complex.div %a, %b : complex<f32>
+  // CHECK: %[[B:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex<f32>
+  // CHECK: %[[A:.*]] = complex.create %arg0, %arg1 : complex<f32>
+  // CHECK: %[[DIV:.*]] = complex.div %[[A]], %[[B]] : complex<f32>
+  // CHECK: return %[[DIV]] : complex<f32>
+  return %div: complex<f32>
+}

>From 4a48b67dc0ebd93d8dcbba5d91d82cb7ef97616a Mon Sep 17 00:00:00 2001
From: Amr Hesham <amr96 at programmer.net>
Date: Fri, 20 Feb 2026 21:25:43 +0100
Subject: [PATCH 2/2] Address code review comments

---
 mlir/lib/Dialect/Complex/IR/ComplexOps.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index e3633a82074b7..b5323597b7ca4 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -396,10 +396,10 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
   // Fold only if RHS is complex.constant<1.0, 0.0>
   APFloat rhsImag = cast<FloatAttr>(rhsArrayAttr[1]).getValue();
   APFloat rhsReal = cast<FloatAttr>(rhsArrayAttr[0]).getValue();
-  if (rhsReal != APFloat(rhsReal.getSemantics(), 1) || !rhsImag.isZero())
+  if (!rhsImag.isZero() || rhsReal != APFloat(rhsReal.getSemantics(), 1))
     return {};
 
-  // Fold to LHS if it doesn't contains NaNs or fast math flag nan is exists
+  // Fold to LHS if it doesn't contains NaNs or fast math flag nan is set
   // complex.div(a, complex.constant<1.0, 0.0>) fastmath<nnan> -> a
   if ((lhsArrayAttr && !isLhsComplexHasNan) ||
       arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan))



More information about the Mlir-commits mailing list