[Mlir-commits] [mlir] ffb8eec - [mlir][complex] Lowering complex.tanh to standard

Alexander Belyaev llvmlistbot at llvm.org
Wed Jun 1 02:15:58 PDT 2022


Author: lewuathe
Date: 2022-06-01T11:13:54+02:00
New Revision: ffb8eecdd660eee1784ae2f83a9f26cf317ed4ed

URL: https://github.com/llvm/llvm-project/commit/ffb8eecdd660eee1784ae2f83a9f26cf317ed4ed
DIFF: https://github.com/llvm/llvm-project/commit/ffb8eecdd660eee1784ae2f83a9f26cf317ed4ed.diff

LOG: [mlir][complex] Lowering complex.tanh to standard

Lowering complex.tanh to standard dialects including math, arith.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D126521

Added: 
    

Modified: 
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index a7fb824f302ce..4211fad79fba0 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -736,14 +736,45 @@ struct TanOpConversion : public OpConversionPattern<complex::TanOp> {
   matchAndRewrite(complex::TanOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
-
     Value cos = rewriter.create<complex::CosOp>(loc, adaptor.getComplex());
     Value sin = rewriter.create<complex::SinOp>(loc, adaptor.getComplex());
     rewriter.replaceOpWithNewOp<complex::DivOp>(op, sin, cos);
+    return success();
+  }
+};
+
+struct TanhOpConversion : public OpConversionPattern<complex::TanhOp> {
+  using OpConversionPattern<complex::TanhOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto type = adaptor.getComplex().getType().cast<ComplexType>();
+    auto elementType = type.getElementType().cast<FloatType>();
 
+    // The hyperbolic tangent for complex number can be calculated as follows.
+    // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y))
+    // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number
+    Value real =
+        rewriter.create<complex::ReOp>(loc, elementType, adaptor.getComplex());
+    Value imag =
+        rewriter.create<complex::ImOp>(loc, elementType, adaptor.getComplex());
+    Value tanhA = rewriter.create<math::TanhOp>(loc, real);
+    Value cosB = rewriter.create<math::CosOp>(loc, imag);
+    Value sinB = rewriter.create<math::SinOp>(loc, imag);
+    Value tanB = rewriter.create<arith::DivFOp>(loc, sinB, cosB);
+    Value numerator =
+        rewriter.create<complex::CreateOp>(loc, type, tanhA, tanB);
+    Value one = rewriter.create<arith::ConstantOp>(
+        loc, elementType, rewriter.getFloatAttr(elementType, 1));
+    Value mul = rewriter.create<arith::MulFOp>(loc, tanhA, tanB);
+    Value denominator = rewriter.create<complex::CreateOp>(loc, type, one, mul);
+    rewriter.replaceOpWithNewOp<complex::DivOp>(op, numerator, denominator);
     return success();
   }
 };
+
 } // namespace
 
 void mlir::populateComplexToStandardConversionPatterns(
@@ -765,7 +796,8 @@ void mlir::populateComplexToStandardConversionPatterns(
       NegOpConversion,
       SignOpConversion,
       SinOpConversion,
-      TanOpConversion>(patterns.getContext());
+      TanOpConversion,
+      TanhOpConversion>(patterns.getContext());
   // clang-format on
 }
 

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 9fb0c8a87078a..5670dfa62821a 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -pass-pipeline="func.func(convert-complex-to-standard)" | FileCheck %s
+// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file | FileCheck %s
 
 // CHECK-LABEL: func @complex_abs
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
@@ -14,6 +14,8 @@ func.func @complex_abs(%arg: complex<f32>) -> f32 {
 // CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
 // CHECK: return %[[NORM]] : f32
 
+// -----
+
 // CHECK-LABEL: func @complex_add
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -29,6 +31,8 @@ func.func @complex_add(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_cos
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
@@ -50,6 +54,8 @@ func.func @complex_cos(%arg: complex<f32>) -> complex<f32> {
 // CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK:     return %[[RESULT]]
 
+// -----
+
 // CHECK-LABEL: func @complex_div
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -159,6 +165,8 @@ func.func @complex_div(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_eq
 // CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
 func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
@@ -174,6 +182,8 @@ func.func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
 // CHECK: %[[EQUAL:.*]] = arith.andi %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
 // CHECK: return %[[EQUAL]] : i1
 
+// -----
+
 // CHECK-LABEL: func @complex_exp
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
@@ -190,6 +200,8 @@ func.func @complex_exp(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL:   func.func @complex_expm1(
 // CHECK-SAME:                             %[[ARG:.*]]: complex<f32>) -> complex<f32> {
 func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
@@ -211,6 +223,8 @@ func.func @complex_expm1(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex<f32>
 // CHECK: return %[[RES]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_log
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
@@ -230,6 +244,8 @@ func.func @complex_log(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_log1p
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
@@ -254,6 +270,8 @@ func.func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_mul
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -372,6 +390,8 @@ func.func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_neg
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
@@ -385,6 +405,8 @@ func.func @complex_neg(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_neq
 // CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
 func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
@@ -400,6 +422,8 @@ func.func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
 // CHECK: %[[NOT_EQUAL:.*]] = arith.ori %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1
 // CHECK: return %[[NOT_EQUAL]] : i1
 
+// -----
+
 // CHECK-LABEL: func @complex_sin
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
@@ -421,6 +445,8 @@ func.func @complex_sin(%arg: complex<f32>) -> complex<f32> {
 // CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK:     return %[[RESULT]]
 
+// -----
+
 // CHECK-LABEL: func @complex_sign
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
@@ -445,6 +471,8 @@ func.func @complex_sign(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_sub
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
@@ -460,6 +488,8 @@ func.func @complex_sub(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// -----
+
 // CHECK-LABEL: func @complex_tan
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func.func @complex_tan(%arg: complex<f32>) -> complex<f32> {
@@ -595,4 +625,23 @@ func.func @complex_tan(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT_REAL_WITH_SPECIAL_CASES:.*]] = arith.select %[[RESULT_IS_NAN]], %[[RESULT_REAL_SPECIAL_CASE_1]], %[[RESULT_REAL]] : f32
 // CHECK: %[[RESULT_IMAG_WITH_SPECIAL_CASES:.*]] = arith.select %[[RESULT_IS_NAN]], %[[RESULT_IMAG_SPECIAL_CASE_1]], %[[RESULT_IMAG]] : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex<f32>
-// CHECK: return %[[RESULT]] : complex<f32>
\ No newline at end of file
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func @complex_tanh
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func.func @complex_tanh(%arg: complex<f32>) -> complex<f32> {
+  %tanh = complex.tanh %arg: complex<f32>
+  return %tanh : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[TANH_A:.*]] = math.tanh %[[REAL]] : f32
+// CHECK: %[[COS_B:.*]] = math.cos %[[IMAG]] : f32
+// CHECK: %[[SIN_B:.*]] = math.sin %[[IMAG]] : f32
+// CHECK: %[[TAN_B:.*]] = arith.divf %[[SIN_B]], %[[COS_B]] : f32
+// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex<f32>
+// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] : f32
+// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex<f32>


        


More information about the Mlir-commits mailing list