[Mlir-commits] [mlir] f5fa633 - [mlir] Lower complex.sqrt and complex.atan2 to Arithmetic dialect.

Alexander Belyaev llvmlistbot at llvm.org
Mon May 30 00:45:57 PDT 2022


Author: Alexander Belyaev
Date: 2022-05-30T09:44:36+02:00
New Revision: f5fa633b0955a8cee878b384801038fccef11fdc

URL: https://github.com/llvm/llvm-project/commit/f5fa633b0955a8cee878b384801038fccef11fdc
DIFF: https://github.com/llvm/llvm-project/commit/f5fa633b0955a8cee878b384801038fccef11fdc.diff

LOG: [mlir] Lower complex.sqrt and complex.atan2 to Arithmetic dialect.

I don't see a point here in the lit tests here since sqrt, mul and other ops
expand as well. I just added "smoke" tests to verify that the conversion works
and does not create any illegal ops.

I will create a patch that adds a simple integration test to
mlir/test/Integration/Dialect/ComplexOps/ that will compare the values.

Differential Revision: https://reviews.llvm.org/D126539

Added: 
    

Modified: 
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index e1eca6181dff9..194e1669a86c9 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -44,6 +44,49 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
   }
 };
 
+// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
+struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
+  using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+    auto type = op.getType().cast<ComplexType>();
+    Type elementType = type.getElementType();
+
+    Value lhs = adaptor.getLhs();
+    Value rhs = adaptor.getRhs();
+
+    Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
+    Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
+    Value rhsSquaredPlusLhsSquared =
+        b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
+    Value sqrtOfRhsSquaredPlusLhsSquared =
+        b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
+
+    Value zero =
+        b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+    Value one = b.create<arith::ConstantOp>(elementType,
+                                            b.getFloatAttr(elementType, 1));
+    Value i = b.create<complex::CreateOp>(type, zero, one);
+    Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
+    Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
+
+    Value divResult =
+        b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
+    Value logResult = b.create<complex::LogOp>(divResult);
+
+    Value negativeOne = b.create<arith::ConstantOp>(
+        elementType, b.getFloatAttr(elementType, -1));
+    Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
+
+    rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
+    return success();
+  }
+};
+
 template <typename ComparisonOp, arith::CmpFPredicate p>
 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
@@ -700,6 +743,72 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
   }
 };
 
+// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
+struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
+  using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+    auto type = op.getType().cast<ComplexType>();
+    Type elementType = type.getElementType();
+    Value arg = adaptor.getComplex();
+
+    Value zero =
+        b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+
+    Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
+    Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+
+    Value absLhs = b.create<math::AbsOp>(real);
+    Value absArg = b.create<complex::AbsOp>(elementType, arg);
+    Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
+    Value sqrtAddAbs = b.create<math::SqrtOp>(addAbs);
+    Value sqrtAddAbsDivTwo = b.create<arith::DivFOp>(
+        sqrtAddAbs, b.create<arith::ConstantOp>(
+                        elementType, b.getFloatAttr(elementType, 2)));
+
+    Value realIsNegative =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
+    Value imagIsNegative =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
+
+    Value resultReal = sqrtAddAbsDivTwo;
+
+    Value imagDivTwoResultReal = b.create<arith::DivFOp>(
+        imag, b.create<arith::AddFOp>(resultReal, resultReal));
+
+    Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
+
+    Value resultImag = b.create<arith::SelectOp>(
+        realIsNegative,
+        b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
+                                  resultReal),
+        imagDivTwoResultReal);
+
+    resultReal = b.create<arith::SelectOp>(
+        realIsNegative,
+        b.create<arith::DivFOp>(
+            imag, b.create<arith::AddFOp>(resultImag, resultImag)),
+        resultReal);
+
+    Value realIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
+    Value imagIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
+    Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
+
+    resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
+    resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
+
+    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
+                                                   resultImag);
+    return success();
+  }
+};
+
 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
 
@@ -735,6 +844,7 @@ void mlir::populateComplexToStandardConversionPatterns(
   // clang-format off
   patterns.add<
       AbsOpConversion,
+      Atan2OpConversion,
       ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
@@ -748,7 +858,8 @@ void mlir::populateComplexToStandardConversionPatterns(
       MulOpConversion,
       NegOpConversion,
       SignOpConversion,
-      SinOpConversion>(patterns.getContext());
+      SinOpConversion,
+      SqrtOpConversion>(patterns.getContext());
   // clang-format on
 }
 

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 6f57e722b520e..bf41028718517 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline="func.func(convert-complex-to-standard)" | FileCheck %s
+// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @complex_abs
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
@@ -14,6 +14,17 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
 // CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
 // CHECK: return %[[NORM]] : f32
 
+// -----
+
+// CHECK-LABEL: func @complex_atan2
+func.func @complex_atan2(%lhs: complex<f32>,
+                         %rhs: complex<f32>) -> complex<f32> {
+  %atan2 = complex.atan2 %lhs, %rhs : complex<f32>
+  return %atan2 : complex<f32>
+}
+
+// -----
+
 // CHECK-LABEL: func @complex_add
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -29,6 +40,8 @@ func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_cos
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
@@ -50,6 +63,8 @@ func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
 // CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK:     return %[[RESULT]]
 
+// -----
+
 // CHECK-LABEL: func @complex_div
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -159,6 +174,8 @@ func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_eq
 // CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
 func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
@@ -174,6 +191,8 @@ func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
 // CHECK: %[[EQUAL:.*]] = arith.andi %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
 // CHECK: return %[[EQUAL]] : i1
 
+// -----
+
 // CHECK-LABEL: func @complex_exp
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
@@ -190,6 +209,8 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL:   func.func @complex_expm1(
 // CHECK-SAME:                             %[[ARG:.*]]: complex<f32>) -> complex<f32> {
 func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
@@ -211,6 +232,8 @@ func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
 // CHECK: return %[[RES]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_log
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
@@ -230,6 +253,8 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_log1p
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
@@ -254,6 +279,8 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_mul
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -372,6 +399,8 @@ func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_neg
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
@@ -385,6 +414,8 @@ func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_neq
 // CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
 func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
@@ -400,6 +431,8 @@ func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
 // CHECK: %[[NOT_EQUAL:.*]] = arith.ori %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
 // CHECK: return %[[NOT_EQUAL]] : i1
 
+// -----
+
 // CHECK-LABEL: func @complex_sin
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
@@ -421,6 +454,8 @@ func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
 // CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK:     return %[[RESULT]]
 
+// -----
+
 // CHECK-LABEL: func @complex_sign
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
@@ -445,6 +480,8 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_sub
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -459,3 +496,11 @@ func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_sqrt
+func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
+  %sqrt = complex.sqrt %arg : complex<f32>
+  return %sqrt : complex<f32>
+}


        


More information about the Mlir-commits mailing list