[Mlir-commits] [mlir] 6e80e3b - Add Log1pOp to complex dialect.

Adrian Kuegel llvmlistbot at llvm.org
Wed Jul 7 02:34:08 PDT 2021


Author: Adrian Kuegel
Date: 2021-07-07T11:33:54+02:00
New Revision: 6e80e3bd1bef3e7408b29a6d7eda0efbb829a65f

URL: https://github.com/llvm/llvm-project/commit/6e80e3bd1bef3e7408b29a6d7eda0efbb829a65f
DIFF: https://github.com/llvm/llvm-project/commit/6e80e3bd1bef3e7408b29a6d7eda0efbb829a65f.diff

LOG: Add Log1pOp to complex dialect.

Also add a lowering pattern from Complex to Standard/Math dialect.

Differential Revision: https://reviews.llvm.org/D105538

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
    mlir/test/Dialect/Complex/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index a116242dd078..d43b1e5dc1b2 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -216,6 +216,28 @@ def LogOp : ComplexUnaryOp<"log", [SameOperandsAndResultType]> {
   let results = (outs Complex<AnyFloat>:$result);
 }
 
+//===----------------------------------------------------------------------===//
+// Log1pOp
+//===----------------------------------------------------------------------===//
+
+def Log1pOp : ComplexUnaryOp<"log1p", [SameOperandsAndResultType]> {
+  let summary = "computes natural logarithm of a complex number";
+  let description = [{
+    The `log` op takes a single complex number and computes the natural
+    logarithm of one plus the given value, i.e. `log(1 + x)` or `log_e(1 + x)`,
+    where `x` is the input value. `e` denotes Euler's number and is
+    approximately equal to 2.718281.
+
+    Example:
+
+    ```mlir
+    %a = complex.log1p %b : complex<f32>
+    ```
+  }];
+
+  let results = (outs Complex<AnyFloat>:$result);
+}
+
 //===----------------------------------------------------------------------===//
 // MulOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 018882ae9489..4d3d52213e55 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -337,6 +337,28 @@ struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
   }
 };
 
+struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
+  using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::Log1pOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    complex::Log1pOp::Adaptor transformed(operands);
+    auto type = transformed.complex().getType().cast<ComplexType>();
+    auto elementType = type.getElementType().cast<FloatType>();
+    mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+    Value real = b.create<complex::ReOp>(elementType, transformed.complex());
+    Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
+    Value one =
+        b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1));
+    Value realPlusOne = b.create<AddFOp>(real, one);
+    Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
+    rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
+    return success();
+  }
+};
+
 struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
   using OpConversionPattern<complex::MulOp>::OpConversionPattern;
 
@@ -535,6 +557,7 @@ void mlir::populateComplexToStandardConversionPatterns(
       DivOpConversion,
       ExpOpConversion,
       LogOpConversion,
+      Log1pOpConversion,
       MulOpConversion,
       NegOpConversion,
       SignOpConversion>(patterns.getContext());
@@ -558,8 +581,9 @@ void ConvertComplexToStandardPass::runOnFunction() {
   target.addLegalDialect<StandardOpsDialect, math::MathDialect,
                          complex::ComplexDialect>();
   target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
-                      complex::ExpOp, complex::LogOp, complex::MulOp,
-                      complex::NegOp, complex::NotEqualOp, complex::SignOp>();
+                      complex::ExpOp, complex::LogOp, complex::Log1pOp,
+                      complex::MulOp, complex::NegOp, complex::NotEqualOp,
+                      complex::SignOp>();
   if (failed(applyPartialConversion(function, target, std::move(patterns))))
     signalPassFailure();
 }

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 95e6854ffa43..765d79c0bb8c 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -173,6 +173,30 @@ func @complex_log(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// CHECK-LABEL: func @complex_log1p
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
+  %log1p = complex.log1p %arg: complex<f32>
+  return %log1p : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = constant 1.000000e+00 : f32
+// CHECK: %[[REAL_PLUS_ONE:.*]] = addf %[[REAL]], %[[ONE]] : f32
+// CHECK: %[[NEW_COMPLEX:.*]] = complex.create %[[REAL_PLUS_ONE]], %[[IMAG]] : complex<f32>
+// CHECK: %[[REAL:.*]] = complex.re %[[NEW_COMPLEX]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[NEW_COMPLEX]] : complex<f32>
+// CHECK: %[[SQR_REAL:.*]] = mulf %[[REAL]], %[[REAL]] : f32
+// CHECK: %[[SQR_IMAG:.*]] = mulf %[[IMAG]], %[[IMAG]] : f32
+// CHECK: %[[SQ_NORM:.*]] = addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
+// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
+// CHECK: %[[REAL2:.*]] = complex.re %[[NEW_COMPLEX]] : complex<f32>
+// CHECK: %[[IMAG2:.*]] = complex.im %[[NEW_COMPLEX]] : complex<f32>
+// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
 // CHECK-LABEL: func @complex_mul
 // CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
 func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {

diff  --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir
index 74b45b8ae230..3fc0e9299c0f 100644
--- a/mlir/test/Dialect/Complex/ops.mlir
+++ b/mlir/test/Dialect/Complex/ops.mlir
@@ -32,6 +32,9 @@ func @ops(%f: f32) {
   // CHECK: complex.log %[[C]] : complex<f32>
   %log = complex.log %complex : complex<f32>
 
+  // CHECK: complex.log1p %[[C]] : complex<f32>
+  %log1p = complex.log1p %complex : complex<f32>
+
   // CHECK: complex.mul %[[C]], %[[C]] : complex<f32>
   %prod = complex.mul %complex, %complex : complex<f32>
 


        


More information about the Mlir-commits mailing list