[Mlir-commits] [mlir] 73cbc91 - [mlir] Add ExpOp to Complex dialect.
Adrian Kuegel
llvmlistbot at llvm.org
Sun Jun 13 23:09:15 PDT 2021
Author: Adrian Kuegel
Date: 2021-06-14T08:08:53+02:00
New Revision: 73cbc91c93dd5a7ee1b8b1a9d507e194e835b446
URL: https://github.com/llvm/llvm-project/commit/73cbc91c93dd5a7ee1b8b1a9d507e194e835b446
DIFF: https://github.com/llvm/llvm-project/commit/73cbc91c93dd5a7ee1b8b1a9d507e194e835b446.diff
LOG: [mlir] Add ExpOp to Complex dialect.
Also add a conversion pattern from Complex to Standard/Math dialect.
Differential Revision: https://reviews.llvm.org/D104108
Added:
Modified:
mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 7e22ebfacfa05..1f71a97aab4ae 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -155,6 +155,27 @@ def EqualOp : Complex_Op<"eq",
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)";
}
+//===----------------------------------------------------------------------===//
+// ExpOp
+//===----------------------------------------------------------------------===//
+
+def ExpOp : ComplexUnaryOp<"exp", [SameOperandsAndResultType]> {
+ let summary = "computes exponential of a complex number";
+ let description = [{
+ The `exp` op takes a single complex number and computes the exponential of
+ it, i.e. `exp(x)` or `e^(x)`, where `x` is the input tensor.
+ `e` denotes Euler's number and is approximately equal to 2.718281.
+
+ Example:
+
+ ```mlir
+ %a = complex.exp %b : complex<f32>
+ ```
+ }];
+
+ let results = (outs Complex<AnyFloat>:$result);
+}
+
//===----------------------------------------------------------------------===//
// ImOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index f5c06638c86e2..a90ac06c020e4 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -86,7 +86,7 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
ConversionPatternRewriter &rewriter) const override {
complex::DivOp::Adaptor transformed(operands);
auto loc = op.getLoc();
- auto type = transformed.lhs().getType().template cast<ComplexType>();
+ auto type = transformed.lhs().getType().cast<ComplexType>();
auto elementType = type.getElementType().cast<FloatType>();
Value lhsReal =
@@ -286,6 +286,33 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
return success();
}
};
+
+struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
+ using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::ExpOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ complex::ExpOp::Adaptor transformed(operands);
+ auto loc = op.getLoc();
+ auto type = transformed.complex().getType().cast<ComplexType>();
+ auto elementType = type.getElementType().cast<FloatType>();
+
+ Value real =
+ rewriter.create<complex::ReOp>(loc, elementType, transformed.complex());
+ Value imag =
+ rewriter.create<complex::ImOp>(loc, elementType, transformed.complex());
+ Value expReal = rewriter.create<math::ExpOp>(loc, real);
+ Value cosImag = rewriter.create<math::CosOp>(loc, imag);
+ Value resultReal = rewriter.create<MulFOp>(loc, expReal, cosImag);
+ Value sinImag = rewriter.create<math::SinOp>(loc, imag);
+ Value resultImag = rewriter.create<MulFOp>(loc, expReal, sinImag);
+
+ rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
+ resultImag);
+ return success();
+ }
+};
} // namespace
void mlir::populateComplexToStandardConversionPatterns(
@@ -293,7 +320,7 @@ void mlir::populateComplexToStandardConversionPatterns(
patterns.add<AbsOpConversion,
ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
- DivOpConversion>(patterns.getContext());
+ DivOpConversion, ExpOpConversion>(patterns.getContext());
}
namespace {
@@ -313,7 +340,7 @@ void ConvertComplexToStandardPass::runOnFunction() {
target.addLegalDialect<StandardOpsDialect, math::MathDialect,
complex::ComplexDialect>();
target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
- complex::NotEqualOp>();
+ complex::ExpOp, complex::NotEqualOp>();
if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 998104045720e..91b82c7ef16a7 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -138,6 +138,22 @@ func @complex_eq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
// CHECK: %[[EQUAL:.*]] = and %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1
// CHECK: return %[[EQUAL]] : i1
+// CHECK-LABEL: func @complex_exp
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func @complex_exp(%arg: complex<f32>) -> complex<f32> {
+ %exp = complex.exp %arg: complex<f32>
+ return %exp : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] : f32
+// CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] : f32
+// CHECK-DAG: %[[RESULT_REAL:.]] = mulf %[[EXP_REAL]], %[[COS_IMAG]] : f32
+// CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] : f32
+// CHECK-DAG: %[[RESULT_IMAG:.*]] = mulf %[[EXP_REAL]], %[[SIN_IMAG]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
// CHECK-LABEL: func @complex_neq
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
More information about the Mlir-commits
mailing list