[Mlir-commits] [mlir] bf17ee1 - Add MulOp lowering from Complex dialect to Standard/Math dialect.

Adrian Kuegel llvmlistbot at llvm.org
Mon Jul 5 03:52:08 PDT 2021


Author: Adrian Kuegel
Date: 2021-07-05T12:51:51+02:00
New Revision: bf17ee1950ef7df95e160c4869de8398bc2b67d2

URL: https://github.com/llvm/llvm-project/commit/bf17ee1950ef7df95e160c4869de8398bc2b67d2
DIFF: https://github.com/llvm/llvm-project/commit/bf17ee1950ef7df95e160c4869de8398bc2b67d2.diff

LOG: Add MulOp lowering from Complex dialect to Standard/Math dialect.

The lowering handles special cases with NaN or infinity like C++.

Differential Revision: https://reviews.llvm.org/D105270

Added: 
    

Modified: 
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index a5aa4cbda9c6..018882ae9489 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -337,6 +337,144 @@ struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
   }
 };
 
+struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
+  using OpConversionPattern<complex::MulOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    complex::MulOp::Adaptor transformed(operands);
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+    auto type = transformed.lhs().getType().cast<ComplexType>();
+    auto elementType = type.getElementType().cast<FloatType>();
+
+    Value lhsReal = b.create<complex::ReOp>(elementType, transformed.lhs());
+    Value lhsRealAbs = b.create<AbsFOp>(lhsReal);
+    Value lhsImag = b.create<complex::ImOp>(elementType, transformed.lhs());
+    Value lhsImagAbs = b.create<AbsFOp>(lhsImag);
+    Value rhsReal = b.create<complex::ReOp>(elementType, transformed.rhs());
+    Value rhsRealAbs = b.create<AbsFOp>(rhsReal);
+    Value rhsImag = b.create<complex::ImOp>(elementType, transformed.rhs());
+    Value rhsImagAbs = b.create<AbsFOp>(rhsImag);
+
+    Value lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal);
+    Value lhsRealTimesRhsRealAbs = b.create<AbsFOp>(lhsRealTimesRhsReal);
+    Value lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag);
+    Value lhsImagTimesRhsImagAbs = b.create<AbsFOp>(lhsImagTimesRhsImag);
+    Value real = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
+
+    Value lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal);
+    Value lhsImagTimesRhsRealAbs = b.create<AbsFOp>(lhsImagTimesRhsReal);
+    Value lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag);
+    Value lhsRealTimesRhsImagAbs = b.create<AbsFOp>(lhsRealTimesRhsImag);
+    Value imag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
+
+    // Handle cases where the "naive" calculation results in NaN values.
+    Value realIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, real, real);
+    Value imagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, imag, imag);
+    Value isNan = b.create<AndOp>(realIsNan, imagIsNan);
+
+    Value inf = b.create<ConstantOp>(
+        elementType,
+        b.getFloatAttr(elementType,
+                       APFloat::getInf(elementType.getFloatSemantics())));
+
+    // Case 1. `lhsReal` or `lhsImag` are infinite.
+    Value lhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealAbs, inf);
+    Value lhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagAbs, inf);
+    Value lhsIsInf = b.create<OrOp>(lhsRealIsInf, lhsImagIsInf);
+    Value rhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsReal, rhsReal);
+    Value rhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsImag, rhsImag);
+    Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType));
+    Value one =
+        b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1));
+    Value lhsRealIsInfFloat = b.create<SelectOp>(lhsRealIsInf, one, zero);
+    lhsReal = b.create<SelectOp>(
+        lhsIsInf, b.create<CopySignOp>(lhsRealIsInfFloat, lhsReal), lhsReal);
+    Value lhsImagIsInfFloat = b.create<SelectOp>(lhsImagIsInf, one, zero);
+    lhsImag = b.create<SelectOp>(
+        lhsIsInf, b.create<CopySignOp>(lhsImagIsInfFloat, lhsImag), lhsImag);
+    Value lhsIsInfAndRhsRealIsNan = b.create<AndOp>(lhsIsInf, rhsRealIsNan);
+    rhsReal = b.create<SelectOp>(lhsIsInfAndRhsRealIsNan,
+                                 b.create<CopySignOp>(zero, rhsReal), rhsReal);
+    Value lhsIsInfAndRhsImagIsNan = b.create<AndOp>(lhsIsInf, rhsImagIsNan);
+    rhsImag = b.create<SelectOp>(lhsIsInfAndRhsImagIsNan,
+                                 b.create<CopySignOp>(zero, rhsImag), rhsImag);
+
+    // Case 2. `rhsReal` or `rhsImag` are infinite.
+    Value rhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsRealAbs, inf);
+    Value rhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsImagAbs, inf);
+    Value rhsIsInf = b.create<OrOp>(rhsRealIsInf, rhsImagIsInf);
+    Value lhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsReal, lhsReal);
+    Value lhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsImag, lhsImag);
+    Value rhsRealIsInfFloat = b.create<SelectOp>(rhsRealIsInf, one, zero);
+    rhsReal = b.create<SelectOp>(
+        rhsIsInf, b.create<CopySignOp>(rhsRealIsInfFloat, rhsReal), rhsReal);
+    Value rhsImagIsInfFloat = b.create<SelectOp>(rhsImagIsInf, one, zero);
+    rhsImag = b.create<SelectOp>(
+        rhsIsInf, b.create<CopySignOp>(rhsImagIsInfFloat, rhsImag), rhsImag);
+    Value rhsIsInfAndLhsRealIsNan = b.create<AndOp>(rhsIsInf, lhsRealIsNan);
+    lhsReal = b.create<SelectOp>(rhsIsInfAndLhsRealIsNan,
+                                 b.create<CopySignOp>(zero, lhsReal), lhsReal);
+    Value rhsIsInfAndLhsImagIsNan = b.create<AndOp>(rhsIsInf, lhsImagIsNan);
+    lhsImag = b.create<SelectOp>(rhsIsInfAndLhsImagIsNan,
+                                 b.create<CopySignOp>(zero, lhsImag), lhsImag);
+    Value recalc = b.create<OrOp>(lhsIsInf, rhsIsInf);
+
+    // Case 3. One of the pairwise products of left hand side with right hand
+    // side is infinite.
+    Value lhsRealTimesRhsRealIsInf =
+        b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
+    Value lhsImagTimesRhsImagIsInf =
+        b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
+    Value isSpecialCase =
+        b.create<OrOp>(lhsRealTimesRhsRealIsInf, lhsImagTimesRhsImagIsInf);
+    Value lhsRealTimesRhsImagIsInf =
+        b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
+    isSpecialCase = b.create<OrOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
+    Value lhsImagTimesRhsRealIsInf =
+        b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
+    isSpecialCase = b.create<OrOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
+    Type i1Type = b.getI1Type();
+    Value notRecalc = b.create<XOrOp>(
+        recalc, b.create<ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
+    isSpecialCase = b.create<AndOp>(isSpecialCase, notRecalc);
+    Value isSpecialCaseAndLhsRealIsNan =
+        b.create<AndOp>(isSpecialCase, lhsRealIsNan);
+    lhsReal = b.create<SelectOp>(isSpecialCaseAndLhsRealIsNan,
+                                 b.create<CopySignOp>(zero, lhsReal), lhsReal);
+    Value isSpecialCaseAndLhsImagIsNan =
+        b.create<AndOp>(isSpecialCase, lhsImagIsNan);
+    lhsImag = b.create<SelectOp>(isSpecialCaseAndLhsImagIsNan,
+                                 b.create<CopySignOp>(zero, lhsImag), lhsImag);
+    Value isSpecialCaseAndRhsRealIsNan =
+        b.create<AndOp>(isSpecialCase, rhsRealIsNan);
+    rhsReal = b.create<SelectOp>(isSpecialCaseAndRhsRealIsNan,
+                                 b.create<CopySignOp>(zero, rhsReal), rhsReal);
+    Value isSpecialCaseAndRhsImagIsNan =
+        b.create<AndOp>(isSpecialCase, rhsImagIsNan);
+    rhsImag = b.create<SelectOp>(isSpecialCaseAndRhsImagIsNan,
+                                 b.create<CopySignOp>(zero, rhsImag), rhsImag);
+    recalc = b.create<OrOp>(recalc, isSpecialCase);
+    recalc = b.create<AndOp>(isNan, recalc);
+
+    // Recalculate real part.
+    lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal);
+    lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag);
+    Value newReal = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
+    real = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newReal), real);
+
+    // Recalculate imag part.
+    lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal);
+    lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag);
+    Value newImag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
+    imag = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newImag), imag);
+
+    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
+    return success();
+  }
+};
+
 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
 
@@ -397,6 +535,7 @@ void mlir::populateComplexToStandardConversionPatterns(
       DivOpConversion,
       ExpOpConversion,
       LogOpConversion,
+      MulOpConversion,
       NegOpConversion,
       SignOpConversion>(patterns.getContext());
   // clang-format on
@@ -419,8 +558,8 @@ void ConvertComplexToStandardPass::runOnFunction() {
   target.addLegalDialect<StandardOpsDialect, math::MathDialect,
                          complex::ComplexDialect>();
   target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
-                      complex::ExpOp, complex::LogOp, complex::NotEqualOp,
-                      complex::NegOp, complex::SignOp>();
+                      complex::ExpOp, complex::LogOp, complex::MulOp,
+                      complex::NegOp, complex::NotEqualOp, complex::SignOp>();
   if (failed(applyPartialConversion(function, target, std::move(patterns))))
     signalPassFailure();
 }

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index a99585b6f2e5..95e6854ffa43 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -173,6 +173,124 @@ func @complex_log(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// CHECK-LABEL: func @complex_mul
+// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
+func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %mul = complex.mul %lhs, %rhs : complex<f32>
+  return %mul : complex<f32>
+}
+// CHECK: %[[LHS_REAL:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[LHS_REAL_ABS:.*]] = absf %[[LHS_REAL]] : f32
+// CHECK: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[LHS_IMAG_ABS:.*]] = absf %[[LHS_IMAG]] : f32
+// CHECK: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[RHS_REAL_ABS:.*]] = absf %[[RHS_REAL]] : f32
+// CHECK: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex<f32>
+// CHECK: %[[RHS_IMAG_ABS:.*]] = absf %[[RHS_IMAG]] : f32
+
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = mulf %[[LHS_REAL]], %[[RHS_REAL]] : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_ABS:.*]] = absf %[[LHS_REAL_TIMES_RHS_REAL]] : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = mulf %[[LHS_IMAG]], %[[RHS_IMAG]] : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_ABS:.*]] = absf %[[LHS_IMAG_TIMES_RHS_IMAG]] : f32
+// CHECK: %[[REAL:.*]] = subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] : f32
+
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = mulf %[[LHS_IMAG]], %[[RHS_REAL]] : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_ABS:.*]] = absf %[[LHS_IMAG_TIMES_RHS_REAL]] : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = mulf %[[LHS_REAL]], %[[RHS_IMAG]] : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG_ABS:.*]] = absf %[[LHS_REAL_TIMES_RHS_IMAG]] : f32
+// CHECK: %[[IMAG:.*]] = addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] : f32
+
+// Handle cases where the "naive" calculation results in NaN values.
+// CHECK: %[[REAL_IS_NAN:.*]] = cmpf uno, %[[REAL]], %[[REAL]] : f32
+// CHECK: %[[IMAG_IS_NAN:.*]] = cmpf uno, %[[IMAG]], %[[IMAG]] : f32
+// CHECK: %[[IS_NAN:.*]] = and %[[REAL_IS_NAN]], %[[IMAG_IS_NAN]] : i1
+// CHECK: %[[INF:.*]] = constant 0x7F800000 : f32
+
+// Case 1. LHS_REAL or LHS_IMAG are infinite.
+// CHECK: %[[LHS_REAL_IS_INF:.*]] = cmpf oeq, %[[LHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IMAG_IS_INF:.*]] = cmpf oeq, %[[LHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IS_INF:.*]] = or %[[LHS_REAL_IS_INF]], %[[LHS_IMAG_IS_INF]] : i1
+// CHECK:  %[[RHS_REAL_IS_NAN:.*]] = cmpf uno, %[[RHS_REAL]], %[[RHS_REAL]] : f32
+// CHECK: %[[RHS_IMAG_IS_NAN:.*]] = cmpf uno, %[[RHS_IMAG]], %[[RHS_IMAG]] : f32
+// CHECK: %[[ZERO:.*]] = constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = constant 1.000000e+00 : f32
+// CHECK: %[[LHS_REAL_IS_INF_FLOAT:.*]] = select %[[LHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = copysign %[[LHS_REAL_IS_INF_FLOAT]], %[[LHS_REAL]] : f32
+// CHECK: %[[LHS_REAL1:.*]] = select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_REAL]] : f32
+// CHECK: %[[LHS_IMAG_IS_INF_FLOAT:.*]] = select %[[LHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = copysign %[[LHS_IMAG_IS_INF_FLOAT]], %[[LHS_IMAG]] : f32
+// CHECK: %[[LHS_IMAG1:.*]] = select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_IMAG]] : f32
+// CHECK: %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN:.*]] = and %[[LHS_IS_INF]], %[[RHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[RHS_REAL]] : f32
+// CHECK: %[[RHS_REAL1:.*]] = select %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL]] : f32
+// CHECK: %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN:.*]] = and %[[LHS_IS_INF]], %[[RHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[RHS_IMAG]] : f32
+// CHECK: %[[RHS_IMAG1:.*]] = select %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG]] : f32
+
+// Case 2. RHS_REAL or RHS_IMAG are infinite.
+// CHECK: %[[RHS_REAL_IS_INF:.*]] = cmpf oeq, %[[RHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[RHS_IMAG_IS_INF:.*]] = cmpf oeq, %[[RHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[RHS_IS_INF:.*]] = or %[[RHS_REAL_IS_INF]], %[[RHS_IMAG_IS_INF]] : i1
+// CHECK: %[[LHS_REAL_IS_NAN:.*]] = cmpf uno, %[[LHS_REAL1]], %[[LHS_REAL1]] : f32
+// CHECK: %[[LHS_IMAG_IS_NAN:.*]] = cmpf uno, %[[LHS_IMAG1]], %[[LHS_IMAG1]] : f32
+// CHECK: %[[RHS_REAL_IS_INF_FLOAT:.*]] = select %[[RHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = copysign %[[RHS_REAL_IS_INF_FLOAT]], %[[RHS_REAL1]] : f32
+// CHECK: %[[RHS_REAL2:.*]] = select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_REAL1]] : f32
+// CHECK: %[[RHS_IMAG_IS_INF_FLOAT:.*]] = select %[[RHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = copysign %[[RHS_IMAG_IS_INF_FLOAT]], %[[RHS_IMAG1]] : f32
+// CHECK: %[[RHS_IMAG2:.*]] = select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_IMAG1]] : f32
+// CHECK: %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN:.*]] = and %[[RHS_IS_INF]], %[[LHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[LHS_REAL1]] : f32
+// CHECK: %[[LHS_REAL2:.*]] = select %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL1]] : f32
+// CHECK: %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN:.*]] = and %[[RHS_IS_INF]], %[[LHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[LHS_IMAG1]] : f32
+// CHECK: %[[LHS_IMAG2:.*]] = select %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG1]] : f32
+// CHECK: %[[RECALC:.*]] = or %[[LHS_IS_INF]], %[[RHS_IS_INF]] : i1
+
+// Case 3. One of the pairwise products of left hand side with right hand side
+// is infinite.
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_IS_INF:.*]] = cmpf oeq, %[[LHS_REAL_TIMES_RHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_IS_INF:.*]] = cmpf oeq, %[[LHS_IMAG_TIMES_RHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[IS_SPECIAL_CASE:.*]] = or %[[LHS_REAL_TIMES_RHS_REAL_IS_INF]], %[[LHS_IMAG_TIMES_RHS_IMAG_IS_INF]] : i1
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG_IS_INF:.*]] = cmpf oeq, %[[LHS_REAL_TIMES_RHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[IS_SPECIAL_CASE1:.*]] = or %[[IS_SPECIAL_CASE]], %[[LHS_REAL_TIMES_RHS_IMAG_IS_INF]] : i1
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_IS_INF:.*]] = cmpf oeq, %[[LHS_IMAG_TIMES_RHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[IS_SPECIAL_CASE2:.*]] = or %[[IS_SPECIAL_CASE1]], %[[LHS_IMAG_TIMES_RHS_REAL_IS_INF]] : i1
+// CHECK: %[[TRUE:.*]] = constant true
+// CHECK: %[[NOT_RECALC:.*]] = xor %[[RECALC]], %[[TRUE]] : i1
+// CHECK: %[[IS_SPECIAL_CASE3:.*]] = and %[[IS_SPECIAL_CASE2]], %[[NOT_RECALC]] : i1
+// CHECK: %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN:.*]] = and %[[IS_SPECIAL_CASE3]], %[[LHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[LHS_REAL2]] : f32
+// CHECK: %[[LHS_REAL3:.*]] = select %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL2]] : f32
+// CHECK: %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN:.*]] = and %[[IS_SPECIAL_CASE3]], %[[LHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[LHS_IMAG2]] : f32
+// CHECK: %[[LHS_IMAG3:.*]] = select %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG2]] : f32
+// CHECK: %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN:.*]] = and %[[IS_SPECIAL_CASE3]], %[[RHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[RHS_REAL2]] : f32
+// CHECK: %[[RHS_REAL3:.*]] = select %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL2]] : f32
+// CHECK: %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN:.*]] = and %[[IS_SPECIAL_CASE3]], %[[RHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[RHS_IMAG2]] : f32
+// CHECK: %[[RHS_IMAG3:.*]] = select %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG2]] : f32
+// CHECK: %[[RECALC2:.*]] = or %[[RECALC]], %[[IS_SPECIAL_CASE3]] : i1
+// CHECK: %[[RECALC3:.*]] = and %[[IS_NAN]], %[[RECALC2]] : i1
+
+ // Recalculate real part.
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = mulf %[[LHS_REAL3]], %[[RHS_REAL3]] : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = mulf %[[LHS_IMAG3]], %[[RHS_IMAG3]] : f32
+// CHECK: %[[NEW_REAL:.*]] = subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] : f32
+// CHECK: %[[NEW_REAL_TIMES_INF:.*]] = mulf %[[INF]], %[[NEW_REAL]] : f32
+// CHECK: %[[FINAL_REAL:.*]] = select %[[RECALC3]], %[[NEW_REAL_TIMES_INF]], %[[REAL]] : f32
+
+// Recalculate imag part.
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = mulf %[[LHS_IMAG3]], %[[RHS_REAL3]] : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = mulf %[[LHS_REAL3]], %[[RHS_IMAG3]] : f32
+// CHECK: %[[NEW_IMAG:.*]] = addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] : f32
+// CHECK: %[[NEW_IMAG_TIMES_INF:.*]] = mulf %[[INF]], %[[NEW_IMAG]] : f32
+// CHECK: %[[FINAL_IMAG:.*]] = select %[[RECALC3]], %[[NEW_IMAG_TIMES_INF]], %[[IMAG]] : f32
+
+// CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
 // CHECK-LABEL: func @complex_neg
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func @complex_neg(%arg: complex<f32>) -> complex<f32> {


        


More information about the Mlir-commits mailing list