[Mlir-commits] [mlir] bf17ee1 - Add MulOp lowering from Complex dialect to Standard/Math dialect.
Adrian Kuegel
llvmlistbot at llvm.org
Mon Jul 5 03:52:08 PDT 2021
Author: Adrian Kuegel
Date: 2021-07-05T12:51:51+02:00
New Revision: bf17ee1950ef7df95e160c4869de8398bc2b67d2
URL: https://github.com/llvm/llvm-project/commit/bf17ee1950ef7df95e160c4869de8398bc2b67d2
DIFF: https://github.com/llvm/llvm-project/commit/bf17ee1950ef7df95e160c4869de8398bc2b67d2.diff
LOG: Add MulOp lowering from Complex dialect to Standard/Math dialect.
The lowering handles special cases with NaN or infinity like C++.
Differential Revision: https://reviews.llvm.org/D105270
Added:
Modified:
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index a5aa4cbda9c6..018882ae9489 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -337,6 +337,144 @@ struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
}
};
+struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
+ using OpConversionPattern<complex::MulOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::MulOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ complex::MulOp::Adaptor transformed(operands);
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ auto type = transformed.lhs().getType().cast<ComplexType>();
+ auto elementType = type.getElementType().cast<FloatType>();
+
+ Value lhsReal = b.create<complex::ReOp>(elementType, transformed.lhs());
+ Value lhsRealAbs = b.create<AbsFOp>(lhsReal);
+ Value lhsImag = b.create<complex::ImOp>(elementType, transformed.lhs());
+ Value lhsImagAbs = b.create<AbsFOp>(lhsImag);
+ Value rhsReal = b.create<complex::ReOp>(elementType, transformed.rhs());
+ Value rhsRealAbs = b.create<AbsFOp>(rhsReal);
+ Value rhsImag = b.create<complex::ImOp>(elementType, transformed.rhs());
+ Value rhsImagAbs = b.create<AbsFOp>(rhsImag);
+
+ Value lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal);
+ Value lhsRealTimesRhsRealAbs = b.create<AbsFOp>(lhsRealTimesRhsReal);
+ Value lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag);
+ Value lhsImagTimesRhsImagAbs = b.create<AbsFOp>(lhsImagTimesRhsImag);
+ Value real = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
+
+ Value lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal);
+ Value lhsImagTimesRhsRealAbs = b.create<AbsFOp>(lhsImagTimesRhsReal);
+ Value lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag);
+ Value lhsRealTimesRhsImagAbs = b.create<AbsFOp>(lhsRealTimesRhsImag);
+ Value imag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
+
+ // Handle cases where the "naive" calculation results in NaN values.
+ Value realIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, real, real);
+ Value imagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, imag, imag);
+ Value isNan = b.create<AndOp>(realIsNan, imagIsNan);
+
+ Value inf = b.create<ConstantOp>(
+ elementType,
+ b.getFloatAttr(elementType,
+ APFloat::getInf(elementType.getFloatSemantics())));
+
+ // Case 1. `lhsReal` or `lhsImag` are infinite.
+ Value lhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealAbs, inf);
+ Value lhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagAbs, inf);
+ Value lhsIsInf = b.create<OrOp>(lhsRealIsInf, lhsImagIsInf);
+ Value rhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsReal, rhsReal);
+ Value rhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, rhsImag, rhsImag);
+ Value zero = b.create<ConstantOp>(elementType, b.getZeroAttr(elementType));
+ Value one =
+ b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1));
+ Value lhsRealIsInfFloat = b.create<SelectOp>(lhsRealIsInf, one, zero);
+ lhsReal = b.create<SelectOp>(
+ lhsIsInf, b.create<CopySignOp>(lhsRealIsInfFloat, lhsReal), lhsReal);
+ Value lhsImagIsInfFloat = b.create<SelectOp>(lhsImagIsInf, one, zero);
+ lhsImag = b.create<SelectOp>(
+ lhsIsInf, b.create<CopySignOp>(lhsImagIsInfFloat, lhsImag), lhsImag);
+ Value lhsIsInfAndRhsRealIsNan = b.create<AndOp>(lhsIsInf, rhsRealIsNan);
+ rhsReal = b.create<SelectOp>(lhsIsInfAndRhsRealIsNan,
+ b.create<CopySignOp>(zero, rhsReal), rhsReal);
+ Value lhsIsInfAndRhsImagIsNan = b.create<AndOp>(lhsIsInf, rhsImagIsNan);
+ rhsImag = b.create<SelectOp>(lhsIsInfAndRhsImagIsNan,
+ b.create<CopySignOp>(zero, rhsImag), rhsImag);
+
+ // Case 2. `rhsReal` or `rhsImag` are infinite.
+ Value rhsRealIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsRealAbs, inf);
+ Value rhsImagIsInf = b.create<CmpFOp>(CmpFPredicate::OEQ, rhsImagAbs, inf);
+ Value rhsIsInf = b.create<OrOp>(rhsRealIsInf, rhsImagIsInf);
+ Value lhsRealIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsReal, lhsReal);
+ Value lhsImagIsNan = b.create<CmpFOp>(CmpFPredicate::UNO, lhsImag, lhsImag);
+ Value rhsRealIsInfFloat = b.create<SelectOp>(rhsRealIsInf, one, zero);
+ rhsReal = b.create<SelectOp>(
+ rhsIsInf, b.create<CopySignOp>(rhsRealIsInfFloat, rhsReal), rhsReal);
+ Value rhsImagIsInfFloat = b.create<SelectOp>(rhsImagIsInf, one, zero);
+ rhsImag = b.create<SelectOp>(
+ rhsIsInf, b.create<CopySignOp>(rhsImagIsInfFloat, rhsImag), rhsImag);
+ Value rhsIsInfAndLhsRealIsNan = b.create<AndOp>(rhsIsInf, lhsRealIsNan);
+ lhsReal = b.create<SelectOp>(rhsIsInfAndLhsRealIsNan,
+ b.create<CopySignOp>(zero, lhsReal), lhsReal);
+ Value rhsIsInfAndLhsImagIsNan = b.create<AndOp>(rhsIsInf, lhsImagIsNan);
+ lhsImag = b.create<SelectOp>(rhsIsInfAndLhsImagIsNan,
+ b.create<CopySignOp>(zero, lhsImag), lhsImag);
+ Value recalc = b.create<OrOp>(lhsIsInf, rhsIsInf);
+
+ // Case 3. One of the pairwise products of left hand side with right hand
+ // side is infinite.
+ Value lhsRealTimesRhsRealIsInf =
+ b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsRealAbs, inf);
+ Value lhsImagTimesRhsImagIsInf =
+ b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsImagAbs, inf);
+ Value isSpecialCase =
+ b.create<OrOp>(lhsRealTimesRhsRealIsInf, lhsImagTimesRhsImagIsInf);
+ Value lhsRealTimesRhsImagIsInf =
+ b.create<CmpFOp>(CmpFPredicate::OEQ, lhsRealTimesRhsImagAbs, inf);
+ isSpecialCase = b.create<OrOp>(isSpecialCase, lhsRealTimesRhsImagIsInf);
+ Value lhsImagTimesRhsRealIsInf =
+ b.create<CmpFOp>(CmpFPredicate::OEQ, lhsImagTimesRhsRealAbs, inf);
+ isSpecialCase = b.create<OrOp>(isSpecialCase, lhsImagTimesRhsRealIsInf);
+ Type i1Type = b.getI1Type();
+ Value notRecalc = b.create<XOrOp>(
+ recalc, b.create<ConstantOp>(i1Type, b.getIntegerAttr(i1Type, 1)));
+ isSpecialCase = b.create<AndOp>(isSpecialCase, notRecalc);
+ Value isSpecialCaseAndLhsRealIsNan =
+ b.create<AndOp>(isSpecialCase, lhsRealIsNan);
+ lhsReal = b.create<SelectOp>(isSpecialCaseAndLhsRealIsNan,
+ b.create<CopySignOp>(zero, lhsReal), lhsReal);
+ Value isSpecialCaseAndLhsImagIsNan =
+ b.create<AndOp>(isSpecialCase, lhsImagIsNan);
+ lhsImag = b.create<SelectOp>(isSpecialCaseAndLhsImagIsNan,
+ b.create<CopySignOp>(zero, lhsImag), lhsImag);
+ Value isSpecialCaseAndRhsRealIsNan =
+ b.create<AndOp>(isSpecialCase, rhsRealIsNan);
+ rhsReal = b.create<SelectOp>(isSpecialCaseAndRhsRealIsNan,
+ b.create<CopySignOp>(zero, rhsReal), rhsReal);
+ Value isSpecialCaseAndRhsImagIsNan =
+ b.create<AndOp>(isSpecialCase, rhsImagIsNan);
+ rhsImag = b.create<SelectOp>(isSpecialCaseAndRhsImagIsNan,
+ b.create<CopySignOp>(zero, rhsImag), rhsImag);
+ recalc = b.create<OrOp>(recalc, isSpecialCase);
+ recalc = b.create<AndOp>(isNan, recalc);
+
+ // Recalculate real part.
+ lhsRealTimesRhsReal = b.create<MulFOp>(lhsReal, rhsReal);
+ lhsImagTimesRhsImag = b.create<MulFOp>(lhsImag, rhsImag);
+ Value newReal = b.create<SubFOp>(lhsRealTimesRhsReal, lhsImagTimesRhsImag);
+ real = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newReal), real);
+
+ // Recalculate imag part.
+ lhsImagTimesRhsReal = b.create<MulFOp>(lhsImag, rhsReal);
+ lhsRealTimesRhsImag = b.create<MulFOp>(lhsReal, rhsImag);
+ Value newImag = b.create<AddFOp>(lhsImagTimesRhsReal, lhsRealTimesRhsImag);
+ imag = b.create<SelectOp>(recalc, b.create<MulFOp>(inf, newImag), imag);
+
+ rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
+ return success();
+ }
+};
+
struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
using OpConversionPattern<complex::NegOp>::OpConversionPattern;
@@ -397,6 +535,7 @@ void mlir::populateComplexToStandardConversionPatterns(
DivOpConversion,
ExpOpConversion,
LogOpConversion,
+ MulOpConversion,
NegOpConversion,
SignOpConversion>(patterns.getContext());
// clang-format on
@@ -419,8 +558,8 @@ void ConvertComplexToStandardPass::runOnFunction() {
target.addLegalDialect<StandardOpsDialect, math::MathDialect,
complex::ComplexDialect>();
target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
- complex::ExpOp, complex::LogOp, complex::NotEqualOp,
- complex::NegOp, complex::SignOp>();
+ complex::ExpOp, complex::LogOp, complex::MulOp,
+ complex::NegOp, complex::NotEqualOp, complex::SignOp>();
if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index a99585b6f2e5..95e6854ffa43 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -173,6 +173,124 @@ func @complex_log(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// CHECK-LABEL: func @complex_mul
+// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
+func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+ %mul = complex.mul %lhs, %rhs : complex<f32>
+ return %mul : complex<f32>
+}
+// CHECK: %[[LHS_REAL:.*]] = complex.re %[[LHS]] : complex<f32>
+// CHECK: %[[LHS_REAL_ABS:.*]] = absf %[[LHS_REAL]] : f32
+// CHECK: %[[LHS_IMAG:.*]] = complex.im %[[LHS]] : complex<f32>
+// CHECK: %[[LHS_IMAG_ABS:.*]] = absf %[[LHS_IMAG]] : f32
+// CHECK: %[[RHS_REAL:.*]] = complex.re %[[RHS]] : complex<f32>
+// CHECK: %[[RHS_REAL_ABS:.*]] = absf %[[RHS_REAL]] : f32
+// CHECK: %[[RHS_IMAG:.*]] = complex.im %[[RHS]] : complex<f32>
+// CHECK: %[[RHS_IMAG_ABS:.*]] = absf %[[RHS_IMAG]] : f32
+
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = mulf %[[LHS_REAL]], %[[RHS_REAL]] : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_ABS:.*]] = absf %[[LHS_REAL_TIMES_RHS_REAL]] : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = mulf %[[LHS_IMAG]], %[[RHS_IMAG]] : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_ABS:.*]] = absf %[[LHS_IMAG_TIMES_RHS_IMAG]] : f32
+// CHECK: %[[REAL:.*]] = subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] : f32
+
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = mulf %[[LHS_IMAG]], %[[RHS_REAL]] : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_ABS:.*]] = absf %[[LHS_IMAG_TIMES_RHS_REAL]] : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = mulf %[[LHS_REAL]], %[[RHS_IMAG]] : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG_ABS:.*]] = absf %[[LHS_REAL_TIMES_RHS_IMAG]] : f32
+// CHECK: %[[IMAG:.*]] = addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] : f32
+
+// Handle cases where the "naive" calculation results in NaN values.
+// CHECK: %[[REAL_IS_NAN:.*]] = cmpf uno, %[[REAL]], %[[REAL]] : f32
+// CHECK: %[[IMAG_IS_NAN:.*]] = cmpf uno, %[[IMAG]], %[[IMAG]] : f32
+// CHECK: %[[IS_NAN:.*]] = and %[[REAL_IS_NAN]], %[[IMAG_IS_NAN]] : i1
+// CHECK: %[[INF:.*]] = constant 0x7F800000 : f32
+
+// Case 1. LHS_REAL or LHS_IMAG are infinite.
+// CHECK: %[[LHS_REAL_IS_INF:.*]] = cmpf oeq, %[[LHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IMAG_IS_INF:.*]] = cmpf oeq, %[[LHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IS_INF:.*]] = or %[[LHS_REAL_IS_INF]], %[[LHS_IMAG_IS_INF]] : i1
+// CHECK: %[[RHS_REAL_IS_NAN:.*]] = cmpf uno, %[[RHS_REAL]], %[[RHS_REAL]] : f32
+// CHECK: %[[RHS_IMAG_IS_NAN:.*]] = cmpf uno, %[[RHS_IMAG]], %[[RHS_IMAG]] : f32
+// CHECK: %[[ZERO:.*]] = constant 0.000000e+00 : f32
+// CHECK: %[[ONE:.*]] = constant 1.000000e+00 : f32
+// CHECK: %[[LHS_REAL_IS_INF_FLOAT:.*]] = select %[[LHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = copysign %[[LHS_REAL_IS_INF_FLOAT]], %[[LHS_REAL]] : f32
+// CHECK: %[[LHS_REAL1:.*]] = select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_REAL]] : f32
+// CHECK: %[[LHS_IMAG_IS_INF_FLOAT:.*]] = select %[[LHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = copysign %[[LHS_IMAG_IS_INF_FLOAT]], %[[LHS_IMAG]] : f32
+// CHECK: %[[LHS_IMAG1:.*]] = select %[[LHS_IS_INF]], %[[TMP]], %[[LHS_IMAG]] : f32
+// CHECK: %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN:.*]] = and %[[LHS_IS_INF]], %[[RHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[RHS_REAL]] : f32
+// CHECK: %[[RHS_REAL1:.*]] = select %[[LHS_IS_INF_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL]] : f32
+// CHECK: %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN:.*]] = and %[[LHS_IS_INF]], %[[RHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[RHS_IMAG]] : f32
+// CHECK: %[[RHS_IMAG1:.*]] = select %[[LHS_IS_INF_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG]] : f32
+
+// Case 2. RHS_REAL or RHS_IMAG are infinite.
+// CHECK: %[[RHS_REAL_IS_INF:.*]] = cmpf oeq, %[[RHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[RHS_IMAG_IS_INF:.*]] = cmpf oeq, %[[RHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[RHS_IS_INF:.*]] = or %[[RHS_REAL_IS_INF]], %[[RHS_IMAG_IS_INF]] : i1
+// CHECK: %[[LHS_REAL_IS_NAN:.*]] = cmpf uno, %[[LHS_REAL1]], %[[LHS_REAL1]] : f32
+// CHECK: %[[LHS_IMAG_IS_NAN:.*]] = cmpf uno, %[[LHS_IMAG1]], %[[LHS_IMAG1]] : f32
+// CHECK: %[[RHS_REAL_IS_INF_FLOAT:.*]] = select %[[RHS_REAL_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = copysign %[[RHS_REAL_IS_INF_FLOAT]], %[[RHS_REAL1]] : f32
+// CHECK: %[[RHS_REAL2:.*]] = select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_REAL1]] : f32
+// CHECK: %[[RHS_IMAG_IS_INF_FLOAT:.*]] = select %[[RHS_IMAG_IS_INF]], %[[ONE]], %[[ZERO]] : f32
+// CHECK: %[[TMP:.*]] = copysign %[[RHS_IMAG_IS_INF_FLOAT]], %[[RHS_IMAG1]] : f32
+// CHECK: %[[RHS_IMAG2:.*]] = select %[[RHS_IS_INF]], %[[TMP]], %[[RHS_IMAG1]] : f32
+// CHECK: %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN:.*]] = and %[[RHS_IS_INF]], %[[LHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[LHS_REAL1]] : f32
+// CHECK: %[[LHS_REAL2:.*]] = select %[[RHS_IS_INF_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL1]] : f32
+// CHECK: %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN:.*]] = and %[[RHS_IS_INF]], %[[LHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[LHS_IMAG1]] : f32
+// CHECK: %[[LHS_IMAG2:.*]] = select %[[RHS_IS_INF_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG1]] : f32
+// CHECK: %[[RECALC:.*]] = or %[[LHS_IS_INF]], %[[RHS_IS_INF]] : i1
+
+// Case 3. One of the pairwise products of left hand side with right hand side
+// is infinite.
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL_IS_INF:.*]] = cmpf oeq, %[[LHS_REAL_TIMES_RHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG_IS_INF:.*]] = cmpf oeq, %[[LHS_IMAG_TIMES_RHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[IS_SPECIAL_CASE:.*]] = or %[[LHS_REAL_TIMES_RHS_REAL_IS_INF]], %[[LHS_IMAG_TIMES_RHS_IMAG_IS_INF]] : i1
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG_IS_INF:.*]] = cmpf oeq, %[[LHS_REAL_TIMES_RHS_IMAG_ABS]], %[[INF]] : f32
+// CHECK: %[[IS_SPECIAL_CASE1:.*]] = or %[[IS_SPECIAL_CASE]], %[[LHS_REAL_TIMES_RHS_IMAG_IS_INF]] : i1
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL_IS_INF:.*]] = cmpf oeq, %[[LHS_IMAG_TIMES_RHS_REAL_ABS]], %[[INF]] : f32
+// CHECK: %[[IS_SPECIAL_CASE2:.*]] = or %[[IS_SPECIAL_CASE1]], %[[LHS_IMAG_TIMES_RHS_REAL_IS_INF]] : i1
+// CHECK: %[[TRUE:.*]] = constant true
+// CHECK: %[[NOT_RECALC:.*]] = xor %[[RECALC]], %[[TRUE]] : i1
+// CHECK: %[[IS_SPECIAL_CASE3:.*]] = and %[[IS_SPECIAL_CASE2]], %[[NOT_RECALC]] : i1
+// CHECK: %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN:.*]] = and %[[IS_SPECIAL_CASE3]], %[[LHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[LHS_REAL2]] : f32
+// CHECK: %[[LHS_REAL3:.*]] = select %[[IS_SPECIAL_CASE_AND_LHS_REAL_IS_NAN]], %[[TMP]], %[[LHS_REAL2]] : f32
+// CHECK: %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN:.*]] = and %[[IS_SPECIAL_CASE3]], %[[LHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[LHS_IMAG2]] : f32
+// CHECK: %[[LHS_IMAG3:.*]] = select %[[IS_SPECIAL_CASE_AND_LHS_IMAG_IS_NAN]], %[[TMP]], %[[LHS_IMAG2]] : f32
+// CHECK: %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN:.*]] = and %[[IS_SPECIAL_CASE3]], %[[RHS_REAL_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[RHS_REAL2]] : f32
+// CHECK: %[[RHS_REAL3:.*]] = select %[[IS_SPECIAL_CASE_AND_RHS_REAL_IS_NAN]], %[[TMP]], %[[RHS_REAL2]] : f32
+// CHECK: %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN:.*]] = and %[[IS_SPECIAL_CASE3]], %[[RHS_IMAG_IS_NAN]] : i1
+// CHECK: %[[TMP:.*]] = copysign %[[ZERO]], %[[RHS_IMAG2]] : f32
+// CHECK: %[[RHS_IMAG3:.*]] = select %[[IS_SPECIAL_CASE_AND_RHS_IMAG_IS_NAN]], %[[TMP]], %[[RHS_IMAG2]] : f32
+// CHECK: %[[RECALC2:.*]] = or %[[RECALC]], %[[IS_SPECIAL_CASE3]] : i1
+// CHECK: %[[RECALC3:.*]] = and %[[IS_NAN]], %[[RECALC2]] : i1
+
+ // Recalculate real part.
+// CHECK: %[[LHS_REAL_TIMES_RHS_REAL:.*]] = mulf %[[LHS_REAL3]], %[[RHS_REAL3]] : f32
+// CHECK: %[[LHS_IMAG_TIMES_RHS_IMAG:.*]] = mulf %[[LHS_IMAG3]], %[[RHS_IMAG3]] : f32
+// CHECK: %[[NEW_REAL:.*]] = subf %[[LHS_REAL_TIMES_RHS_REAL]], %[[LHS_IMAG_TIMES_RHS_IMAG]] : f32
+// CHECK: %[[NEW_REAL_TIMES_INF:.*]] = mulf %[[INF]], %[[NEW_REAL]] : f32
+// CHECK: %[[FINAL_REAL:.*]] = select %[[RECALC3]], %[[NEW_REAL_TIMES_INF]], %[[REAL]] : f32
+
+// Recalculate imag part.
+// CHECK: %[[LHS_IMAG_TIMES_RHS_REAL:.*]] = mulf %[[LHS_IMAG3]], %[[RHS_REAL3]] : f32
+// CHECK: %[[LHS_REAL_TIMES_RHS_IMAG:.*]] = mulf %[[LHS_REAL3]], %[[RHS_IMAG3]] : f32
+// CHECK: %[[NEW_IMAG:.*]] = addf %[[LHS_IMAG_TIMES_RHS_REAL]], %[[LHS_REAL_TIMES_RHS_IMAG]] : f32
+// CHECK: %[[NEW_IMAG_TIMES_INF:.*]] = mulf %[[INF]], %[[NEW_IMAG]] : f32
+// CHECK: %[[FINAL_IMAG:.*]] = select %[[RECALC3]], %[[NEW_IMAG_TIMES_INF]], %[[IMAG]] : f32
+
+// CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
// CHECK-LABEL: func @complex_neg
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func @complex_neg(%arg: complex<f32>) -> complex<f32> {
More information about the Mlir-commits
mailing list