[Mlir-commits] [mlir] f5fa633 - [mlir] Lower complex.sqrt and complex.atan2 to Arithmetic dialect.
Alexander Belyaev
llvmlistbot at llvm.org
Mon May 30 00:45:57 PDT 2022
Author: Alexander Belyaev
Date: 2022-05-30T09:44:36+02:00
New Revision: f5fa633b0955a8cee878b384801038fccef11fdc
URL: https://github.com/llvm/llvm-project/commit/f5fa633b0955a8cee878b384801038fccef11fdc
DIFF: https://github.com/llvm/llvm-project/commit/f5fa633b0955a8cee878b384801038fccef11fdc.diff
LOG: [mlir] Lower complex.sqrt and complex.atan2 to Arithmetic dialect.
I don't see a point here in the lit tests here since sqrt, mul and other ops
expand as well. I just added "smoke" tests to verify that the conversion works
and does not create any illegal ops.
I will create a patch that adds a simple integration test to
mlir/test/Integration/Dialect/ComplexOps/ that will compare the values.
Differential Revision: https://reviews.llvm.org/D126539
Added:
Modified:
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index e1eca6181dff9..194e1669a86c9 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -44,6 +44,49 @@ struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
}
};
+// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
+struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
+ using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+ auto type = op.getType().cast<ComplexType>();
+ Type elementType = type.getElementType();
+
+ Value lhs = adaptor.getLhs();
+ Value rhs = adaptor.getRhs();
+
+ Value rhsSquared = b.create<complex::MulOp>(type, rhs, rhs);
+ Value lhsSquared = b.create<complex::MulOp>(type, lhs, lhs);
+ Value rhsSquaredPlusLhsSquared =
+ b.create<complex::AddOp>(type, rhsSquared, lhsSquared);
+ Value sqrtOfRhsSquaredPlusLhsSquared =
+ b.create<complex::SqrtOp>(type, rhsSquaredPlusLhsSquared);
+
+ Value zero =
+ b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+ Value one = b.create<arith::ConstantOp>(elementType,
+ b.getFloatAttr(elementType, 1));
+ Value i = b.create<complex::CreateOp>(type, zero, one);
+ Value iTimesLhs = b.create<complex::MulOp>(i, lhs);
+ Value rhsPlusILhs = b.create<complex::AddOp>(rhs, iTimesLhs);
+
+ Value divResult =
+ b.create<complex::DivOp>(rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared);
+ Value logResult = b.create<complex::LogOp>(divResult);
+
+ Value negativeOne = b.create<arith::ConstantOp>(
+ elementType, b.getFloatAttr(elementType, -1));
+ Value negativeI = b.create<complex::CreateOp>(type, zero, negativeOne);
+
+ rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult);
+ return success();
+ }
+};
+
template <typename ComparisonOp, arith::CmpFPredicate p>
struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
using OpConversionPattern<ComparisonOp>::OpConversionPattern;
@@ -700,6 +743,72 @@ struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
}
};
+// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
+struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
+ using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+ auto type = op.getType().cast<ComplexType>();
+ Type elementType = type.getElementType();
+ Value arg = adaptor.getComplex();
+
+ Value zero =
+ b.create<arith::ConstantOp>(elementType, b.getZeroAttr(elementType));
+
+ Value real = b.create<complex::ReOp>(elementType, adaptor.getComplex());
+ Value imag = b.create<complex::ImOp>(elementType, adaptor.getComplex());
+
+ Value absLhs = b.create<math::AbsOp>(real);
+ Value absArg = b.create<complex::AbsOp>(elementType, arg);
+ Value addAbs = b.create<arith::AddFOp>(absLhs, absArg);
+ Value sqrtAddAbs = b.create<math::SqrtOp>(addAbs);
+ Value sqrtAddAbsDivTwo = b.create<arith::DivFOp>(
+ sqrtAddAbs, b.create<arith::ConstantOp>(
+ elementType, b.getFloatAttr(elementType, 2)));
+
+ Value realIsNegative =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, real, zero);
+ Value imagIsNegative =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, imag, zero);
+
+ Value resultReal = sqrtAddAbsDivTwo;
+
+ Value imagDivTwoResultReal = b.create<arith::DivFOp>(
+ imag, b.create<arith::AddFOp>(resultReal, resultReal));
+
+ Value negativeResultReal = b.create<arith::NegFOp>(resultReal);
+
+ Value resultImag = b.create<arith::SelectOp>(
+ realIsNegative,
+ b.create<arith::SelectOp>(imagIsNegative, negativeResultReal,
+ resultReal),
+ imagDivTwoResultReal);
+
+ resultReal = b.create<arith::SelectOp>(
+ realIsNegative,
+ b.create<arith::DivFOp>(
+ imag, b.create<arith::AddFOp>(resultImag, resultImag)),
+ resultReal);
+
+ Value realIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, real, zero);
+ Value imagIsZero =
+ b.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, imag, zero);
+ Value argIsZero = b.create<arith::AndIOp>(realIsZero, imagIsZero);
+
+ resultReal = b.create<arith::SelectOp>(argIsZero, zero, resultReal);
+ resultImag = b.create<arith::SelectOp>(argIsZero, zero, resultImag);
+
+ rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
+ resultImag);
+ return success();
+ }
+};
+
struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
using OpConversionPattern<complex::SignOp>::OpConversionPattern;
@@ -735,6 +844,7 @@ void mlir::populateComplexToStandardConversionPatterns(
// clang-format off
patterns.add<
AbsOpConversion,
+ Atan2OpConversion,
ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
@@ -748,7 +858,8 @@ void mlir::populateComplexToStandardConversionPatterns(
MulOpConversion,
NegOpConversion,
SignOpConversion,
- SinOpConversion>(patterns.getContext());
+ SinOpConversion,
+ SqrtOpConversion>(patterns.getContext());
// clang-format on
}
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 6f57e722b520e..bf41028718517 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline="func.func(convert-complex-to-standard)" | FileCheck %s
+// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file | FileCheck %s
// CHECK-LABEL: func @complex_abs
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
@@ -14,6 +14,17 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
// CHECK: return %[[NORM]] : f32
+// -----
+
+// CHECK-LABEL: func @complex_atan2
+func.func @complex_atan2(%lhs: complex<f32>,
+ %rhs: complex<f32>) -> complex<f32> {
+ %atan2 = complex.atan2 %lhs, %rhs : complex<f32>
+ return %atan2 : complex<f32>
+}
+
+// -----
+
// CHECK-LABEL: func @complex_add
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -29,6 +40,8 @@ func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_cos
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
@@ -50,6 +63,8 @@ func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]]
+// -----
+
// CHECK-LABEL: func @complex_div
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -159,6 +174,8 @@ func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_eq
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
@@ -174,6 +191,8 @@ func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
// CHECK: %[[EQUAL:.*]] = arith.andi %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
// CHECK: return %[[EQUAL]] : i1
+// -----
+
// CHECK-LABEL: func @complex_exp
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
@@ -190,6 +209,8 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func.func @complex_expm1(
// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
@@ -211,6 +232,8 @@ func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
// CHECK: return %[[RES]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_log
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
@@ -230,6 +253,8 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_log1p
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
@@ -254,6 +279,8 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_mul
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -372,6 +399,8 @@ func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_neg
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
@@ -385,6 +414,8 @@ func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_neq
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
@@ -400,6 +431,8 @@ func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
// CHECK: %[[NOT_EQUAL:.*]] = arith.ori %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
// CHECK: return %[[NOT_EQUAL]] : i1
+// -----
+
// CHECK-LABEL: func @complex_sin
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
@@ -421,6 +454,8 @@ func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]]
+// -----
+
// CHECK-LABEL: func @complex_sign
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
@@ -445,6 +480,8 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_sub
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -459,3 +496,11 @@ func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT_IMAG:.*]] = arith.subf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_sqrt
+func.func @complex_sqrt(%arg: complex<f32>) -> complex<f32> {
+ %sqrt = complex.sqrt %arg : complex<f32>
+ return %sqrt : complex<f32>
+}
More information about the Mlir-commits
mailing list