[Mlir-commits] [mlir] 662e074 - [mlir] Add NegOp to complex dialect.

Adrian Kuegel llvmlistbot at llvm.org
Tue Jun 15 03:16:38 PDT 2021


Author: Adrian Kuegel
Date: 2021-06-15T12:16:22+02:00
New Revision: 662e074d9043949eea4e360e47bf9e39959694b8

URL: https://github.com/llvm/llvm-project/commit/662e074d9043949eea4e360e47bf9e39959694b8
DIFF: https://github.com/llvm/llvm-project/commit/662e074d9043949eea4e360e47bf9e39959694b8.diff

LOG: [mlir] Add NegOp to complex dialect.

Also add a lowering pattern from complex dialect to standard dialect.

Differential Revision: https://reviews.llvm.org/D104284

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
    mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
    mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 63b21b86ad79b..989df0e7d368d 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -210,6 +210,25 @@ def MulOp : ComplexArithmeticOp<"mul"> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// NegOp
+//===----------------------------------------------------------------------===//
+
+def NegOp : ComplexUnaryOp<"neg", [SameOperandsAndResultType]> {
+  let summary = "Negation operator";
+  let description = [{
+    The `neg` op takes a single complex number `complex` and returns `-complex`.
+
+    Example:
+
+    ```mlir
+    %a = complex.neg %b : complex<f32>
+    ```
+  }];
+
+  let results = (outs Complex<AnyFloat>:$result);
+}
+
 //===----------------------------------------------------------------------===//
 // NotEqualOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
index a90ac06c020e4..51aa671d995c2 100644
--- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
+++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp
@@ -313,6 +313,28 @@ struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
     return success();
   }
 };
+
+struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
+  using OpConversionPattern<complex::NegOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(complex::NegOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    complex::NegOp::Adaptor transformed(operands);
+    auto loc = op.getLoc();
+    auto type = transformed.complex().getType().cast<ComplexType>();
+    auto elementType = type.getElementType().cast<FloatType>();
+
+    Value real =
+        rewriter.create<complex::ReOp>(loc, elementType, transformed.complex());
+    Value imag =
+        rewriter.create<complex::ImOp>(loc, elementType, transformed.complex());
+    Value negReal = rewriter.create<NegFOp>(loc, real);
+    Value negImag = rewriter.create<NegFOp>(loc, imag);
+    rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
+    return success();
+  }
+};
 } // namespace
 
 void mlir::populateComplexToStandardConversionPatterns(
@@ -320,7 +342,8 @@ void mlir::populateComplexToStandardConversionPatterns(
   patterns.add<AbsOpConversion,
                ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>,
                ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>,
-               DivOpConversion, ExpOpConversion>(patterns.getContext());
+               DivOpConversion, ExpOpConversion, NegOpConversion>(
+      patterns.getContext());
 }
 
 namespace {
@@ -340,7 +363,7 @@ void ConvertComplexToStandardPass::runOnFunction() {
   target.addLegalDialect<StandardOpsDialect, math::MathDialect,
                          complex::ComplexDialect>();
   target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp,
-                      complex::ExpOp, complex::NotEqualOp>();
+                      complex::ExpOp, complex::NotEqualOp, complex::NegOp>();
   if (failed(applyPartialConversion(function, target, std::move(patterns))))
     signalPassFailure();
 }

diff  --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
index 91b82c7ef16a7..fe75575b59cf3 100644
--- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
+++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir
@@ -154,6 +154,19 @@ func @complex_exp(%arg: complex<f32>) -> complex<f32> {
 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32>
 // CHECK: return %[[RESULT]] : complex<f32>
 
+// CHECK-LABEL: func @complex_neg
+// CHECK-SAME: %[[ARG:.*]]: complex<f32>
+func @complex_neg(%arg: complex<f32>) -> complex<f32> {
+  %neg = complex.neg %arg: complex<f32>
+  return %neg : complex<f32>
+}
+// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32>
+// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32>
+// CHECK-DAG: %[[NEG_REAL:.*]] = negf %[[REAL]] : f32
+// CHECK-DAG: %[[NEG_IMAG:.*]] = negf %[[IMAG]] : f32
+// CHECK: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex<f32>
+// CHECK: return %[[RESULT]] : complex<f32>
+
 // CHECK-LABEL: func @complex_neq
 // CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32>
 func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {


        


More information about the Mlir-commits mailing list