[Mlir-commits] [mlir] 380fa71 - [mlir] Add LogOp lowering from Complex dialect to Standard/Math dialect.

Adrian Kuegel llvmlistbot at llvm.org
Mon Jul 5 00:34:03 PDT 2021


Author: Adrian Kuegel
Date: 2021-07-05T09:33:45+02:00
New Revision: 380fa71fb00998332ee5dd97f82aaf3eadd282ac

URL: https://github.com/llvm/llvm-project/commit/380fa71fb00998332ee5dd97f82aaf3eadd282ac
DIFF: https://github.com/llvm/llvm-project/commit/380fa71fb00998332ee5dd97f82aaf3eadd282ac.diff

LOG: [mlir] Add LogOp lowering from Complex dialect to Standard/Math dialect.

Differential Revision: https://reviews.llvm.org/D105342

Added: 
    

Modified: 
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index bcd9572a471f7..a5aa4cbda9c63 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -315,6 +315,28 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
   }
 };
 
+struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
+  using OpConversionPattern<complex::LogOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::LogOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    complex::LogOp::Adaptor transformed(operands);
+    auto type = transformed.complex().getType().cast<ComplexType>();
+    auto elementType = type.getElementType().cast<FloatType>();
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+    Value abs = b.create<complex::AbsOp>(elementType, transformed.complex());
+    Value resultReal = b.create<math::LogOp>(elementType, abs);
+    Value real = b.create<complex::ReOp>(elementType, transformed.complex());
+    Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
+    Value resultImag = b.create<math::Atan2Op>(elementType, imag, real);
+    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
+                                                   resultImag);
+    return success();
+  }
+};
+
 struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
   using OpConversionPattern<complex::NegOp>::OpConversionPattern;
 
@@ -374,6 +396,7 @@ void mlir::populateComplexToStandardConversionPatterns(
       ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
       DivOpConversion,
       ExpOpConversion,
+      LogOpConversion,
       NegOpConversion,
       SignOpConversion>(patterns.getContext());
   // clang-format on
@@ -396,8 +419,8 @@ void ConvertComplexToStandardPass::runOnFunction() {
   target.addLegalDialect<StandardOpsDialect, math::MathDialect,
                          complex::ComplexDialect>();
   target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
-                      complex::ExpOp, complex::NotEqualOp, complex::NegOp,
-                      complex::SignOp>();
+                      complex::ExpOp, complex::LogOp, complex::NotEqualOp,
+                      complex::NegOp, complex::SignOp>();
   if (failed(applyPartialConversion(function, target, std::move(patterns))))
     signalPassFailure();
 }

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 62edf1e74a4a9..a99585b6f2e5c 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -154,6 +154,25 @@ func @complex_exp(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// CHECK-LABEL: func @complex_log
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func @complex_log(%arg: complex<f32>) -> complex<f32> {
+  %log = complex.log %arg: complex<f32>
+  return %log : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[SQR_REAL:.*]] = mulf %[[REAL]], %[[REAL]] : f32
+// CHECK: %[[SQR_IMAG:.*]] = mulf %[[IMAG]], %[[IMAG]] : f32
+// CHECK: %[[SQ_NORM:.*]] = addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
+// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
+// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
 // CHECK-LABEL: func @complex_neg
 // CHECK-SAME: %[[ARG:.*]]: complex<f32>
 func @complex_neg(%arg: complex<f32>) -> complex<f32> {


        


More information about the Mlir-commits mailing list