[Mlir-commits] [mlir] f711785 - [mlir] Add conversion and tests for complex.[sqrt|atan2] to Arith.
Alexander Belyaev
llvmlistbot at llvm.org
Wed Jun 1 11:22:08 PDT 2022
Author: Alexander Belyaev
Date: 2022-06-01T20:21:51+02:00
New Revision: f711785e61e72ca1f483a288c39557e1bdbd1eaa
URL: https://github.com/llvm/llvm-project/commit/f711785e61e72ca1f483a288c39557e1bdbd1eaa
DIFF: https://github.com/llvm/llvm-project/commit/f711785e61e72ca1f483a288c39557e1bdbd1eaa.diff
LOG: [mlir] Add conversion and tests for complex.[sqrt|atan2] to Arith.
Differential Revision: https://reviews.llvm.org/D126799
Added:
mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
Modified:
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 4211fad79fba0..2e981c097f2d6 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -44,6 +44,49 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
}
};
+// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
+struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
+ using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+ auto type = op.getType().cast<ComplexType>();
+ Type elementType = type.getElementType();
+
+ Value lhs = adaptor.getLhs();
+ Value rhs = adaptor.getRhs();
+
+ Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
+ Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
+ Value rhsSquaredPlusLhsSquared =
+ b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
+ Value sqrtOfRhsSquaredPlusLhsSquared =
+ b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
+
+ Value zero =
+ b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+ Value one = b.create<arith::ConstantOp>(elementType,
+ b.getFloatAttr(elementType, 1));
+ Value i = b.create<complex::CreateOp>(type, zero, one);
+ Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
+ Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
+
+ Value divResult =
+ b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
+ Value logResult = b.create<complex::LogOp>(divResult);
+
+ Value negativeOne = b.create<arith::ConstantOp>(
+ elementType, b.getFloatAttr(elementType, -1));
+ Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
+
+ rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
+ return success();
+ }
+};
+
template <typename ComparisonOp, arith::CmpFPredicate p>
struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
using OpConversionPattern<ComparisonOp>::OpConversionPattern;
@@ -700,6 +743,73 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
}
};
+// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
+struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
+ using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+ auto type = op.getType().cast<ComplexType>();
+ Type elementType = type.getElementType();
+ Value arg = adaptor.getComplex();
+
+ Value zero =
+ b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+
+ Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
+ Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+
+ Value absLhs = b.create<math::AbsOp>(real);
+ Value absArg = b.create<complex::AbsOp>(elementType, arg);
+ Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
+
+ Value half = b.create<arith::ConstantOp>(
+ elementType, b.getFloatAttr(elementType, 0.5));
+ Value halfAddAbs = b.create<arith::MulFOp>(addAbs, half);
+ Value sqrtAddAbs = b.create<math::SqrtOp>(halfAddAbs);
+
+ Value realIsNegative =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
+ Value imagIsNegative =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
+
+ Value resultReal = sqrtAddAbs;
+
+ Value imagDivTwoResultReal = b.create<arith::DivFOp>(
+ imag, b.create<arith::AddFOp>(resultReal, resultReal));
+
+ Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
+
+ Value resultImag = b.create<arith::SelectOp>(
+ realIsNegative,
+ b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
+ resultReal),
+ imagDivTwoResultReal);
+
+ resultReal = b.create<arith::SelectOp>(
+ realIsNegative,
+ b.create<arith::DivFOp>(
+ imag, b.create<arith::AddFOp>(resultImag, resultImag)),
+ resultReal);
+
+ Value realIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
+ Value imagIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
+ Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
+
+ resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
+ resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
+
+ rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
+ resultImag);
+ return success();
+ }
+};
+
struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
using OpConversionPattern<complex::SignOp>::OpConversionPattern;
@@ -782,6 +892,7 @@ void mlir::populateComplexToStandardConversionPatterns(
// clang-format off
patterns.add<
AbsOpConversion,
+ Atan2OpConversion,
ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
@@ -796,6 +907,7 @@ void mlir::populateComplexToStandardConversionPatterns(
NegOpConversion,
SignOpConversion,
SinOpConversion,
+ SqrtOpConversion,
TanOpConversion,
TanhOpConversion>(patterns.getContext());
// clang-format on
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 5670dfa62821a..319a443f5d411 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file |\
+// RUN: FileCheck %s
// CHECK-LABEL: func @complex_abs
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
@@ -16,6 +17,15 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
// -----
+// CHECK-LABEL: func @complex_atan2
+func.func @complex_atan2(%lhs: complex<f32>,
+ %rhs: complex<f32>) -> complex<f32> {
+ %atan2 = complex.atan2 %lhs, %rhs : complex<f32>
+ return %atan2 : complex<f32>
+}
+
+// -----
+
// CHECK-LABEL: func @complex_add
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -645,3 +655,11 @@ func.func @complex_tanh(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] : f32
// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_sqrt
+func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
+ %sqrt = complex.sqrt %arg : complex<f32>
+ return %sqrt : complex<f32>
+}
diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
new file mode 100644
index 0000000000000..e3a3aa605f860
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
@@ -0,0 +1,119 @@
+// RUN: mlir-opt %s \
+// RUN: -func-bufferize -tensor-bufferize -arith-bufferize --canonicalize \
+// RUN: -convert-scf-to-cf --convert-complex-to-standard \
+// RUN: -convert-memref-to-llvm -convert-math-to-llvm -convert-math-to-libm \
+// RUN: -convert-vector-to-llvm -convert-complex-to-llvm \
+// RUN: -convert-func-to-llvm -reconcile-unrealized-casts |\
+// RUN: mlir-cpu-runner \
+// RUN: -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext |\
+// RUN: FileCheck %s
+
+func.func @test_unary(%input: tensor<?xcomplex<f32>>,
+ %func: (complex<f32>) -> complex<f32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>>
+
+ scf.for %i = %c0 to %size step %c1 {
+ %elem = tensor.extract %input[%i]: tensor<?xcomplex<f32>>
+
+ %val = func.call_indirect %func(%elem) : (complex<f32>) -> complex<f32>
+ %real = complex.re %val : complex<f32>
+ %imag = complex.im %val: complex<f32>
+ vector.print %real : f32
+ vector.print %imag : f32
+ scf.yield
+ }
+ func.return
+}
+
+func.func @sqrt(%arg: complex<f32>) -> complex<f32> {
+ %sqrt = complex.sqrt %arg : complex<f32>
+ func.return %sqrt : complex<f32>
+}
+
+// %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...]
+func.func @test_binary(%input: tensor<?xcomplex<f32>>,
+ %func: (complex<f32>, complex<f32>) -> complex<f32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %size = tensor.dim %input, %c0: tensor<?xcomplex<f32>>
+
+ scf.for %i = %c0 to %size step %c2 {
+ %lhs = tensor.extract %input[%i]: tensor<?xcomplex<f32>>
+ %i_next = arith.addi %i, %c1 : index
+ %rhs = tensor.extract %input[%i_next]: tensor<?xcomplex<f32>>
+
+ %val = func.call_indirect %func(%lhs, %rhs)
+ : (complex<f32>, complex<f32>) -> complex<f32>
+ %real = complex.re %val : complex<f32>
+ %imag = complex.im %val: complex<f32>
+ vector.print %real : f32
+ vector.print %imag : f32
+ scf.yield
+ }
+ func.return
+}
+
+func.func @atan2(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+ %atan2 = complex.atan2 %lhs, %rhs : complex<f32>
+ func.return %atan2 : complex<f32>
+}
+
+
+func.func @entry() {
+ // complex.sqrt test
+ %sqrt_test = arith.constant dense<[
+ (-1.0, -1.0),
+ // CHECK: 0.455
+ // CHECK-NEXT: -1.098
+ (-1.0, 1.0),
+ // CHECK-NEXT: 0.455
+ // CHECK-NEXT: 1.098
+ (0.0, 0.0),
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 0
+ (0.0, 1.0),
+ // CHECK-NEXT: 0.707
+ // CHECK-NEXT: 0.707
+ (1.0, -1.0),
+ // CHECK-NEXT: 1.098
+ // CHECK-NEXT: -0.455
+ (1.0, 0.0),
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 0
+ (1.0, 1.0)
+ // CHECK-NEXT: 1.098
+ // CHECK-NEXT: 0.455
+ ]> : tensor<7xcomplex<f32>>
+ %sqrt_test_cast = tensor.cast %sqrt_test
+ : tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
+
+ %sqrt_func = func.constant @sqrt : (complex<f32>) -> complex<f32>
+ call @test_unary(%sqrt_test_cast, %sqrt_func)
+ : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
+
+ // complex.atan2 test
+ %atan2_test = arith.constant dense<[
+ (1.0, 2.0), (2.0, 1.0),
+ // CHECK: 0.785
+ // CHECK-NEXT: 0.346
+ (1.0, 1.0), (1.0, 0.0),
+ // CHECK-NEXT: 1.017
+ // CHECK-NEXT: 0.402
+ (1.0, 1.0), (1.0, 1.0)
+ // CHECK-NEXT: 0.785
+ // CHECK-NEXT: 0
+ ]> : tensor<6xcomplex<f32>>
+ %atan2_test_cast = tensor.cast %atan2_test
+ : tensor<6xcomplex<f32>> to tensor<?xcomplex<f32>>
+
+ %atan2_func = func.constant @atan2 : (complex<f32>, complex<f32>)
+ -> complex<f32>
+ call @test_binary(%atan2_test_cast, %atan2_func)
+ : (tensor<?xcomplex<f32>>, (complex<f32>, complex<f32>)
+ -> complex<f32>) -> ()
+ func.return
+}
More information about the Mlir-commits
mailing list