[Mlir-commits] [mlir] 981f0a1 - [mlir] Add canonicalizer to merge shape.assuming_all ops

Eugene Zhulenev llvmlistbot at llvm.org
Fri Feb 4 15:27:43 PST 2022


Author: Eugene Zhulenev
Date: 2022-02-04T15:27:37-08:00
New Revision: 981f0a14f1de1455abc092c0692e5c78a16f24a7

URL: https://github.com/llvm/llvm-project/commit/981f0a14f1de1455abc092c0692e5c78a16f24a7
DIFF: https://github.com/llvm/llvm-project/commit/981f0a14f1de1455abc092c0692e5c78a16f24a7.diff

LOG: [mlir] Add canonicalizer to merge shape.assuming_all ops

Depends On D119021

Reviewed By: frgossen

Differential Revision: https://reviews.llvm.org/D119025

Added: 
    

Modified: 
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 661f621f1cee1..ecf9ade0b05c5 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -460,6 +460,39 @@ LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
 //===----------------------------------------------------------------------===//
 
 namespace {
+
+// Merge multiple `shape.assuming_all` operations together.
+//
+//   %0 = shape.assuming_all %w0, %w1
+//   %1 = shape.assuming_all %w2, %0
+//
+// to:
+//
+//   %0 = shape.assuming_all %w0, %w2, %w2
+struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
+  using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AssumingAllOp op,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<Value> operands;
+
+    for (Value operand : op.getInputs()) {
+      if (auto assume_all = operand.getDefiningOp<AssumingAllOp>())
+        operands.append(assume_all.operand_begin(), assume_all->operand_end());
+      else
+        operands.push_back(operand);
+    }
+
+    // We didn't find any other `assuming_all` ops to merge with.
+    if (operands.size() == op.getNumOperands())
+      return failure();
+
+    // Replace with a new `assuming_all` operation with merged constraints.
+    rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
+    return success();
+  }
+};
+
 struct AssumingAllToCstrEqCanonicalization
     : public OpRewritePattern<AssumingAllOp> {
   using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
@@ -506,7 +539,8 @@ struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
 
 void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
                                                 MLIRContext *context) {
-  patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
+  patterns.add<MergeAssumingAllOps, AssumingAllOneOp,
+               AssumingAllToCstrEqCanonicalization,
                RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
 }
 

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 75b92640a998d..425f8fe71a42f 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -463,6 +463,26 @@ func @cstr_require_no_fold(%arg0: i1) {
   return
 }
 
+// -----
+
+// merge assuming_all operations
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK-NEXT: %[[W0:.*]] = "test.source"
+  // CHECK-NEXT: %[[W1:.*]] = "test.source"
+  // CHECK-NEXT: %[[W2:.*]] = "test.source"
+  // CHECK-NEXT: shape.assuming_all %[[W0]], %[[W1]], %[[W2]]
+  // CHECK-NEXT: consume.witness
+  // CHECK-NEXT: return
+  %0 = "test.source"() : () -> !shape.witness
+  %1 = "test.source"() : () -> !shape.witness
+  %2 = "test.source"() : () -> !shape.witness
+  %3 = shape.assuming_all %0, %1
+  %4 = shape.assuming_all %3, %2
+  "consume.witness"(%4) : (!shape.witness) -> ()
+  return
+}
+
 // -----
 // `assuming_all` with all `cstr_eq` and shared operands can be collapsed.
 // CHECK-LABEL: func @assuming_all_to_cstr_eq


        


More information about the Mlir-commits mailing list