[Mlir-commits] [mlir] 981f0a1 - [mlir] Add canonicalizer to merge shape.assuming_all ops
Eugene Zhulenev
llvmlistbot at llvm.org
Fri Feb 4 15:27:43 PST 2022
Author: Eugene Zhulenev
Date: 2022-02-04T15:27:37-08:00
New Revision: 981f0a14f1de1455abc092c0692e5c78a16f24a7
URL: https://github.com/llvm/llvm-project/commit/981f0a14f1de1455abc092c0692e5c78a16f24a7
DIFF: https://github.com/llvm/llvm-project/commit/981f0a14f1de1455abc092c0692e5c78a16f24a7.diff
LOG: [mlir] Add canonicalizer to merge shape.assuming_all ops
Depends On D119021
Reviewed By: frgossen
Differential Revision: https://reviews.llvm.org/D119025
Added:
Modified:
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 661f621f1cee1..ecf9ade0b05c5 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -460,6 +460,39 @@ LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
//===----------------------------------------------------------------------===//
namespace {
+
+// Merge multiple `shape.assuming_all` operations together.
+//
+// %0 = shape.assuming_all %w0, %w1
+// %1 = shape.assuming_all %w2, %0
+//
+// to:
+//
+// %0 = shape.assuming_all %w0, %w2, %w2
+struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
+ using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AssumingAllOp op,
+ PatternRewriter &rewriter) const override {
+ SmallVector<Value> operands;
+
+ for (Value operand : op.getInputs()) {
+ if (auto assume_all = operand.getDefiningOp<AssumingAllOp>())
+ operands.append(assume_all.operand_begin(), assume_all->operand_end());
+ else
+ operands.push_back(operand);
+ }
+
+ // We didn't find any other `assuming_all` ops to merge with.
+ if (operands.size() == op.getNumOperands())
+ return failure();
+
+ // Replace with a new `assuming_all` operation with merged constraints.
+ rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
+ return success();
+ }
+};
+
struct AssumingAllToCstrEqCanonicalization
: public OpRewritePattern<AssumingAllOp> {
using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
@@ -506,7 +539,8 @@ struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
+ patterns.add<MergeAssumingAllOps, AssumingAllOneOp,
+ AssumingAllToCstrEqCanonicalization,
RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
}
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 75b92640a998d..425f8fe71a42f 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -463,6 +463,26 @@ func @cstr_require_no_fold(%arg0: i1) {
return
}
+// -----
+
+// merge assuming_all operations
+// CHECK-LABEL: func @f
+func @f() {
+ // CHECK-NEXT: %[[W0:.*]] = "test.source"
+ // CHECK-NEXT: %[[W1:.*]] = "test.source"
+ // CHECK-NEXT: %[[W2:.*]] = "test.source"
+ // CHECK-NEXT: shape.assuming_all %[[W0]], %[[W1]], %[[W2]]
+ // CHECK-NEXT: consume.witness
+ // CHECK-NEXT: return
+ %0 = "test.source"() : () -> !shape.witness
+ %1 = "test.source"() : () -> !shape.witness
+ %2 = "test.source"() : () -> !shape.witness
+ %3 = shape.assuming_all %0, %1
+ %4 = shape.assuming_all %3, %2
+ "consume.witness"(%4) : (!shape.witness) -> ()
+ return
+}
+
// -----
// `assuming_all` with all `cstr_eq` and shared operands can be collapsed.
// CHECK-LABEL: func @assuming_all_to_cstr_eq
More information about the Mlir-commits
mailing list