[Mlir-commits] [mlir] f711785 - [mlir] Add conversion and tests for complex.[sqrt|atan2] to Arith.

Alexander Belyaev llvmlistbot at llvm.org
Wed Jun 1 11:22:08 PDT 2022


Author: Alexander Belyaev
Date: 2022-06-01T20:21:51+02:00
New Revision: f711785e61e72ca1f483a288c39557e1bdbd1eaa

URL: https://github.com/llvm/llvm-project/commit/f711785e61e72ca1f483a288c39557e1bdbd1eaa
DIFF: https://github.com/llvm/llvm-project/commit/f711785e61e72ca1f483a288c39557e1bdbd1eaa.diff

LOG: [mlir] Add conversion and tests for complex.[sqrt|atan2] to Arith.

Differential Revision: https://reviews.llvm.org/D126799

Added: 
    mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir

Modified: 
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 4211fad79fba0..2e981c097f2d6 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -44,6 +44,49 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
   }
 };
 
+// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
+struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
+  using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+    auto type = op.getType().cast<ComplexType>();
+    Type elementType = type.getElementType();
+
+    Value lhs = adaptor.getLhs();
+    Value rhs = adaptor.getRhs();
+
+    Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
+    Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
+    Value rhsSquaredPlusLhsSquared =
+        b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
+    Value sqrtOfRhsSquaredPlusLhsSquared =
+        b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
+
+    Value zero =
+        b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+    Value one = b.create<arith::ConstantOp>(elementType,
+                                            b.getFloatAttr(elementType, 1));
+    Value i = b.create<complex::CreateOp>(type, zero, one);
+    Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
+    Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
+
+    Value divResult =
+        b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
+    Value logResult = b.create<complex::LogOp>(divResult);
+
+    Value negativeOne = b.create<arith::ConstantOp>(
+        elementType, b.getFloatAttr(elementType, -1));
+    Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
+
+    rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
+    return success();
+  }
+};
+
 template <typename ComparisonOp, arith::CmpFPredicate p>
 struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
   using OpConversionPattern<ComparisonOp>::OpConversionPattern;
@@ -700,6 +743,73 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
   }
 };
 
+// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
+struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
+  using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+    auto type = op.getType().cast<ComplexType>();
+    Type elementType = type.getElementType();
+    Value arg = adaptor.getComplex();
+
+    Value zero =
+        b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+
+    Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
+    Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+
+    Value absLhs = b.create<math::AbsOp>(real);
+    Value absArg = b.create<complex::AbsOp>(elementType, arg);
+    Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
+
+    Value half = b.create<arith::ConstantOp>(
+                        elementType, b.getFloatAttr(elementType, 0.5));
+    Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half);
+    Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs);
+
+    Value realIsNegative =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
+    Value imagIsNegative =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
+
+    Value resultReal = sqrtAddAbs;
+
+    Value imagDivTwoResultReal = b.create<arith::DivFOp>(
+        imag, b.create<arith::AddFOp>(resultReal, resultReal));
+
+    Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
+
+    Value resultImag = b.create<arith::SelectOp>(
+        realIsNegative,
+        b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
+                                  resultReal),
+        imagDivTwoResultReal);
+
+    resultReal = b.create<arith::SelectOp>(
+        realIsNegative,
+        b.create<arith::DivFOp>(
+            imag, b.create<arith::AddFOp>(resultImag, resultImag)),
+        resultReal);
+
+    Value realIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
+    Value imagIsZero =
+        b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
+    Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
+
+    resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
+    resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
+
+    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
+                                                   resultImag);
+    return success();
+  }
+};
+
 struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
   using OpConversionPattern<complex::SignOp>::OpConversionPattern;
 
@@ -782,6 +892,7 @@ void mlir::populateComplexToStandardConversionPatterns(
   // clang-format off
   patterns.add<
       AbsOpConversion,
+      Atan2OpConversion,
       ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
       ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
       BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
@@ -796,6 +907,7 @@ void mlir::populateComplexToStandardConversionPatterns(
       NegOpConversion,
       SignOpConversion,
       SinOpConversion,
+      SqrtOpConversion,
       TanOpConversion,
       TanhOpConversion>(patterns.getContext());
   // clang-format on

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 5670dfa62821a..319a443f5d411 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file |\
+// RUN: FileCheck %s
 
 // CHECK-LABEL: func @complex_abs
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
@@ -16,6 +17,15 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: func @complex_atan2
+func.func @complex_atan2(%lhs: complex<f32>,
+                         %rhs: complex<f32>) -> complex<f32> {
+  %atan2 = complex.atan2 %lhs, %rhs : complex<f32>
+  return %atan2 : complex<f32>
+}
+
+// -----
+
 // CHECK-LABEL: func @complex_add
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -645,3 +655,11 @@ func.func @complex_tanh(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
 // CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] : f32
 // CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_sqrt
+func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
+  %sqrt = complex.sqrt %arg : complex<f32>
+  return %sqrt : complex<f32>
+}

diff  --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
new file mode 100644
index 0000000000000..e3a3aa605f860
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-opt %s \
+// RUN:   -func-bufferize -tensor-bufferize -arith-bufferize --canonicalize \
+// RUN:   -convert-scf-to-cf --convert-complex-to-standard \
+// RUN:   -convert-memref-to-llvm -convert-math-to-llvm -convert-math-to-libm \
+// RUN:   -convert-vector-to-llvm -convert-complex-to-llvm \
+// RUN:   -convert-func-to-llvm -reconcile-unrealized-casts |\
+// RUN: mlir-cpu-runner \
+// RUN:  -e entry -entry-point-result=void  \
+// RUN:  -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext |\
+// RUN: FileCheck %s
+
+func.func @test_unary(%input: tensor<?xcomplex<f32>>,
+                      %func: (complex<f32>) -> complex<f32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>>
+
+  scf.for %i = %c0 to %size step %c1 {
+    %elem = tensor.extract %input[%i]: tensor<?xcomplex<f32>>
+
+    %val = func.call_indirect %func(%elem) : (complex<f32>) -> complex<f32>
+    %real = complex.re %val : complex<f32>
+    %imag = complex.im %val: complex<f32>
+    vector.print %real : f32
+    vector.print %imag : f32
+    scf.yield
+  }
+  func.return
+}
+
+func.func @sqrt(%arg: complex<f32>) -> complex<f32> {
+  %sqrt = complex.sqrt %arg : complex<f32>
+  func.return %sqrt : complex<f32>
+}
+
+// %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...]
+func.func @test_binary(%input: tensor<?xcomplex<f32>>,
+                       %func: (complex<f32>, complex<f32>) -> complex<f32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>>
+
+  scf.for %i = %c0 to %size step %c2 {
+    %lhs = tensor.extract %input[%i]: tensor<?xcomplex<f32>>
+    %i_next = arith.addi %i, %c1 : index
+    %rhs = tensor.extract %input[%i_next]: tensor<?xcomplex<f32>>
+
+    %val = func.call_indirect %func(%lhs, %rhs)
+      : (complex<f32>, complex<f32>) -> complex<f32>
+    %real = complex.re %val : complex<f32>
+    %imag = complex.im %val: complex<f32>
+    vector.print %real : f32
+    vector.print %imag : f32
+    scf.yield
+  }
+  func.return
+}
+
+func.func @atan2(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %atan2 = complex.atan2 %lhs, %rhs : complex<f32>
+  func.return %atan2 : complex<f32>
+}
+
+
+func.func @entry() {
+  // complex.sqrt test
+  %sqrt_test = arith.constant dense<[
+    (-1.0, -1.0),
+    // CHECK:       0.455
+    // CHECK-NEXT: -1.098
+    (-1.0, 1.0),
+    // CHECK-NEXT:  0.455
+    // CHECK-NEXT:  1.098
+    (0.0, 0.0),
+    // CHECK-NEXT:  0
+    // CHECK-NEXT:  0
+    (0.0, 1.0),
+    // CHECK-NEXT:  0.707
+    // CHECK-NEXT:  0.707
+    (1.0, -1.0),
+    // CHECK-NEXT:  1.098
+    // CHECK-NEXT:  -0.455
+    (1.0, 0.0),
+    // CHECK-NEXT:  1
+    // CHECK-NEXT:  0
+    (1.0, 1.0)
+    // CHECK-NEXT:  1.098
+    // CHECK-NEXT:  0.455
+  ]> : tensor<7xcomplex<f32>>
+  %sqrt_test_cast = tensor.cast %sqrt_test
+    :  tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
+
+  %sqrt_func = func.constant @sqrt : (complex<f32>) -> complex<f32>
+  call @test_unary(%sqrt_test_cast, %sqrt_func)
+    : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
+
+  // complex.atan2 test
+  %atan2_test = arith.constant dense<[
+    (1.0, 2.0), (2.0, 1.0),
+    // CHECK:       0.785
+    // CHECK-NEXT:  0.346
+    (1.0, 1.0), (1.0, 0.0),
+    // CHECK-NEXT:  1.017
+    // CHECK-NEXT:  0.402
+    (1.0, 1.0), (1.0, 1.0)
+    // CHECK-NEXT:  0.785
+    // CHECK-NEXT:  0
+  ]> : tensor<6xcomplex<f32>>
+  %atan2_test_cast = tensor.cast %atan2_test
+    :  tensor<6xcomplex<f32>> to tensor<?xcomplex<f32>>
+
+  %atan2_func = func.constant @atan2 : (complex<f32>, complex<f32>)
+    -> complex<f32>
+  call @test_binary(%atan2_test_cast, %atan2_func)
+    : (tensor<?xcomplex<f32>>, (complex<f32>, complex<f32>)
+    -> complex<f32>) -> ()
+  func.return
+}


        


More information about the Mlir-commits mailing list