[Mlir-commits] [mlir] 655e08c - [mlir] Canonicalization of shape.assuming

Tres Popp llvmlistbot at llvm.org
Fri Jun 5 02:01:03 PDT 2020


Author: Tres Popp
Date: 2020-06-05T11:00:20+02:00
New Revision: 655e08ceeb7bf908cc5460279acbe2882bd47c91

URL: https://github.com/llvm/llvm-project/commit/655e08ceeb7bf908cc5460279acbe2882bd47c91
DIFF: https://github.com/llvm/llvm-project/commit/655e08ceeb7bf908cc5460279acbe2882bd47c91.diff

LOG: [mlir] Canonicalization of shape.assuming

Summary:
This will inline the region to a shape.assuming in the case that the
input witness is found to be statically true.

Differential Revision: https://reviews.llvm.org/D80302

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
    mlir/lib/Dialect/Shape/IR/Shape.cpp
    mlir/test/Dialect/Shape/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 6fb7cbf1f3b7..88ee5f4d520b 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -509,6 +509,8 @@ def Shape_AssumingOp : Shape_Op<"assuming",
 
   let printer = [{ return ::print(p, *this); }];
   let parser = [{ return ::parse$cppClass(parser, result); }];
+
+  let hasCanonicalizer = 1;
 }
 
 def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",

diff  --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index e12e23ba128c..5866b0ac2680 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -159,6 +159,44 @@ static void print(OpAsmPrinter &p, AssumingOp op) {
   p.printOptionalAttrDict(op.getAttrs());
 }
 
+namespace {
+// Removes AssumingOp with a passing witness and inlines the region.
+struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
+  using OpRewritePattern<AssumingOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AssumingOp op,
+                                PatternRewriter &rewriter) const override {
+    auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
+    if (!witness || !witness.passingAttr())
+      return failure();
+
+    auto *blockBeforeAssuming = rewriter.getInsertionBlock();
+    auto *assumingBlock = op.getBody();
+    auto initPosition = rewriter.getInsertionPoint();
+    auto *blockAfterAssuming =
+        rewriter.splitBlock(blockBeforeAssuming, initPosition);
+
+    // Remove the AssumingOp and AssumingYieldOp.
+    auto &yieldOp = assumingBlock->back();
+    rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
+    rewriter.replaceOp(op, yieldOp.getOperands());
+    rewriter.eraseOp(&yieldOp);
+
+    // Merge blocks together as there was no branching behavior from the
+    // AssumingOp.
+    rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
+    rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
+    return success();
+  }
+};
+}; // namespace
+
+void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
+                                             MLIRContext *context) {
+  // If taking a passing witness, inline region
+  patterns.insert<AssumingWithTrue>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // AssumingAllOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 32fa496e7347..7c90753e255e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -324,6 +324,42 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
   return %1 : !shape.shape
 }
 
+// -----
+// assuming with a known passing witness can be removed
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK-NEXT: source
+  // CHECK-NEXT: sink
+  // CHECK-NEXT: return
+  %0 = shape.const_witness true
+  %1 = shape.assuming %0 -> index {
+    %2 = "test.source"() : () -> (index)
+    shape.assuming_yield %2 : index
+  }
+  "test.sink"(%1) : (index) -> ()
+  return
+}
+
+// -----
+// assuming without a known passing passing witness cannot be removed
+// CHECK-LABEL: func @f
+func @f() {
+  // CHECK-NEXT: test.source
+  // CHECK-NEXT: shape.assuming
+  // CHECK-NEXT:   test.source
+  // CHECK-NEXT:   shape.assuming_yield
+  // CHECK-NEXT: }
+  // CHECK-NEXT: test.sink
+  // CHECK-NEXT: return
+  %0 = "test.source"() : () -> (!shape.witness)
+  %1 = shape.assuming %0 -> index {
+    %2 = "test.source"() : () -> (index)
+    shape.assuming_yield %2 : index
+  }
+  "test.sink"(%1) : (index) -> ()
+  return
+}
+
 // -----
 // Broadcastable with broadcastable constant shapes can be removed.
 // CHECK-LABEL: func @f


        


More information about the Mlir-commits mailing list