[Mlir-commits] [mlir] 73cbc91 - [mlir] Add ExpOp to Complex dialect.

Adrian Kuegel llvmlistbot at llvm.org
Sun Jun 13 23:09:15 PDT 2021


Author: Adrian Kuegel
Date: 2021-06-14T08:08:53+02:00
New Revision: 73cbc91c93dd5a7ee1b8b1a9d507e194e835b446

URL: https://github.com/llvm/llvm-project/commit/73cbc91c93dd5a7ee1b8b1a9d507e194e835b446
DIFF: https://github.com/llvm/llvm-project/commit/73cbc91c93dd5a7ee1b8b1a9d507e194e835b446.diff

LOG: [mlir] Add ExpOp to Complex dialect.

Also add a conversion pattern from Complex to Standard/Math dialect.

Differential Revision: https://reviews.llvm.org/D104108

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 7e22ebfacfa05..1f71a97aab4ae 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -155,6 +155,27 @@ def EqualOp : Complex_Op<"eq",
   let assemblyFormat = "$lhs `,` $rhs  attr-dict `:` type($lhs)";
 }
 
+//===----------------------------------------------------------------------===//
+// ExpOp
+//===----------------------------------------------------------------------===//
+
+def ExpOp : ComplexUnaryOp<"exp", [SameOperandsAndResultType]> {
+  let summary = "computes exponential of a complex number";
+  let description = [{
+    The `exp` op takes a single complex number and computes the exponential of
+    it, i.e. `exp(x)` or `e^(x)`, where `x` is the input tensor.
+    `e` denotes Euler's number and is approximately equal to 2.718281.
+
+    Example:
+
+    ```mlir
+    %a = complex.exp %b : complex<f32>
+    ```
+  }];
+
+  let results = (outs Complex<AnyFloat>:$result);
+}
+
 //===----------------------------------------------------------------------===//
 // ImOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index f5c06638c86e2..a90ac06c020e4 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -86,7 +86,7 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
                   ConversionPatternRewriter &rewriter) const override {
     complex::DivOp::Adaptor transformed(operands);
     auto loc = op.getLoc();
-    auto type = transformed.lhs().getType().template cast<ComplexType>();
+    auto type = transformed.lhs().getType().cast<ComplexType>();
     auto elementType = type.getElementType().cast<FloatType>();
 
     Value lhsReal =
@@ -286,6 +286,33 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
     return success();
   }
 };
+
+struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
+  using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::ExpOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    complex::ExpOp::Adaptor transformed(operands);
+    auto loc = op.getLoc();
+    auto type = transformed.complex().getType().cast<ComplexType>();
+    auto elementType = type.getElementType().cast<FloatType>();
+
+    Value real =
+        rewriter.create<complex::ReOp>(loc, elementType, transformed.complex());
+    Value imag =
+        rewriter.create<complex::ImOp>(loc, elementType, transformed.complex());
+    Value expReal = rewriter.create<math::ExpOp>(loc, real);
+    Value cosImag = rewriter.create<math::CosOp>(loc, imag);
+    Value resultReal = rewriter.create<MulFOp>(loc, expReal, cosImag);
+    Value sinImag = rewriter.create<math::SinOp>(loc, imag);
+    Value resultImag = rewriter.create<MulFOp>(loc, expReal, sinImag);
+
+    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
+                                                   resultImag);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::populateComplexToStandardConversionPatterns(
@@ -293,7 +320,7 @@ void mlir::populateComplexToStandardConversionPatterns(
   patterns.add<AbsOpConversion,
                ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
                ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
-               DivOpConversion>(patterns.getContext());
+               DivOpConversion, ExpOpConversion>(patterns.getContext());
 }
 
 namespace {
@@ -313,7 +340,7 @@ void ConvertComplexToStandardPass::runOnFunction() {
   target.addLegalDialect<StandardOpsDialect, math::MathDialect,
                          complex::ComplexDialect>();
   target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
-                      complex::NotEqualOp>();
+                      complex::ExpOp, complex::NotEqualOp>();
   if (failed(applyPartialConversion(function, target, std::move(patterns))))
     signalPassFailure();
 }

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 998104045720e..91b82c7ef16a7 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -138,6 +138,22 @@ func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
 // CHECK: %[[EQUAL:.*]] = and %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
 // CHECK: return %[[EQUAL]] : i1
 
+// CHECK-LABEL: func @complex_exp
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func @complex_exp(%arg: complex<f32>) -> complex<f32> {
+  %exp = complex.exp %arg: complex<f32>
+  return %exp : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] : f32
+// CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] : f32
+// CHECK-DAG: %[[RESULT_REAL:.]] = mulf %[[EXP_REAL]], %[[COS_IMAG]] : f32
+// CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] : f32
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = mulf %[[EXP_REAL]], %[[SIN_IMAG]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
 // CHECK-LABEL: func @complex_neq
 // CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
 func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {


        


More information about the Mlir-commits mailing list