[Mlir-commits] [mlir] 6c6eddb - [mlir] Lower complex.power and complex.rsqrt to standard dialect.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 8 10:54:03 PDT 2022


Author: bixia1
Date: 2022-06-08T10:53:53-07:00
New Revision: 6c6eddb6172f910c7e38d1327e5c6493b62c2950

URL: https://github.com/llvm/llvm-project/commit/6c6eddb6172f910c7e38d1327e5c6493b62c2950
DIFF: https://github.com/llvm/llvm-project/commit/6c6eddb6172f910c7e38d1327e5c6493b62c2950.diff

LOG: [mlir] Lower complex.power and complex.rsqrt to standard dialect.

Add conversion tests and correctness tests.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D127255

Added: 
    

Modified: 
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
    mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index e314f2eebc7d9..0a5124ada7a49 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -906,6 +906,109 @@ struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
   }
 };
 
+/// Coverts x^y = (a+bi)^(c+di) to
+///    (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
+///    where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
+static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
+                                 ComplexType type, Value a, Value b, Value c,
+                                 Value d) {
+  auto elementType = type.getElementType().cast<FloatType>();
+
+  // Compute (a*a+b*b)^(0.5c).
+  Value aaPbb = builder.create<arith::AddFOp>(
+      builder.create<arith::MulFOp>(a, a), builder.create<arith::MulFOp>(b, b));
+  Value half = builder.create<arith::ConstantOp>(
+      elementType, builder.getFloatAttr(elementType, 0.5));
+  Value halfC = builder.create<arith::MulFOp>(half, c);
+  Value aaPbbTohalfC = builder.create<math::PowFOp>(aaPbb, halfC);
+
+  // Compute exp(-d*atan2(b,a)).
+  Value negD = builder.create<arith::NegFOp>(d);
+  Value argX = builder.create<math::Atan2Op>(b, a);
+  Value negDArgX = builder.create<arith::MulFOp>(negD, argX);
+  Value eToNegDArgX = builder.create<math::ExpOp>(negDArgX);
+
+  // Compute (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)).
+  Value coeff = builder.create<arith::MulFOp>(aaPbbTohalfC, eToNegDArgX);
+
+  // Compute c*atan2(b,a)+0.5d*ln(a*a+b*b).
+  Value lnAaPbb = builder.create<math::LogOp>(aaPbb);
+  Value halfD = builder.create<arith::MulFOp>(half, d);
+  Value q = builder.create<arith::AddFOp>(
+      builder.create<arith::MulFOp>(c, argX),
+      builder.create<arith::MulFOp>(halfD, lnAaPbb));
+
+  Value cosQ = builder.create<math::CosOp>(q);
+  Value sinQ = builder.create<math::SinOp>(q);
+  Value zero = builder.create<arith::ConstantOp>(
+      elementType, builder.getFloatAttr(elementType, 0));
+  Value one = builder.create<arith::ConstantOp>(
+      elementType, builder.getFloatAttr(elementType, 1));
+
+  Value xEqZero =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, aaPbb, zero);
+  Value yGeZero = builder.create<arith::AndIOp>(
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, c, zero),
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero));
+  Value cEqZero =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero);
+  Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
+  Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
+  Value complexOther = builder.create<complex::CreateOp>(
+      type, builder.create<arith::MulFOp>(coeff, cosQ),
+      builder.create<arith::MulFOp>(coeff, sinQ));
+
+  // x^y is 0 if x is 0 and y > 0. 0^0 is defined to be 1.0, see
+  // Branch Cuts for Complex Elementary Functions or Much Ado About
+  // Nothing's Sign Bit, W. Kahan, Section 10.
+  return builder.create<arith::SelectOp>(
+      builder.create<arith::AndIOp>(xEqZero, yGeZero),
+      builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero),
+      complexOther);
+}
+
+struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
+  using OpConversionPattern<complex::PowOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
+    auto type = adaptor.getLhs().getType().cast<ComplexType>();
+    auto elementType = type.getElementType().cast<FloatType>();
+
+    Value a = builder.create<complex::ReOp>(elementType, adaptor.getLhs());
+    Value b = builder.create<complex::ImOp>(elementType, adaptor.getLhs());
+    Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
+    Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
+
+    rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
+    return success();
+  }
+};
+
+struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
+  using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
+    auto type = adaptor.getComplex().getType().cast<ComplexType>();
+    auto elementType = type.getElementType().cast<FloatType>();
+
+    Value a = builder.create<complex::ReOp>(elementType, adaptor.getComplex());
+    Value b = builder.create<complex::ImOp>(elementType, adaptor.getComplex());
+    Value c = builder.create<arith::ConstantOp>(
+        elementType, builder.getFloatAttr(elementType, -0.5));
+    Value d = builder.create<arith::ConstantOp>(
+        elementType, builder.getFloatAttr(elementType, 0));
+
+    rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateComplexToStandardConversionPatterns(
@@ -931,7 +1034,9 @@ void mlir::populateComplexToStandardConversionPatterns(
       SinOpConversion,
       SqrtOpConversion,
       TanOpConversion,
-      TanhOpConversion
+      TanhOpConversion,
+      PowOpConversion,
+      RsqrtOpConversion
   >(patterns.getContext());
   // clang-format on
 }

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 96875239cf9ff..5b37899075a4f 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -676,4 +676,21 @@ func.func @complex_conj(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
 // CHECK: %[[NEG_IMAG:.*]] = arith.negf %[[IMAG]] : f32
 // CHECK: %[[RESULT:.*]] = complex.create %[[REAL]], %[[NEG_IMAG]] : complex<f32>
-// CHECK: return %[[RESULT]] : complex<f32>
\ No newline at end of file
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL:   func.func @complex_pow
+func.func @complex_pow(%lhs: complex<f32>,
+                         %rhs: complex<f32>) -> complex<f32> {
+  %pow = complex.pow %lhs, %rhs : complex<f32>
+  return %pow : complex<f32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @complex_rsqrt
+func.func @complex_rsqrt(%arg: complex<f32>) -> complex<f32> {
+  %rsqrt = complex.rsqrt %arg : complex<f32>
+  return %rsqrt : complex<f32>
+}

diff  --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
index 67867f3cbd680..00ab3ed76e278 100644
--- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
+++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
@@ -38,6 +38,11 @@ func.func @tanh(%arg: complex<f32>) -> complex<f32> {
   func.return %tanh : complex<f32>
 }
 
+func.func @rsqrt(%arg: complex<f32>) -> complex<f32> {
+  %sqrt = complex.rsqrt %arg : complex<f32>
+  func.return %sqrt : complex<f32>
+}
+
 // %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...]
 func.func @test_binary(%input: tensor<?xcomplex<f32>>,
                        %func: (complex<f32>, complex<f32>) -> complex<f32>) {
@@ -67,6 +72,10 @@ func.func @atan2(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
   func.return %atan2 : complex<f32>
 }
 
+func.func @pow(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+  %pow = complex.pow %lhs, %rhs : complex<f32>
+  func.return %pow : complex<f32>
+}
 
 func.func @entry() {
   // complex.sqrt test
@@ -121,6 +130,30 @@ func.func @entry() {
     : (tensor<?xcomplex<f32>>, (complex<f32>, complex<f32>)
     -> complex<f32>) -> ()
 
+  // complex.pow test
+  %pow_test = arith.constant dense<[
+    (0.0, 0.0), (0.0, 0.0),
+    // CHECK:       1
+    // CHECK-NEXT:  0
+    (0.0, 0.0), (1.0, 0.0),
+    // CHECK-NEXT:  0
+    // CHECK-NEXT:  0
+    (0.0, 0.0), (-1.0, 0.0),
+    // CHECK-NEXT:  -nan
+    // CHECK-NEXT:  -nan
+    (1.0, 1.0), (1.0, 1.0)
+    // CHECK-NEXT:  0.273
+    // CHECK-NEXT:  0.583
+  ]> : tensor<8xcomplex<f32>>
+  %pow_test_cast = tensor.cast %pow_test
+    :  tensor<8xcomplex<f32>> to tensor<?xcomplex<f32>>
+
+  %pow_func = func.constant @pow : (complex<f32>, complex<f32>)
+    -> complex<f32>
+  call @test_binary(%pow_test_cast, %pow_func)
+    : (tensor<?xcomplex<f32>>, (complex<f32>, complex<f32>)
+    -> complex<f32>) -> ()
+
   // complex.tanh test
   %tanh_test = arith.constant dense<[
     (-1.0, -1.0),
@@ -152,5 +185,36 @@ func.func @entry() {
   call @test_unary(%tanh_test_cast, %tanh_func)
     : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
 
+  // complex.rsqrt test
+  %rsqrt_test = arith.constant dense<[
+    (-1.0, -1.0),
+    // CHECK:       0.321
+    // CHECK-NEXT:  0.776
+    (-1.0, 1.0),
+    // CHECK-NEXT:  0.321
+    // CHECK-NEXT:  -0.776
+    (0.0, 0.0),
+    // CHECK-NEXT:  nan
+    // CHECK-NEXT:  nan
+    (0.0, 1.0),
+    // CHECK-NEXT:  0.707
+    // CHECK-NEXT:  -0.707
+    (1.0, -1.0),
+    // CHECK-NEXT:  0.776
+    // CHECK-NEXT:  0.321
+    (1.0, 0.0),
+    // CHECK-NEXT:  1
+    // CHECK-NEXT:  0
+    (1.0, 1.0)
+    // CHECK-NEXT:  0.776
+    // CHECK-NEXT:  -0.321
+  ]> : tensor<7xcomplex<f32>>
+  %rsqrt_test_cast = tensor.cast %rsqrt_test
+    :  tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
+
+  %rsqrt_func = func.constant @rsqrt : (complex<f32>) -> complex<f32>
+  call @test_unary(%rsqrt_test_cast, %rsqrt_func)
+    : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
+
   func.return
 }


        


More information about the Mlir-commits mailing list