[Mlir-commits] [mlir] 662e074 - [mlir] Add NegOp to complex dialect.
Adrian Kuegel
llvmlistbot at llvm.org
Tue Jun 15 03:16:38 PDT 2021
Author: Adrian Kuegel
Date: 2021-06-15T12:16:22+02:00
New Revision: 662e074d9043949eea4e360e47bf9e39959694b8
URL: https://github.com/llvm/llvm-project/commit/662e074d9043949eea4e360e47bf9e39959694b8
DIFF: https://github.com/llvm/llvm-project/commit/662e074d9043949eea4e360e47bf9e39959694b8.diff
LOG: [mlir] Add NegOp to complex dialect.
Also add a lowering pattern from complex dialect to standard dialect.
Differential Revision: https://reviews.llvm.org/D104284
Added:
Modified:
mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 63b21b86ad79b..989df0e7d368d 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -210,6 +210,25 @@ def MulOp : ComplexArithmeticOp<"mul"> {
}];
}
+//===----------------------------------------------------------------------===//
+// NegOp
+//===----------------------------------------------------------------------===//
+
+def NegOp : ComplexUnaryOp<"neg", [SameOperandsAndResultType]> {
+ let summary = "Negation operator";
+ let description = [{
+ The `neg` op takes a single complex number `complex` and returns `-complex`.
+
+ Example:
+
+ ```mlir
+ %a = complex.neg %b : complex<f32>
+ ```
+ }];
+
+ let results = (outs Complex<AnyFloat>:$result);
+}
+
//===----------------------------------------------------------------------===//
// NotEqualOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index a90ac06c020e4..51aa671d995c2 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -313,6 +313,28 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
return success();
}
};
+
+struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
+ using OpConversionPattern<complex::NegOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(complex::NegOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ complex::NegOp::Adaptor transformed(operands);
+ auto loc = op.getLoc();
+ auto type = transformed.complex().getType().cast<ComplexType>();
+ auto elementType = type.getElementType().cast<FloatType>();
+
+ Value real =
+ rewriter.create<complex::ReOp>(loc, elementType, transformed.complex());
+ Value imag =
+ rewriter.create<complex::ImOp>(loc, elementType, transformed.complex());
+ Value negReal = rewriter.create<NegFOp>(loc, real);
+ Value negImag = rewriter.create<NegFOp>(loc, imag);
+ rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
+ return success();
+ }
+};
} // namespace
void mlir::populateComplexToStandardConversionPatterns(
@@ -320,7 +342,8 @@ void mlir::populateComplexToStandardConversionPatterns(
patterns.add<AbsOpConversion,
ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
- DivOpConversion, ExpOpConversion>(patterns.getContext());
+ DivOpConversion, ExpOpConversion, NegOpConversion>(
+ patterns.getContext());
}
namespace {
@@ -340,7 +363,7 @@ void ConvertComplexToStandardPass::runOnFunction() {
target.addLegalDialect<StandardOpsDialect, math::MathDialect,
complex::ComplexDialect>();
target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
- complex::ExpOp, complex::NotEqualOp>();
+ complex::ExpOp, complex::NotEqualOp, complex::NegOp>();
if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 91b82c7ef16a7..fe75575b59cf3 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -154,6 +154,19 @@ func @complex_exp(%arg: complex<f32>) -> complex<f32> {
// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
// CHECK: return %[[RESULT]] : complex<f32>
+// CHECK-LABEL: func @complex_neg
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func @complex_neg(%arg: complex<f32>) -> complex<f32> {
+ %neg = complex.neg %arg: complex<f32>
+ return %neg : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK-DAG: %[[NEG_REAL:.*]] = negf %[[REAL]] : f32
+// CHECK-DAG: %[[NEG_IMAG:.*]] = negf %[[IMAG]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
// CHECK-LABEL: func @complex_neq
// CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {
More information about the Mlir-commits
mailing list