[Mlir-commits] [mlir] 402b837 - Revert "[mlir] Lower complex.sqrt and complex.atan2 to Arithmetic dialect."

Alexander Belyaev llvmlistbot at llvm.org
Mon May 30 01:50:15 PDT 2022


Author: Alexander Belyaev
Date: 2022-05-30T10:48:58+02:00
New Revision: 402b8373021a69090aae505c7f39d9e750127ff2

URL: https://github.com/llvm/llvm-project/commit/402b8373021a69090aae505c7f39d9e750127ff2
DIFF: https://github.com/llvm/llvm-project/commit/402b8373021a69090aae505c7f39d9e750127ff2.diff

LOG: Revert "[mlir] Lower complex.sqrt and complex.atan2 to Arithmetic dialect."

This reverts commit f5fa633b0955a8cee878b384801038fccef11fdc.

Integration test sparse_complex_ops.mlir breaks because of it.

Added: 
    

Modified: 
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 194e1669a86c..e1eca6181dff 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -44,49 +44,6 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
   }
 };
 
-// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
-struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
-  using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-
-    auto type = op.getType().cast<ComplexType>();
-    Type elementType = type.getElementType();
-
-    Value lhs = adaptor.getLhs();
-    Value rhs = adaptor.getRhs();
-
-    Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
-    Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
-    Value rhsSquaredPlusLhsSquared =
-        b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
-    Value sqrtOfRhsSquaredPlusLhsSquared =
-        b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
-
-    Value zero =
-        b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
-    Value one = b.create<arith::ConstantOp>(elementType,
-                                            b.getFloatAttr(elementType, 1));
-    Value i = b.create<complex::CreateOp>(type, zero, one);
-    Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
-    Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
-
-    Value divResult =
-        b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
-    Value logResult = b.create<complex::LogOp>(divResult);
-
-    Value negativeOne = b.create<arith::ConstantOp>(
-        elementType, b.getFloatAttr(elementType, -1));
-    Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
-
-    rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
-    return success();
-  }
-};
-
 template <typename ComparisonOp, arith::CmpFPredicate p>
 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
@@ -743,72 +700,6 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
   }
 };
 
-// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
-struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
-  using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
-
-    auto type = op.getType().cast<ComplexType>();
-    Type elementType = type.getElementType();
-    Value arg = adaptor.getComplex();
-
-    Value zero =
-        b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
-
-    Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
-    Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
-
-    Value absLhs = b.create<math::AbsOp>(real);
-    Value absArg = b.create<complex::AbsOp>(elementType, arg);
-    Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
-    Value sqrtAddAbs = b.create<math::SqrtOp>(addAbs);
-    Value sqrtAddAbsDivTwo = b.create<arith::DivFOp>(
-        sqrtAddAbs, b.create<arith::ConstantOp>(
-                        elementType, b.getFloatAttr(elementType, 2)));
-
-    Value realIsNegative =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
-    Value imagIsNegative =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
-
-    Value resultReal = sqrtAddAbsDivTwo;
-
-    Value imagDivTwoResultReal = b.create<arith::DivFOp>(
-        imag, b.create<arith::AddFOp>(resultReal, resultReal));
-
-    Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
-
-    Value resultImag = b.create<arith::SelectOp>(
-        realIsNegative,
-        b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
-                                  resultReal),
-        imagDivTwoResultReal);
-
-    resultReal = b.create<arith::SelectOp>(
-        realIsNegative,
-        b.create<arith::DivFOp>(
-            imag, b.create<arith::AddFOp>(resultImag, resultImag)),
-        resultReal);
-
-    Value realIsZero =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
-    Value imagIsZero =
-        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
-    Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
-
-    resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
-    resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
-
-    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
-                                                   resultImag);
-    return success();
-  }
-};
-
 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
 
@@ -844,7 +735,6 @@ void mlir::populateComplexToStandardConversionPatterns(
   // clang-format off
   patterns.add<
       AbsOpConversion,
-      Atan2OpConversion,
       ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
@@ -858,8 +748,7 @@ void mlir::populateComplexToStandardConversionPatterns(
       MulOpConversion,
       NegOpConversion,
       SignOpConversion,
-      SinOpConversion,
-      SqrtOpConversion>(patterns.getContext());
+      SinOpConversion>(patterns.getContext());
   // clang-format on
 }
 

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index bf4102871851..6f57e722b520 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -pass-pipeline="func.func(convert-complex-to-standard)" | FileCheck %s
 
 // CHECK-LABEL: func @complex_abs
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
@@ -14,17 +14,6 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
 // CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
 // CHECK: return %[[NORM]] : f32
 
-// -----
-
-// CHECK-LABEL: func @complex_atan2
-func.func @complex_atan2(%lhs: complex<f32>,
-                         %rhs: complex<f32>) -> complex<f32> {
-  %atan2 = complex.atan2 %lhs, %rhs : complex<f32>
-  return %atan2 : complex<f32>
-}
-
-// -----
-
 // CHECK-LABEL: func @complex_add
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -40,8 +29,6 @@ func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
-// -----
-
 // CHECK-LABEL: func @complex_cos
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
@@ -63,8 +50,6 @@ func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
 // CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK:     return %[[RESULT]]
 
-// -----
-
 // CHECK-LABEL: func @complex_div
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -174,8 +159,6 @@ func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
-// -----
-
 // CHECK-LABEL: func @complex_eq
 // CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
 func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
@@ -191,8 +174,6 @@ func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
 // CHECK: %[[EQUAL:.*]] = arith.andi %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
 // CHECK: return %[[EQUAL]] : i1
 
-// -----
-
 // CHECK-LABEL: func @complex_exp
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
@@ -209,8 +190,6 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
-// -----
-
 // CHECK-LABEL:   func.func @complex_expm1(
 // CHECK-SAME:                             %[[ARG:.*]]: complex<f32>) -> complex<f32> {
 func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
@@ -232,8 +211,6 @@ func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
 // CHECK: return %[[RES]] : complex<f32>
 
-// -----
-
 // CHECK-LABEL: func @complex_log
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
@@ -253,8 +230,6 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
-// -----
-
 // CHECK-LABEL: func @complex_log1p
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
@@ -279,8 +254,6 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
-// -----
-
 // CHECK-LABEL: func @complex_mul
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -399,8 +372,6 @@ func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
-// -----
-
 // CHECK-LABEL: func @complex_neg
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
@@ -414,8 +385,6 @@ func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
-// -----
-
 // CHECK-LABEL: func @complex_neq
 // CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
 func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
@@ -431,8 +400,6 @@ func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
 // CHECK: %[[NOT_EQUAL:.*]] = arith.ori %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
 // CHECK: return %[[NOT_EQUAL]] : i1
 
-// -----
-
 // CHECK-LABEL: func @complex_sin
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
@@ -454,8 +421,6 @@ func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
 // CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK:     return %[[RESULT]]
 
-// -----
-
 // CHECK-LABEL: func @complex_sign
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
@@ -480,8 +445,6 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
-// -----
-
 // CHECK-LABEL: func @complex_sub
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -496,11 +459,3 @@ func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
-
-// -----
-
-// CHECK-LABEL: func @complex_sqrt
-func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
-  %sqrt = complex.sqrt %arg : complex<f32>
-  return %sqrt : complex<f32>
-}


        


More information about the Mlir-commits mailing list