[Mlir-commits] [mlir] [MLIR][Complex] DivOp check for NaN in the folder (PR #169724)

Amr Hesham llvmlistbot at llvm.org
Tue Jan 13 10:00:52 PST 2026


https://github.com/AmrDeveloper updated https://github.com/llvm/llvm-project/pull/169724

>From ed480c8280151a5d873fe3fc566af7fbff7fbc69 Mon Sep 17 00:00:00 2001
From: Amr Hesham <amr96 at programmer.net>
Date: Wed, 26 Nov 2025 21:23:22 +0100
Subject: [PATCH 1/4] [MLIR][Complex] DivOp check for NaN in the folder

---
 mlir/lib/Dialect/Complex/IR/ComplexOps.cpp  | 23 +++++---
 mlir/test/Dialect/Complex/canonicalize.mlir | 61 +++++++++------------
 mlir/test/Dialect/Complex/div-fold.mlir     | 25 +++++++++
 3 files changed, 66 insertions(+), 43 deletions(-)
 create mode 100644 mlir/test/Dialect/Complex/div-fold.mlir

diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 6caec2a2345bd..7838dce137bd2 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -374,21 +374,30 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
 
 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
   auto rhs = adaptor.getRhs();
-  if (!rhs)
+  auto lhs = adaptor.getLhs();
+  if (!rhs || !lhs)
     return {};
 
-  ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs);
-  if (!arrayAttr || arrayAttr.size() != 2)
+  ArrayAttr rhsArrayAttr = dyn_cast<ArrayAttr>(rhs);
+  if (!rhsArrayAttr || rhsArrayAttr.size() != 2)
     return {};
 
-  APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
-  APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
+  ArrayAttr lhsArrayAttr = dyn_cast<ArrayAttr>(lhs);
+  if (!lhsArrayAttr || lhsArrayAttr.size() != 2)
+    return {};
 
-  if (!imag.isZero())
+  APFloat rhsImag = cast<FloatAttr>(rhsArrayAttr[1]).getValue();
+  if (!rhsImag.isZero())
+    return {};
+
+  APFloat lhsReal = cast<FloatAttr>(lhsArrayAttr[0]).getValue();
+  APFloat lhsImag = cast<FloatAttr>(lhsArrayAttr[1]).getValue();
+  if (lhsReal.isNaN() || lhsImag.isNaN())
     return {};
 
   // complex.div(a, complex.constant<1.0, 0.0>) -> a
-  if (real == APFloat(real.getSemantics(), 1))
+  APFloat rhsReal = cast<FloatAttr>(rhsArrayAttr[0]).getValue();
+  if (rhsReal == APFloat(rhsReal.getSemantics(), 1))
     return getLhs();
 
   return {};
diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir
index 16fa0fdf56a1b..0dede36c73f1f 100644
--- a/mlir/test/Dialect/Complex/canonicalize.mlir
+++ b/mlir/test/Dialect/Complex/canonicalize.mlir
@@ -282,58 +282,47 @@ func.func @double_reverse_bitcast(%arg0 : complex<f32>) -> f64 {
   func.return %1 : f64
 }
 
-
 // CHECK-LABEL: func @div_one_f16
-//  CHECK-SAME: (%[[ARG0:.*]]: f16, %[[ARG1:.*]]: f16) -> complex<f16>
-func.func @div_one_f16(%arg0: f16, %arg1: f16) -> complex<f16> {
-  %create = complex.create %arg0, %arg1: complex<f16>  
-  %one = complex.constant [1.0 : f16, 0.0 : f16] : complex<f16>
-  %div = complex.div %create, %one : complex<f16>
-  // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex<f16>
-  // CHECK-NEXT: return %[[CREATE]]
+func.func @div_one_f16() -> complex<f16> {
+  %one = complex.constant [1.0 : f16, 2.0 : f16] : complex<f16>
+  %two = complex.constant [1.0 : f16, 0.0 : f16] : complex<f16>
+  %div = complex.div %one, %two : complex<f16>
+  // CHECK: %[[DIV:.*]] = complex.constant [1.000000e+00 : f16, 2.000000e+00 : f16] : complex<f16>
   return %div : complex<f16>
 }
 
 // CHECK-LABEL: func @div_one_f32
-//  CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) -> complex<f32>
-func.func @div_one_f32(%arg0: f32, %arg1: f32) -> complex<f32> {
-  %create = complex.create %arg0, %arg1: complex<f32>  
-  %one = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
-  %div = complex.div %create, %one : complex<f32>
-  // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex<f32>
-  // CHECK-NEXT: return %[[CREATE]]
+func.func @div_one_f32() -> complex<f32> {
+  %one = complex.constant [1.0 : f32, 2.0 : f32] : complex<f32>
+  %two = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
+  %div = complex.div %one, %two : complex<f32>
+  // CHECK: %[[DIV:.*]] = complex.constant [1.000000e+00 : f32, 2.000000e+00 : f32] : complex<f32>
   return %div : complex<f32>
 }
 
 // CHECK-LABEL: func @div_one_f64
-//  CHECK-SAME: (%[[ARG0:.*]]: f64, %[[ARG1:.*]]: f64) -> complex<f64>
-func.func @div_one_f64(%arg0: f64, %arg1: f64) -> complex<f64> {
-  %create = complex.create %arg0, %arg1: complex<f64>  
-  %one = complex.constant [1.0 : f64, 0.0 : f64] : complex<f64>
-  %div = complex.div %create, %one : complex<f64>
-  // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex<f64>
-  // CHECK-NEXT: return %[[CREATE]]
+func.func @div_one_f64() -> complex<f64> {
+  %one = complex.constant [1.0 : f64, 2.0 : f64] : complex<f64>
+  %two = complex.constant [1.0 : f64, 0.0 : f64] : complex<f64>
+  %div = complex.div %one, %two : complex<f64>
+  // CHECK: %[[DIV:.*]] = complex.constant [1.000000e+00, 2.000000e+00] : complex<f64>
   return %div : complex<f64>
 }
 
 // CHECK-LABEL: func @div_one_f80
-//  CHECK-SAME: (%[[ARG0:.*]]: f80, %[[ARG1:.*]]: f80) -> complex<f80>
-func.func @div_one_f80(%arg0: f80, %arg1: f80) -> complex<f80> {
-  %create = complex.create %arg0, %arg1: complex<f80>  
-  %one = complex.constant [1.0 : f80, 0.0 : f80] : complex<f80>
-  %div = complex.div %create, %one : complex<f80>
-  // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex<f80>
-  // CHECK-NEXT: return %[[CREATE]]
+func.func @div_one_f80() -> complex<f80> {
+  %one = complex.constant [1.0 : f80, 2.0 : f80] : complex<f80>
+  %two = complex.constant [1.0 : f80, 0.0 : f80] : complex<f80>
+  %div = complex.div %one, %two : complex<f80>
+  // CHECK: %[[DIV:.*]] = complex.constant [1.000000e+00 : f80, 2.000000e+00 : f80] : complex<f80>
   return %div : complex<f80>
 }
 
 // CHECK-LABEL: func @div_one_f128
-//  CHECK-SAME: (%[[ARG0:.*]]: f128, %[[ARG1:.*]]: f128) -> complex<f128>
-func.func @div_one_f128(%arg0: f128, %arg1: f128) -> complex<f128> {
-  %create = complex.create %arg0, %arg1: complex<f128>  
-  %one = complex.constant [1.0 : f128, 0.0 : f128] : complex<f128>
-  %div = complex.div %create, %one : complex<f128>
-  // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex<f128>
-  // CHECK-NEXT: return %[[CREATE]]
+func.func @div_one_f128() -> complex<f128> {
+  %one = complex.constant [1.0 : f128, 2.0 : f128] : complex<f128>
+  %two = complex.constant [1.0 : f128, 0.0 : f128] : complex<f128>
+  %div = complex.div %one, %two : complex<f128>
+  // CHECK: %[[DIV:.*]] = complex.constant [1.000000e+00 : f128, 2.000000e+00 : f128] : complex<f128>
   return %div : complex<f128>
 }
diff --git a/mlir/test/Dialect/Complex/div-fold.mlir b/mlir/test/Dialect/Complex/div-fold.mlir
new file mode 100644
index 0000000000000..8564b378a2c96
--- /dev/null
+++ b/mlir/test/Dialect/Complex/div-fold.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s -split-input-file -test-single-fold | FileCheck %s
+
+// CHECK-LABEL: div_op_fold
+func.func @div_op_fold() -> complex<f32> {
+  %a = complex.constant [2.0 : f32, 1.0 : f32]: complex<f32>
+  %b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
+  %div = complex.div %a, %b : complex<f32>
+  // CHECK: %[[DIV:.*]] = complex.constant [2.000000e+00 : f32, 1.000000e+00 : f32] : complex<f32>
+  // CHECK: return %[[DIV]] : complex<f32>
+  return %div : complex<f32>
+}
+
+// CHECK-LABEL: div_op_not_fold_if_rhs_has_nan
+func.func @div_op_not_fold_if_rhs_has_nan() -> complex<f32> {
+  %a = complex.constant [0x7fffffff : f32, 1.0 : f32]: complex<f32>
+  %b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
+  %div = complex.div %a, %b : complex<f32>
+  // CHECK: %[[A:.*]] = complex.constant [0x7FFFFFFF : f32, 1.000000e+00 : f32] : complex<f32>
+  // CHECK: %[[B:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex<f32>
+  // CHECK: %[[DIV:.*]] = complex.div %[[A]], %[[B]] : complex<f32>
+  // CHECK: return %[[DIV]] : complex<f32>
+  return %div : complex<f32>
+}
+
+

>From 78cb0d3dd26590197a6887cf29ee9d541bffc53f Mon Sep 17 00:00:00 2001
From: Amr Hesham <amr96 at programmer.net>
Date: Tue, 2 Dec 2025 20:07:16 +0100
Subject: [PATCH 2/4] Add Canonicalize for ComplexDivOp

---
 .../mlir/Dialect/Complex/IR/ComplexOps.td     |  1 +
 mlir/lib/Dialect/Complex/IR/ComplexOps.cpp    | 28 +++++++++++++++++++
 mlir/test/Dialect/Complex/canonicalize.mlir   | 10 +++++++
 mlir/test/Dialect/Complex/div-fold.mlir       | 25 -----------------
 4 files changed, 39 insertions(+), 25 deletions(-)
 delete mode 100644 mlir/test/Dialect/Complex/div-fold.mlir

diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 828379ded14b3..607921f952be3 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -227,6 +227,7 @@ def DivOp : ComplexArithmeticOp<"div"> {
   }];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 7838dce137bd2..4d914770d0b00 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -169,6 +169,34 @@ void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<MergeComplexBitcast, MergeArithBitcast>(context);
 }
 
+struct FoldComplexDivWithNaN final : OpRewritePattern<complex::DivOp> {
+  using OpRewritePattern<complex::DivOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(complex::DivOp op,
+                                PatternRewriter &rewriter) const override {
+    if (auto constant = op->getOperand(0).getDefiningOp<ConstantOp>()) {
+      mlir::ArrayAttr arrayAttr = constant.getValue();
+      if (!arrayAttr || arrayAttr.size() != 2)
+        return failure();
+
+      APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
+      APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
+      if (real.isNaN() || imag.isNaN()) {
+        Attribute nanValue = real.isNaN() ? arrayAttr[0] : arrayAttr[1];
+        rewriter.replaceOpWithNewOp<complex::ConstantOp>(
+            op, op.getType(), rewriter.getArrayAttr({nanValue, nanValue}));
+        return success();
+      }
+    }
+    return failure();
+  };
+};
+
+void DivOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                        MLIRContext *context) {
+  results.add<FoldComplexDivWithNaN>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // CreateOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir
index 0dede36c73f1f..b3f49eb3f44c1 100644
--- a/mlir/test/Dialect/Complex/canonicalize.mlir
+++ b/mlir/test/Dialect/Complex/canonicalize.mlir
@@ -326,3 +326,13 @@ func.func @div_one_f128() -> complex<f128> {
   // CHECK: %[[DIV:.*]] = complex.constant [1.000000e+00 : f128, 2.000000e+00 : f128] : complex<f128>
   return %div : complex<f128>
 }
+
+// CHECK-LABEL: div_op_with_rhs_has_nan
+func.func @div_op_with_rhs_has_nan() -> complex<f32> {
+  %a = complex.constant [0x7fffffff : f32, 1.0 : f32]: complex<f32>
+  %b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
+  %div = complex.div %a, %b : complex<f32>
+  // CHECK: %[[DIV:.*]] = complex.constant [0x7FFFFFFF : f32, 0x7FFFFFFF : f32] : complex<f32>
+  // CHECK: return %[[DIV]] : complex<f32>
+  return %div : complex<f32>
+}
diff --git a/mlir/test/Dialect/Complex/div-fold.mlir b/mlir/test/Dialect/Complex/div-fold.mlir
deleted file mode 100644
index 8564b378a2c96..0000000000000
--- a/mlir/test/Dialect/Complex/div-fold.mlir
+++ /dev/null
@@ -1,25 +0,0 @@
-// RUN: mlir-opt %s -split-input-file -test-single-fold | FileCheck %s
-
-// CHECK-LABEL: div_op_fold
-func.func @div_op_fold() -> complex<f32> {
-  %a = complex.constant [2.0 : f32, 1.0 : f32]: complex<f32>
-  %b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
-  %div = complex.div %a, %b : complex<f32>
-  // CHECK: %[[DIV:.*]] = complex.constant [2.000000e+00 : f32, 1.000000e+00 : f32] : complex<f32>
-  // CHECK: return %[[DIV]] : complex<f32>
-  return %div : complex<f32>
-}
-
-// CHECK-LABEL: div_op_not_fold_if_rhs_has_nan
-func.func @div_op_not_fold_if_rhs_has_nan() -> complex<f32> {
-  %a = complex.constant [0x7fffffff : f32, 1.0 : f32]: complex<f32>
-  %b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
-  %div = complex.div %a, %b : complex<f32>
-  // CHECK: %[[A:.*]] = complex.constant [0x7FFFFFFF : f32, 1.000000e+00 : f32] : complex<f32>
-  // CHECK: %[[B:.*]] = complex.constant [1.000000e+00 : f32, 0.000000e+00 : f32] : complex<f32>
-  // CHECK: %[[DIV:.*]] = complex.div %[[A]], %[[B]] : complex<f32>
-  // CHECK: return %[[DIV]] : complex<f32>
-  return %div : complex<f32>
-}
-
-

>From ab2d3e01e82f4ffc4349757f47bd2e156d8e9dae Mon Sep 17 00:00:00 2001
From: Amr Hesham <amr96 at programmer.net>
Date: Sat, 13 Dec 2025 15:27:28 +0100
Subject: [PATCH 3/4] Return Attr directly from the folder

---
 .../mlir/Dialect/Complex/IR/ComplexOps.td     |  1 -
 mlir/lib/Dialect/Complex/IR/ComplexOps.cpp    | 38 +++----------------
 2 files changed, 5 insertions(+), 34 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 607921f952be3..828379ded14b3 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -227,7 +227,6 @@ def DivOp : ComplexArithmeticOp<"div"> {
   }];
 
   let hasFolder = 1;
-  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 4d914770d0b00..7c2ec86fbe8c6 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -19,9 +19,7 @@ using namespace mlir::complex;
 // ConstantOp
 //===----------------------------------------------------------------------===//
 
-OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
-  return getValue();
-}
+OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
 
 void ConstantOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
@@ -169,34 +167,6 @@ void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<MergeComplexBitcast, MergeArithBitcast>(context);
 }
 
-struct FoldComplexDivWithNaN final : OpRewritePattern<complex::DivOp> {
-  using OpRewritePattern<complex::DivOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(complex::DivOp op,
-                                PatternRewriter &rewriter) const override {
-    if (auto constant = op->getOperand(0).getDefiningOp<ConstantOp>()) {
-      mlir::ArrayAttr arrayAttr = constant.getValue();
-      if (!arrayAttr || arrayAttr.size() != 2)
-        return failure();
-
-      APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
-      APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
-      if (real.isNaN() || imag.isNaN()) {
-        Attribute nanValue = real.isNaN() ? arrayAttr[0] : arrayAttr[1];
-        rewriter.replaceOpWithNewOp<complex::ConstantOp>(
-            op, op.getType(), rewriter.getArrayAttr({nanValue, nanValue}));
-        return success();
-      }
-    }
-    return failure();
-  };
-};
-
-void DivOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                        MLIRContext *context) {
-  results.add<FoldComplexDivWithNaN>(context);
-}
-
 //===----------------------------------------------------------------------===//
 // CreateOp
 //===----------------------------------------------------------------------===//
@@ -420,8 +390,10 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
 
   APFloat lhsReal = cast<FloatAttr>(lhsArrayAttr[0]).getValue();
   APFloat lhsImag = cast<FloatAttr>(lhsArrayAttr[1]).getValue();
-  if (lhsReal.isNaN() || lhsImag.isNaN())
-    return {};
+  if (lhsReal.isNaN() || lhsImag.isNaN()) {
+    Attribute nanValue = lhsReal.isNaN() ? lhsArrayAttr[0] : lhsArrayAttr[1];
+    return ArrayAttr::get(getContext(), {nanValue, nanValue});
+  }
 
   // complex.div(a, complex.constant<1.0, 0.0>) -> a
   APFloat rhsReal = cast<FloatAttr>(rhsArrayAttr[0]).getValue();

>From bee665b92cbaa368162943f677531112eb909f0c Mon Sep 17 00:00:00 2001
From: Amr Hesham <amr96 at programmer.net>
Date: Tue, 13 Jan 2026 18:47:16 +0100
Subject: [PATCH 4/4] Address code review comment

---
 mlir/lib/Dialect/Complex/IR/ComplexOps.cpp | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 7c2ec86fbe8c6..25af2f3be3067 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -373,6 +373,8 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
 OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
   auto rhs = adaptor.getRhs();
   auto lhs = adaptor.getLhs();
+
+  // We can't fold without knowing that LHS isn't NaN
   if (!rhs || !lhs)
     return {};
 



More information about the Mlir-commits mailing list