[Mlir-commits] [mlir] ffb8eec - [mlir][complex] Lowering complex.tanh to standard
Alexander Belyaev
llvmlistbot at llvm.org
Wed Jun 1 02:15:58 PDT 2022
Author: lewuathe
Date: 2022-06-01T11:13:54+02:00
New Revision: ffb8eecdd660eee1784ae2f83a9f26cf317ed4ed
URL: https://github.com/llvm/llvm-project/commit/ffb8eecdd660eee1784ae2f83a9f26cf317ed4ed
DIFF: https://github.com/llvm/llvm-project/commit/ffb8eecdd660eee1784ae2f83a9f26cf317ed4ed.diff
LOG: [mlir][complex] Lowering complex.tanh to standard
Lowering complex.tanh to standard dialects including math, arith.
Reviewed By: pifon2a
Differential Revision: https://reviews.llvm.org/D126521
Added:
Modified:
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index a7fb824f302ce..4211fad79fba0 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -736,14 +736,45 @@ struct TanOpConversion : public OpConversionPattern<complex::TanOp> {
matchAndRewrite(complex::TanOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
-
Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex());
Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex());
rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos);
+ return success();
+ }
+};
+
+struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
+ using OpConversionPattern<complex::TanhOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto type = adaptor.getComplex().getType().cast<ComplexType>();
+ auto elementType = type.getElementType().cast<FloatType>();
+ // The hyperbolic tangent for complex number can be calculated as follows.
+ // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
+ // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number
+ Value real =
+ rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+ Value imag =
+ rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
+ Value tanhA = rewriter.create<math::TanhOp>(loc, real);
+ Value cosB = rewriter.create<math::CosOp>(loc, imag);
+ Value sinB = rewriter.create<math::SinOp>(loc, imag);
+ Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
+ Value numerator =
+ rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
+ Value one = rewriter.create<arith::ConstantOp>(
+ loc, elementType, rewriter.getFloatAttr(elementType, 1));
+ Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
+ Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
+ rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
return success();
}
};
+
} // namespace
void mlir::populateComplexToStandardConversionPatterns(
@@ -765,7 +796,8 @@ void mlir::populateComplexToStandardConversionPatterns(
NegOpConversion,
SignOpConversion,
SinOpConversion,
- TanOpConversion>(patterns.getContext());
+ TanOpConversion,
+ TanhOpConversion>(patterns.getContext());
// clang-format on
}
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 9fb0c8a87078a..5670dfa62821a 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline="func.func(convert-complex-to-standard)" | FileCheck %s
+// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file | FileCheck %s
// CHECK-LABEL: func @complex_abs
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
@@ -14,6 +14,8 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
// CHECK: return %[[NORM]] : f32
+// -----
+
// CHECK-LABEL: func @complex_add
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -29,6 +31,8 @@ func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_cos
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
@@ -50,6 +54,8 @@ func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]]
+// -----
+
// CHECK-LABEL: func @complex_div
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -159,6 +165,8 @@ func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_eq
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
@@ -174,6 +182,8 @@ func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
// CHECK: %[[EQUAL:.*]] = arith.andi %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
// CHECK: return %[[EQUAL]] : i1
+// -----
+
// CHECK-LABEL: func @complex_exp
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
@@ -190,6 +200,8 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func.func @complex_expm1(
// CHECK-SAME: %[[ARG:.*]]: complex<f32>) -> complex<f32> {
func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
@@ -211,6 +223,8 @@ func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
// CHECK: return %[[RES]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_log
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
@@ -230,6 +244,8 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_log1p
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
@@ -254,6 +270,8 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_mul
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -372,6 +390,8 @@ func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_neg
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
@@ -385,6 +405,8 @@ func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_neq
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
@@ -400,6 +422,8 @@ func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
// CHECK: %[[NOT_EQUAL:.*]] = arith.ori %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
// CHECK: return %[[NOT_EQUAL]] : i1
+// -----
+
// CHECK-LABEL: func @complex_sin
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
@@ -421,6 +445,8 @@ func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
// CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]]
+// -----
+
// CHECK-LABEL: func @complex_sign
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
@@ -445,6 +471,8 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_sub
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -460,6 +488,8 @@ func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// -----
+
// CHECK-LABEL: func @complex_tan
// CHECK-SAME: %[[ARG:.*]]: complex<f32>
func.func @complex_tan(%arg: complex<f32>) -> complex<f32> {
@@ -595,4 +625,23 @@ func.func @complex_tan(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT_REAL_WITH_SPECIAL_CASES:.*]] = arith.select %[[RESULT_IS_NAN]], %[[RESULT_REAL_SPECIAL_CASE_1]], %[[RESULT_REAL]] : f32
// CHECK: %[[RESULT_IMAG_WITH_SPECIAL_CASES:.*]] = arith.select %[[RESULT_IS_NAN]], %[[RESULT_IMAG_SPECIAL_CASE_1]], %[[RESULT_IMAG]] : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
-// CHECK: return %[[RESULT]] : complex<f32>
\ No newline at end of file
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_tanh
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_tanh(%arg: complex<f32>) -> complex<f32> {
+ %tanh = complex.tanh %arg: complex<f32>
+ return %tanh : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[TANH_A:.*]] = math.tanh %[[REAL]] : f32
+// CHECK: %[[COS_B:.*]] = math.cos %[[IMAG]] : f32
+// CHECK: %[[SIN_B:.*]] = math.sin %[[IMAG]] : f32
+// CHECK: %[[TAN_B:.*]] = arith.divf %[[SIN_B]], %[[COS_B]] : f32
+// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] : f32
+// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>
More information about the Mlir-commits
mailing list