[Mlir-commits] [mlir] 6e80e3b - Add Log1pOp to complex dialect.
Adrian Kuegel
llvmlistbot at llvm.org
Wed Jul 7 02:34:08 PDT 2021
Author: Adrian Kuegel
Date: 2021-07-07T11:33:54+02:00
New Revision: 6e80e3bd1bef3e7408b29a6d7eda0efbb829a65f
URL: https://github.com/llvm/llvm-project/commit/6e80e3bd1bef3e7408b29a6d7eda0efbb829a65f
DIFF: https://github.com/llvm/llvm-project/commit/6e80e3bd1bef3e7408b29a6d7eda0efbb829a65f.diff
LOG: Add Log1pOp to complex dialect.
Also add a lowering pattern from Complex to Standard/Math dialect.
Differential Revision: https://reviews.llvm.org/D105538
Added:
Modified:
mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
mlir/test/Dialect/Complex/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index a116242dd078..d43b1e5dc1b2 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -216,6 +216,28 @@ def LogOp : ComplexUnaryOp<"log", [SameOperandsAndResultType]> {
let results = (outs Complex<AnyFloat>:$result);
}
+//===----------------------------------------------------------------------===//
+// Log1pOp
+//===----------------------------------------------------------------------===//
+
+def Log1pOp : ComplexUnaryOp<"log1p", [SameOperandsAndResultType]> {
+ let summary = "computes natural logarithm of a complex number";
+ let description = [{
+ The `log` op takes a single complex number and computes the natural
+ logarithm of one plus the given value, i.e. `log(1 + x)` or `log_e(1 + x)`,
+ where `x` is the input value. `e` denotes Euler's number and is
+ approximately equal to 2.718281.
+
+ Example:
+
+ ```mlir
+ %a = complex.log1p %b : complex<f32>
+ ```
+ }];
+
+ let results = (outs Complex<AnyFloat>:$result);
+}
+
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index 018882ae9489..4d3d52213e55 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -337,6 +337,28 @@ struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
}
};
+struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
+ using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::Log1pOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ complex::Log1pOp::Adaptor transformed(operands);
+ auto type = transformed.complex().getType().cast<ComplexType>();
+ auto elementType = type.getElementType().cast<FloatType>();
+ mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+
+ Value real = b.create<complex::ReOp>(elementType, transformed.complex());
+ Value imag = b.create<complex::ImOp>(elementType, transformed.complex());
+ Value one =
+ b.create<ConstantOp>(elementType, b.getFloatAttr(elementType, 1));
+ Value realPlusOne = b.create<AddFOp>(real, one);
+ Value newComplex = b.create<complex::CreateOp>(type, realPlusOne, imag);
+ rewriter.replaceOpWithNewOp<complex::LogOp>(op, type, newComplex);
+ return success();
+ }
+};
+
struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
using OpConversionPattern<complex::MulOp>::OpConversionPattern;
@@ -535,6 +557,7 @@ void mlir::populateComplexToStandardConversionPatterns(
DivOpConversion,
ExpOpConversion,
LogOpConversion,
+ Log1pOpConversion,
MulOpConversion,
NegOpConversion,
SignOpConversion>(patterns.getContext());
@@ -558,8 +581,9 @@ void ConvertComplexToStandardPass::runOnFunction() {
target.addLegalDialect<StandardOpsDialect, math::MathDialect,
complex::ComplexDialect>();
target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
- complex::ExpOp, complex::LogOp, complex::MulOp,
- complex::NegOp, complex::NotEqualOp, complex::SignOp>();
+ complex::ExpOp, complex::LogOp, complex::Log1pOp,
+ complex::MulOp, complex::NegOp, complex::NotEqualOp,
+ complex::SignOp>();
if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 95e6854ffa43..765d79c0bb8c 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -173,6 +173,30 @@ func @complex_log(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// CHECK-LABEL: func @complex_log1p
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func @complex_log1p(%arg: complex<f32>) -> complex<f32> {
+ %log1p = complex.log1p %arg: complex<f32>
+ return %log1p : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK: %[[ONE:.*]] = constant 1.000000e+00 : f32
+// CHECK: %[[REAL_PLUS_ONE:.*]] = addf %[[REAL]], %[[ONE]] : f32
+// CHECK: %[[NEW_COMPLEX:.*]] = complex.create %[[REAL_PLUS_ONE]], %[[IMAG]] : complex<f32>
+// CHECK: %[[REAL:.*]] = complex.re %[[NEW_COMPLEX]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[NEW_COMPLEX]] : complex<f32>
+// CHECK: %[[SQR_REAL:.*]] = mulf %[[REAL]], %[[REAL]] : f32
+// CHECK: %[[SQR_IMAG:.*]] = mulf %[[IMAG]], %[[IMAG]] : f32
+// CHECK: %[[SQ_NORM:.*]] = addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32
+// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32
+// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32
+// CHECK: %[[REAL2:.*]] = complex.re %[[NEW_COMPLEX]] : complex<f32>
+// CHECK: %[[IMAG2:.*]] = complex.im %[[NEW_COMPLEX]] : complex<f32>
+// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
// CHECK-LABEL: func @complex_mul
// CHECK-SAME: (%[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>)
func @complex_mul(%lhs: complex<f32>, %rhs: complex<f32>) -> complex<f32> {
diff --git a/mlir/test/Dialect/Complex/ops.mlir b/mlir/test/Dialect/Complex/ops.mlir
index 74b45b8ae230..3fc0e9299c0f 100644
--- a/mlir/test/Dialect/Complex/ops.mlir
+++ b/mlir/test/Dialect/Complex/ops.mlir
@@ -32,6 +32,9 @@ func @ops(%f: f32) {
// CHECK: complex.log %[[C]] : complex<f32>
%log = complex.log %complex : complex<f32>
+ // CHECK: complex.log1p %[[C]] : complex<f32>
+ %log1p = complex.log1p %complex : complex<f32>
+
// CHECK: complex.mul %[[C]], %[[C]] : complex<f32>
%prod = complex.mul %complex, %complex : complex<f32>
More information about the Mlir-commits
mailing list