[Mlir-commits] [mlir] 402b837 - Revert "[mlir] Lower complex.sqrt and complex.atan2 to Arithmetic dialect."
Alexander Belyaev
llvmlistbot at llvm.org
Mon May 30 01:50:15 PDT 2022
Author: Alexander Belyaev
Date: 2022-05-30T10:48:58+02:00
New Revision: 402b8373021a69090aae505c7f39d9e750127ff2
URL: https://github.com/llvm/llvm-project/commit/402b8373021a69090aae505c7f39d9e750127ff2
DIFF: https://github.com/llvm/llvm-project/commit/402b8373021a69090aae505c7f39d9e750127ff2.diff
LOG: Revert "[mlir] Lower complex.sqrt and complex.atan2 to Arithmetic dialect."
This reverts commit f5fa633b0955a8cee878b384801038fccef11fdc.
Integration test sparse_complex_ops.mlir breaks because of it.
Added:
Modified:
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 194e1669a86c..e1eca6181dff 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -44,49 +44,6 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
}
};
-// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
-struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
- using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-
- auto type = op.getType().cast<ComplexType>();
- Type elementType = type.getElementType();
-
- Value lhs = adaptor.getLhs();
- Value rhs = adaptor.getRhs();
-
- Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
- Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
- Value rhsSquaredPlusLhsSquared =
- b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
- Value sqrtOfRhsSquaredPlusLhsSquared =
- b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
-
- Value zero =
- b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
- Value one = b.create<arith::ConstantOp>(elementType,
- b.getFloatAttr(elementType, 1));
- Value i = b.create<complex::CreateOp>(type, zero, one);
- Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
- Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
-
- Value divResult =
- b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
- Value logResult = b.create<complex::LogOp>(divResult);
-
- Value negativeOne = b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, -1));
- Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
-
- rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
- return success();
- }
-};
-
template <typename ComparisonOp, arith::CmpFPredicate p>
struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
using OpConversionPattern<ComparisonOp>::OpConversionPattern;
@@ -743,72 +700,6 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
}
};
-// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
-struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
- using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-
- auto type = op.getType().cast<ComplexType>();
- Type elementType = type.getElementType();
- Value arg = adaptor.getComplex();
-
- Value zero =
- b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
-
- Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
- Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
-
- Value absLhs = b.create<math::AbsOp>(real);
- Value absArg = b.create<complex::AbsOp>(elementType, arg);
- Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
- Value sqrtAddAbs = b.create<math::SqrtOp>(addAbs);
- Value sqrtAddAbsDivTwo = b.create<arith::DivFOp>(
- sqrtAddAbs, b.create<arith::ConstantOp>(
- elementType, b.getFloatAttr(elementType, 2)));
-
- Value realIsNegative =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
- Value imagIsNegative =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
-
- Value resultReal = sqrtAddAbsDivTwo;
-
- Value imagDivTwoResultReal = b.create<arith::DivFOp>(
- imag, b.create<arith::AddFOp>(resultReal, resultReal));
-
- Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
-
- Value resultImag = b.create<arith::SelectOp>(
- realIsNegative,
- b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
- resultReal),
- imagDivTwoResultReal);
-
- resultReal = b.create<arith::SelectOp>(
- realIsNegative,
- b.create<arith::DivFOp>(
- imag, b.create<arith::AddFOp>(resultImag, resultImag)),
- resultReal);
-
- Value realIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
- Value imagIsZero =
- b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
- Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
-
- resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
- resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
-
- rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
- resultImag);
- return success();
- }
-};
-
struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
using OpConversionPattern<complex::SignOp>::OpConversionPattern;
@@ -844,7 +735,6 @@ void mlir::populateComplexToStandardConversionPatterns(
// clang-format off
patterns.add<
AbsOpConversion,
- Atan2OpConversion,
ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
@@ -858,8 +748,7 @@ void mlir::populateComplexToStandardConversionPatterns(
MulOpConversion,
NegOpConversion,
SignOpConversion,
- SinOpConversion,
- SqrtOpConversion>(patterns.getContext());
+ SinOpConversion>(patterns.getContext());
// clang-format on
}
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index bf4102871851..6f57e722b520 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline="func.func(convert-complex-to-standard)" | FileCheck %s
// CHECK-LABEL: func @complex_abs
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
@@ -14,17 +14,6 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
// CHECK: return %[[NORM]] : f32
-// -----
-
-// CHECK-LABEL: func @complex_atan2
-func.func @complex_atan2(%lhs: complex<f32>,
- %rhs: complex<f32>) -> complex<f32> {
- %atan2 = complex.atan2 %lhs, %rhs : complex<f32>
- return %atan2 : complex<f32>
-}
-
-// -----
-
// CHECK-LABEL: func @complex_add
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -40,8 +29,6 @@ func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
-// -----
-
// CHECK-LABEL: func @complex_cos
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
@@ -63,8 +50,6 @@ func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]]
-// -----
-
// CHECK-LABEL: func @complex_div
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -174,8 +159,6 @@ func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
-// -----
-
// CHECK-LABEL: func @complex_eq
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
@@ -191,8 +174,6 @@ func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
// CHECK: %[[EQUAL:.*]] = arith.andi %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
// CHECK: return %[[EQUAL]] : i1
-// -----
-
// CHECK-LABEL: func @complex_exp
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
@@ -209,8 +190,6 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
-// -----
-
// CHECK-LABEL: func.func @complex_expm1(
// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
@@ -232,8 +211,6 @@ func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
// CHECK: return %[[RES]] : complex<f32>
-// -----
-
// CHECK-LABEL: func @complex_log
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
@@ -253,8 +230,6 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
-// -----
-
// CHECK-LABEL: func @complex_log1p
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
@@ -279,8 +254,6 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
-// -----
-
// CHECK-LABEL: func @complex_mul
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -399,8 +372,6 @@ func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
-// -----
-
// CHECK-LABEL: func @complex_neg
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
@@ -414,8 +385,6 @@ func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
-// -----
-
// CHECK-LABEL: func @complex_neq
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
@@ -431,8 +400,6 @@ func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
// CHECK: %[[NOT_EQUAL:.*]] = arith.ori %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
// CHECK: return %[[NOT_EQUAL]] : i1
-// -----
-
// CHECK-LABEL: func @complex_sin
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
@@ -454,8 +421,6 @@ func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]]
-// -----
-
// CHECK-LABEL: func @complex_sign
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
@@ -480,8 +445,6 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
-// -----
-
// CHECK-LABEL: func @complex_sub
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -496,11 +459,3 @@ func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
-
-// -----
-
-// CHECK-LABEL: func @complex_sqrt
-func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
- %sqrt = complex.sqrt %arg : complex<f32>
- return %sqrt : complex<f32>
-}
More information about the Mlir-commits
mailing list