[Mlir-commits] [mlir] a1e7861 - [mlir][complex] Canonicalize re/im(neg(create))
Lei Zhang
llvmlistbot at llvm.org
Mon May 29 17:54:13 PDT 2023
Author: Lei Zhang
Date: 2023-05-29T17:52:48-07:00
New Revision: a1e78615fb331484e07c2201433ba1e683348c47
URL: https://github.com/llvm/llvm-project/commit/a1e78615fb331484e07c2201433ba1e683348c47
DIFF: https://github.com/llvm/llvm-project/commit/a1e78615fb331484e07c2201433ba1e683348c47.diff
LOG: [mlir][complex] Canonicalize re/im(neg(create))
When can just convert this to arith.negf.
Reviewed By: kuhar
Differential Revision: https://reviews.llvm.org/D151633
Added:
Modified:
mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
mlir/test/Dialect/Complex/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
index 7116bed2763f6..dd7c1a8ca8866 100644
--- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
+++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td
@@ -290,6 +290,7 @@ def ImOp : ComplexUnaryOp<"im",
let results = (outs AnyFloat:$imaginary);
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
@@ -436,6 +437,7 @@ def ReOp : ComplexUnaryOp<"re",
let results = (outs AnyFloat:$real);
let hasFolder = 1;
+ let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index f2d1a96fa4a28..f8c9b63f12aa2 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -6,9 +6,12 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::complex;
@@ -99,6 +102,36 @@ OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
return {};
}
+namespace {
+template <typename OpKind, int ComponentIndex>
+struct FoldComponentNeg final : OpRewritePattern<OpKind> {
+ using OpRewritePattern<OpKind>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(OpKind op,
+ PatternRewriter &rewriter) const override {
+ auto negOp = op.getOperand().template getDefiningOp<NegOp>();
+ if (!negOp)
+ return failure();
+
+ auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
+ if (!createOp)
+ return failure();
+
+ Type elementType = createOp.getType().getElementType();
+ assert(isa<FloatType>(elementType));
+
+ rewriter.replaceOpWithNewOp<arith::NegFOp>(
+ op, elementType, createOp.getOperand(ComponentIndex));
+ return success();
+ }
+};
+} // namespace
+
+void ImOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldComponentNeg<ImOp, 1>>(context);
+}
+
//===----------------------------------------------------------------------===//
// ReOp
//===----------------------------------------------------------------------===//
@@ -113,6 +146,11 @@ OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
return {};
}
+void ReOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<FoldComponentNeg<ReOp, 0>>(context);
+}
+
//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Complex/canonicalize.mlir b/mlir/test/Dialect/Complex/canonicalize.mlir
index f0d287fde18aa..2fd2002c5cedf 100644
--- a/mlir/test/Dialect/Complex/canonicalize.mlir
+++ b/mlir/test/Dialect/Complex/canonicalize.mlir
@@ -155,3 +155,25 @@ func.func @complex_sub_zero() -> complex<f32> {
%sub = complex.sub %complex1, %complex2 : complex<f32>
return %sub : complex<f32>
}
+
+// CHECK-LABEL: func @re_neg
+// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+func.func @re_neg(%arg0: f32, %arg1: f32) -> f32 {
+ %create = complex.create %arg0, %arg1: complex<f32>
+ // CHECK: %[[NEG:.*]] = arith.negf %[[ARG0]]
+ %neg = complex.neg %create : complex<f32>
+ %re = complex.re %neg : complex<f32>
+ // CHECK-NEXT: return %[[NEG]]
+ return %re : f32
+}
+
+// CHECK-LABEL: func @im_neg
+// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+func.func @im_neg(%arg0: f32, %arg1: f32) -> f32 {
+ %create = complex.create %arg0, %arg1: complex<f32>
+ // CHECK: %[[NEG:.*]] = arith.negf %[[ARG1]]
+ %neg = complex.neg %create : complex<f32>
+ %im = complex.im %neg : complex<f32>
+ // CHECK-NEXT: return %[[NEG]]
+ return %im : f32
+}
More information about the Mlir-commits
mailing list