[Mlir-commits] [mlir] a1e7861 - [mlir][complex] Canonicalize re/im(neg(create))

Lei Zhang llvmlistbot at llvm.org
Mon May 29 17:54:13 PDT 2023


Author: Lei Zhang
Date: 2023-05-29T17:52:48-07:00
New Revision: a1e78615fb331484e07c2201433ba1e683348c47

URL: https://github.com/llvm/llvm-project/commit/a1e78615fb331484e07c2201433ba1e683348c47
DIFF: https://github.com/llvm/llvm-project/commit/a1e78615fb331484e07c2201433ba1e683348c47.diff

LOG: [mlir][complex] Canonicalize re/im(neg(create))

When can just convert this to arith.negf.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D151633

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
    mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
    mlir/test/Dialect/Complex/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 7116bed2763f6..dd7c1a8ca8866 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -290,6 +290,7 @@ def ImOp : ComplexUnaryOp<"im",
 
   let results = (outs AnyFloat:$imaginary);
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -436,6 +437,7 @@ def ReOp : ComplexUnaryOp<"re",
 
   let results = (outs AnyFloat:$real);
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index f2d1a96fa4a28..f8c9b63f12aa2 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -6,9 +6,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
 
 using namespace mlir;
 using namespace mlir::complex;
@@ -99,6 +102,36 @@ OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+namespace {
+template <typename OpKind, int ComponentIndex>
+struct FoldComponentNeg final : OpRewritePattern<OpKind> {
+  using OpRewritePattern<OpKind>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpKind op,
+                                PatternRewriter &rewriter) const override {
+    auto negOp = op.getOperand().template getDefiningOp<NegOp>();
+    if (!negOp)
+      return failure();
+
+    auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
+    if (!createOp)
+      return failure();
+
+    Type elementType = createOp.getType().getElementType();
+    assert(isa<FloatType>(elementType));
+
+    rewriter.replaceOpWithNewOp<arith::NegFOp>(
+        op, elementType, createOp.getOperand(ComponentIndex));
+    return success();
+  }
+};
+} // namespace
+
+void ImOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                       MLIRContext *context) {
+  results.add<FoldComponentNeg<ImOp, 1>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ReOp
 //===----------------------------------------------------------------------===//
@@ -113,6 +146,11 @@ OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+void ReOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                       MLIRContext *context) {
+  results.add<FoldComponentNeg<ReOp, 0>>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // AddOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir
index f0d287fde18aa..2fd2002c5cedf 100644
--- a/mlir/test/Dialect/Complex/canonicalize.mlir
+++ b/mlir/test/Dialect/Complex/canonicalize.mlir
@@ -155,3 +155,25 @@ func.func @complex_sub_zero() -> complex<f32> {
   %sub = complex.sub %complex1, %complex2 : complex<f32>
   return %sub : complex<f32>
 }
+
+// CHECK-LABEL: func @re_neg
+//  CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+func.func @re_neg(%arg0: f32, %arg1: f32) -> f32 {
+  %create = complex.create %arg0, %arg1: complex<f32>
+  // CHECK: %[[NEG:.*]] = arith.negf %[[ARG0]]
+  %neg = complex.neg %create : complex<f32>
+  %re = complex.re %neg : complex<f32>
+  // CHECK-NEXT: return %[[NEG]]
+  return %re : f32
+}
+
+// CHECK-LABEL: func @im_neg
+//  CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+func.func @im_neg(%arg0: f32, %arg1: f32) -> f32 {
+  %create = complex.create %arg0, %arg1: complex<f32>
+  // CHECK: %[[NEG:.*]] = arith.negf %[[ARG1]]
+  %neg = complex.neg %create : complex<f32>
+  %im = complex.im %neg : complex<f32>
+  // CHECK-NEXT: return %[[NEG]]
+  return %im : f32
+}


        


More information about the Mlir-commits mailing list