[Mlir-commits] [mlir] 6c6eddb - [mlir] Lower complex.power and complex.rsqrt to standard dialect.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 8 10:54:03 PDT 2022
Author: bixia1
Date: 2022-06-08T10:53:53-07:00
New Revision: 6c6eddb6172f910c7e38d1327e5c6493b62c2950
URL: https://github.com/llvm/llvm-project/commit/6c6eddb6172f910c7e38d1327e5c6493b62c2950
DIFF: https://github.com/llvm/llvm-project/commit/6c6eddb6172f910c7e38d1327e5c6493b62c2950.diff
LOG: [mlir] Lower complex.power and complex.rsqrt to standard dialect.
Add conversion tests and correctness tests.
Reviewed By: pifon2a
Differential Revision: https://reviews.llvm.org/D127255
Added:
Modified:
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index e314f2eebc7d9..0a5124ada7a49 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -906,6 +906,109 @@ struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
}
};
+/// Coverts x^y = (a+bi)^(c+di) to
+/// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
+/// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
+static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
+ ComplexType type, Value a, Value b, Value c,
+ Value d) {
+ auto elementType = type.getElementType().cast<FloatType>();
+
+ // Compute (a*a+b*b)^(0.5c).
+ Value aaPbb = builder.create<arith::AddFOp>(
+ builder.create<arith::MulFOp>(a, a), builder.create<arith::MulFOp>(b, b));
+ Value half = builder.create<arith::ConstantOp>(
+ elementType, builder.getFloatAttr(elementType, 0.5));
+ Value halfC = builder.create<arith::MulFOp>(half, c);
+ Value aaPbbTohalfC = builder.create<math::PowFOp>(aaPbb, halfC);
+
+ // Compute exp(-d*atan2(b,a)).
+ Value negD = builder.create<arith::NegFOp>(d);
+ Value argX = builder.create<math::Atan2Op>(b, a);
+ Value negDArgX = builder.create<arith::MulFOp>(negD, argX);
+ Value eToNegDArgX = builder.create<math::ExpOp>(negDArgX);
+
+ // Compute (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)).
+ Value coeff = builder.create<arith::MulFOp>(aaPbbTohalfC, eToNegDArgX);
+
+ // Compute c*atan2(b,a)+0.5d*ln(a*a+b*b).
+ Value lnAaPbb = builder.create<math::LogOp>(aaPbb);
+ Value halfD = builder.create<arith::MulFOp>(half, d);
+ Value q = builder.create<arith::AddFOp>(
+ builder.create<arith::MulFOp>(c, argX),
+ builder.create<arith::MulFOp>(halfD, lnAaPbb));
+
+ Value cosQ = builder.create<math::CosOp>(q);
+ Value sinQ = builder.create<math::SinOp>(q);
+ Value zero = builder.create<arith::ConstantOp>(
+ elementType, builder.getFloatAttr(elementType, 0));
+ Value one = builder.create<arith::ConstantOp>(
+ elementType, builder.getFloatAttr(elementType, 1));
+
+ Value xEqZero =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, aaPbb, zero);
+ Value yGeZero = builder.create<arith::AndIOp>(
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, c, zero),
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, d, zero));
+ Value cEqZero =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, c, zero);
+ Value complexZero = builder.create<complex::CreateOp>(type, zero, zero);
+ Value complexOne = builder.create<complex::CreateOp>(type, one, zero);
+ Value complexOther = builder.create<complex::CreateOp>(
+ type, builder.create<arith::MulFOp>(coeff, cosQ),
+ builder.create<arith::MulFOp>(coeff, sinQ));
+
+ // x^y is 0 if x is 0 and y > 0. 0^0 is defined to be 1.0, see
+ // Branch Cuts for Complex Elementary Functions or Much Ado About
+ // Nothing's Sign Bit, W. Kahan, Section 10.
+ return builder.create<arith::SelectOp>(
+ builder.create<arith::AndIOp>(xEqZero, yGeZero),
+ builder.create<arith::SelectOp>(cEqZero, complexOne, complexZero),
+ complexOther);
+}
+
+struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
+ using OpConversionPattern<complex::PowOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
+ auto type = adaptor.getLhs().getType().cast<ComplexType>();
+ auto elementType = type.getElementType().cast<FloatType>();
+
+ Value a = builder.create<complex::ReOp>(elementType, adaptor.getLhs());
+ Value b = builder.create<complex::ImOp>(elementType, adaptor.getLhs());
+ Value c = builder.create<complex::ReOp>(elementType, adaptor.getRhs());
+ Value d = builder.create<complex::ImOp>(elementType, adaptor.getRhs());
+
+ rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
+ return success();
+ }
+};
+
+struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
+ using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
+ auto type = adaptor.getComplex().getType().cast<ComplexType>();
+ auto elementType = type.getElementType().cast<FloatType>();
+
+ Value a = builder.create<complex::ReOp>(elementType, adaptor.getComplex());
+ Value b = builder.create<complex::ImOp>(elementType, adaptor.getComplex());
+ Value c = builder.create<arith::ConstantOp>(
+ elementType, builder.getFloatAttr(elementType, -0.5));
+ Value d = builder.create<arith::ConstantOp>(
+ elementType, builder.getFloatAttr(elementType, 0));
+
+ rewriter.replaceOp(op, {powOpConversionImpl(builder, type, a, b, c, d)});
+ return success();
+ }
+};
+
} // namespace
void mlir::populateComplexToStandardConversionPatterns(
@@ -931,7 +1034,9 @@ void mlir::populateComplexToStandardConversionPatterns(
SinOpConversion,
SqrtOpConversion,
TanOpConversion,
- TanhOpConversion
+ TanhOpConversion,
+ PowOpConversion,
+ RsqrtOpConversion
>(patterns.getContext());
// clang-format on
}
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 96875239cf9ff..5b37899075a4f 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -676,4 +676,21 @@ func.func @complex_conj(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
// CHECK: %[[NEG_IMAG:.*]] = arith.negf %[[IMAG]] : f32
// CHECK: %[[RESULT:.*]] = complex.create %[[REAL]], %[[NEG_IMAG]] : complex<f32>
-// CHECK: return %[[RESULT]] : complex<f32>
\ No newline at end of file
+// CHECK: return %[[RESULT]] : complex<f32>
+
+// -----
+
+// CHECK-LABEL: func.func @complex_pow
+func.func @complex_pow(%lhs: complex<f32>,
+ %rhs: complex<f32>) -> complex<f32> {
+ %pow = complex.pow %lhs, %rhs : complex<f32>
+ return %pow : complex<f32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @complex_rsqrt
+func.func @complex_rsqrt(%arg: complex<f32>) -> complex<f32> {
+ %rsqrt = complex.rsqrt %arg : complex<f32>
+ return %rsqrt : complex<f32>
+}
diff --git a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
index 67867f3cbd680..00ab3ed76e278 100644
--- a/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
+++ b/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
@@ -38,6 +38,11 @@ func.func @tanh(%arg: complex<f32>) -> complex<f32> {
func.return %tanh : complex<f32>
}
+func.func @rsqrt(%arg: complex<f32>) -> complex<f32> {
+ %sqrt = complex.rsqrt %arg : complex<f32>
+ func.return %sqrt : complex<f32>
+}
+
// %input contains pairs of lhs, rhs, i.e. [lhs_0, rhs_0, lhs_1, rhs_1,...]
func.func @test_binary(%input: tensor<?xcomplex<f32>>,
%func: (complex<f32>, complex<f32>) -> complex<f32>) {
@@ -67,6 +72,10 @@ func.func @atan2(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
func.return %atan2 : complex<f32>
}
+func.func @pow(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
+ %pow = complex.pow %lhs, %rhs : complex<f32>
+ func.return %pow : complex<f32>
+}
func.func @entry() {
// complex.sqrt test
@@ -121,6 +130,30 @@ func.func @entry() {
: (tensor<?xcomplex<f32>>, (complex<f32>, complex<f32>)
-> complex<f32>) -> ()
+ // complex.pow test
+ %pow_test = arith.constant dense<[
+ (0.0, 0.0), (0.0, 0.0),
+ // CHECK: 1
+ // CHECK-NEXT: 0
+ (0.0, 0.0), (1.0, 0.0),
+ // CHECK-NEXT: 0
+ // CHECK-NEXT: 0
+ (0.0, 0.0), (-1.0, 0.0),
+ // CHECK-NEXT: -nan
+ // CHECK-NEXT: -nan
+ (1.0, 1.0), (1.0, 1.0)
+ // CHECK-NEXT: 0.273
+ // CHECK-NEXT: 0.583
+ ]> : tensor<8xcomplex<f32>>
+ %pow_test_cast = tensor.cast %pow_test
+ : tensor<8xcomplex<f32>> to tensor<?xcomplex<f32>>
+
+ %pow_func = func.constant @pow : (complex<f32>, complex<f32>)
+ -> complex<f32>
+ call @test_binary(%pow_test_cast, %pow_func)
+ : (tensor<?xcomplex<f32>>, (complex<f32>, complex<f32>)
+ -> complex<f32>) -> ()
+
// complex.tanh test
%tanh_test = arith.constant dense<[
(-1.0, -1.0),
@@ -152,5 +185,36 @@ func.func @entry() {
call @test_unary(%tanh_test_cast, %tanh_func)
: (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
+ // complex.rsqrt test
+ %rsqrt_test = arith.constant dense<[
+ (-1.0, -1.0),
+ // CHECK: 0.321
+ // CHECK-NEXT: 0.776
+ (-1.0, 1.0),
+ // CHECK-NEXT: 0.321
+ // CHECK-NEXT: -0.776
+ (0.0, 0.0),
+ // CHECK-NEXT: nan
+ // CHECK-NEXT: nan
+ (0.0, 1.0),
+ // CHECK-NEXT: 0.707
+ // CHECK-NEXT: -0.707
+ (1.0, -1.0),
+ // CHECK-NEXT: 0.776
+ // CHECK-NEXT: 0.321
+ (1.0, 0.0),
+ // CHECK-NEXT: 1
+ // CHECK-NEXT: 0
+ (1.0, 1.0)
+ // CHECK-NEXT: 0.776
+ // CHECK-NEXT: -0.321
+ ]> : tensor<7xcomplex<f32>>
+ %rsqrt_test_cast = tensor.cast %rsqrt_test
+ : tensor<7xcomplex<f32>> to tensor<?xcomplex<f32>>
+
+ %rsqrt_func = func.constant @rsqrt : (complex<f32>) -> complex<f32>
+ call @test_unary(%rsqrt_test_cast, %rsqrt_func)
+ : (tensor<?xcomplex<f32>>, (complex<f32>) -> complex<f32>) -> ()
+
func.return
}
More information about the Mlir-commits
mailing list