[Mlir-commits] [mlir] 70dc096 - [MLIR][Complex] DivOp check for NaN in the folder (#169724)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 14 10:47:37 PST 2026
Author: Amr Hesham
Date: 2026-01-14T19:47:33+01:00
New Revision: 70dc0961d01392747bf5d98691103a844ce5571a
URL: https://github.com/llvm/llvm-project/commit/70dc0961d01392747bf5d98691103a844ce5571a
DIFF: https://github.com/llvm/llvm-project/commit/70dc0961d01392747bf5d98691103a844ce5571a.diff
LOG: [MLIR][Complex] DivOp check for NaN in the folder (#169724)
Fold DivOp in complex to Complex of NaNs if LHS or RHS is a NaN value,
and prevent folding without knowing that LHS isn't NaN
Added:
Modified:
mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
mlir/test/Dialect/Complex/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 6caec2a2345bd..25af2f3be3067 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -19,9 +19,7 @@ using namespace mlir::complex;
// ConstantOp
//===----------------------------------------------------------------------===//
-OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
- return getValue();
-}
+OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
void ConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
@@ -374,21 +372,34 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
auto rhs = adaptor.getRhs();
- if (!rhs)
+ auto lhs = adaptor.getLhs();
+
+ // We can't fold without knowing that LHS isn't NaN
+ if (!rhs || !lhs)
return {};
- ArrayAttr arrayAttr = dyn_cast<ArrayAttr>(rhs);
- if (!arrayAttr || arrayAttr.size() != 2)
+ ArrayAttr rhsArrayAttr = dyn_cast<ArrayAttr>(rhs);
+ if (!rhsArrayAttr || rhsArrayAttr.size() != 2)
return {};
- APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
- APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
+ ArrayAttr lhsArrayAttr = dyn_cast<ArrayAttr>(lhs);
+ if (!lhsArrayAttr || lhsArrayAttr.size() != 2)
+ return {};
- if (!imag.isZero())
+ APFloat rhsImag = cast<FloatAttr>(rhsArrayAttr[1]).getValue();
+ if (!rhsImag.isZero())
return {};
+ APFloat lhsReal = cast<FloatAttr>(lhsArrayAttr[0]).getValue();
+ APFloat lhsImag = cast<FloatAttr>(lhsArrayAttr[1]).getValue();
+ if (lhsReal.isNaN() || lhsImag.isNaN()) {
+ Attribute nanValue = lhsReal.isNaN() ? lhsArrayAttr[0] : lhsArrayAttr[1];
+ return ArrayAttr::get(getContext(), {nanValue, nanValue});
+ }
+
// complex.div(a, complex.constant<1.0, 0.0>) -> a
- if (real == APFloat(real.getSemantics(), 1))
+ APFloat rhsReal = cast<FloatAttr>(rhsArrayAttr[0]).getValue();
+ if (rhsReal == APFloat(rhsReal.getSemantics(), 1))
return getLhs();
return {};
diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir
index 16fa0fdf56a1b..b3f49eb3f44c1 100644
--- a/mlir/test/Dialect/Complex/canonicalize.mlir
+++ b/mlir/test/Dialect/Complex/canonicalize.mlir
@@ -282,58 +282,57 @@ func.func @double_reverse_bitcast(%arg0 : complex<f32>) -> f64 {
func.return %1 : f64
}
-
// CHECK-LABEL: func @div_one_f16
-// CHECK-SAME: (%[[ARG0:.*]]: f16, %[[ARG1:.*]]: f16) -> complex<f16>
-func.func @div_one_f16(%arg0: f16, %arg1: f16) -> complex<f16> {
- %create = complex.create %arg0, %arg1: complex<f16>
- %one = complex.constant [1.0 : f16, 0.0 : f16] : complex<f16>
- %div = complex.div %create, %one : complex<f16>
- // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex<f16>
- // CHECK-NEXT: return %[[CREATE]]
+func.func @div_one_f16() -> complex<f16> {
+ %one = complex.constant [1.0 : f16, 2.0 : f16] : complex<f16>
+ %two = complex.constant [1.0 : f16, 0.0 : f16] : complex<f16>
+ %div = complex.div %one, %two : complex<f16>
+ // CHECK: %[[DIV:.*]] = complex.constant [1.000000e+00 : f16, 2.000000e+00 : f16] : complex<f16>
return %div : complex<f16>
}
// CHECK-LABEL: func @div_one_f32
-// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) -> complex<f32>
-func.func @div_one_f32(%arg0: f32, %arg1: f32) -> complex<f32> {
- %create = complex.create %arg0, %arg1: complex<f32>
- %one = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
- %div = complex.div %create, %one : complex<f32>
- // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex<f32>
- // CHECK-NEXT: return %[[CREATE]]
+func.func @div_one_f32() -> complex<f32> {
+ %one = complex.constant [1.0 : f32, 2.0 : f32] : complex<f32>
+ %two = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
+ %div = complex.div %one, %two : complex<f32>
+ // CHECK: %[[DIV:.*]] = complex.constant [1.000000e+00 : f32, 2.000000e+00 : f32] : complex<f32>
return %div : complex<f32>
}
// CHECK-LABEL: func @div_one_f64
-// CHECK-SAME: (%[[ARG0:.*]]: f64, %[[ARG1:.*]]: f64) -> complex<f64>
-func.func @div_one_f64(%arg0: f64, %arg1: f64) -> complex<f64> {
- %create = complex.create %arg0, %arg1: complex<f64>
- %one = complex.constant [1.0 : f64, 0.0 : f64] : complex<f64>
- %div = complex.div %create, %one : complex<f64>
- // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex<f64>
- // CHECK-NEXT: return %[[CREATE]]
+func.func @div_one_f64() -> complex<f64> {
+ %one = complex.constant [1.0 : f64, 2.0 : f64] : complex<f64>
+ %two = complex.constant [1.0 : f64, 0.0 : f64] : complex<f64>
+ %div = complex.div %one, %two : complex<f64>
+ // CHECK: %[[DIV:.*]] = complex.constant [1.000000e+00, 2.000000e+00] : complex<f64>
return %div : complex<f64>
}
// CHECK-LABEL: func @div_one_f80
-// CHECK-SAME: (%[[ARG0:.*]]: f80, %[[ARG1:.*]]: f80) -> complex<f80>
-func.func @div_one_f80(%arg0: f80, %arg1: f80) -> complex<f80> {
- %create = complex.create %arg0, %arg1: complex<f80>
- %one = complex.constant [1.0 : f80, 0.0 : f80] : complex<f80>
- %div = complex.div %create, %one : complex<f80>
- // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex<f80>
- // CHECK-NEXT: return %[[CREATE]]
+func.func @div_one_f80() -> complex<f80> {
+ %one = complex.constant [1.0 : f80, 2.0 : f80] : complex<f80>
+ %two = complex.constant [1.0 : f80, 0.0 : f80] : complex<f80>
+ %div = complex.div %one, %two : complex<f80>
+ // CHECK: %[[DIV:.*]] = complex.constant [1.000000e+00 : f80, 2.000000e+00 : f80] : complex<f80>
return %div : complex<f80>
}
// CHECK-LABEL: func @div_one_f128
-// CHECK-SAME: (%[[ARG0:.*]]: f128, %[[ARG1:.*]]: f128) -> complex<f128>
-func.func @div_one_f128(%arg0: f128, %arg1: f128) -> complex<f128> {
- %create = complex.create %arg0, %arg1: complex<f128>
- %one = complex.constant [1.0 : f128, 0.0 : f128] : complex<f128>
- %div = complex.div %create, %one : complex<f128>
- // CHECK: %[[CREATE:.*]] = complex.create %[[ARG0]], %[[ARG1]] : complex<f128>
- // CHECK-NEXT: return %[[CREATE]]
+func.func @div_one_f128() -> complex<f128> {
+ %one = complex.constant [1.0 : f128, 2.0 : f128] : complex<f128>
+ %two = complex.constant [1.0 : f128, 0.0 : f128] : complex<f128>
+ %div = complex.div %one, %two : complex<f128>
+ // CHECK: %[[DIV:.*]] = complex.constant [1.000000e+00 : f128, 2.000000e+00 : f128] : complex<f128>
return %div : complex<f128>
}
+
+// CHECK-LABEL: div_op_with_rhs_has_nan
+func.func @div_op_with_rhs_has_nan() -> complex<f32> {
+ %a = complex.constant [0x7fffffff : f32, 1.0 : f32]: complex<f32>
+ %b = complex.constant [1.0: f32, 0.0 : f32]: complex<f32>
+ %div = complex.div %a, %b : complex<f32>
+ // CHECK: %[[DIV:.*]] = complex.constant [0x7FFFFFFF : f32, 0x7FFFFFFF : f32] : complex<f32>
+ // CHECK: return %[[DIV]] : complex<f32>
+ return %div : complex<f32>
+}
More information about the Mlir-commits
mailing list