[Mlir-commits] [mlir] 655e08c - [mlir] Canonicalization of shape.assuming
Tres Popp
llvmlistbot at llvm.org
Fri Jun 5 02:01:03 PDT 2020
Author: Tres Popp
Date: 2020-06-05T11:00:20+02:00
New Revision: 655e08ceeb7bf908cc5460279acbe2882bd47c91
URL: https://github.com/llvm/llvm-project/commit/655e08ceeb7bf908cc5460279acbe2882bd47c91
DIFF: https://github.com/llvm/llvm-project/commit/655e08ceeb7bf908cc5460279acbe2882bd47c91.diff
LOG: [mlir] Canonicalization of shape.assuming
Summary:
This will inline the region to a shape.assuming in the case that the
input witness is found to be statically true.
Differential Revision: https://reviews.llvm.org/D80302
Added:
Modified:
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index 6fb7cbf1f3b7..88ee5f4d520b 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -509,6 +509,8 @@ def Shape_AssumingOp : Shape_Op<"assuming",
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
+
+ let hasCanonicalizer = 1;
}
def Shape_AssumingYieldOp : Shape_Op<"assuming_yield",
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index e12e23ba128c..5866b0ac2680 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -159,6 +159,44 @@ static void print(OpAsmPrinter &p, AssumingOp op) {
p.printOptionalAttrDict(op.getAttrs());
}
+namespace {
+// Removes AssumingOp with a passing witness and inlines the region.
+struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
+ using OpRewritePattern<AssumingOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(AssumingOp op,
+ PatternRewriter &rewriter) const override {
+ auto witness = op.witness().getDefiningOp<ConstWitnessOp>();
+ if (!witness || !witness.passingAttr())
+ return failure();
+
+ auto *blockBeforeAssuming = rewriter.getInsertionBlock();
+ auto *assumingBlock = op.getBody();
+ auto initPosition = rewriter.getInsertionPoint();
+ auto *blockAfterAssuming =
+ rewriter.splitBlock(blockBeforeAssuming, initPosition);
+
+ // Remove the AssumingOp and AssumingYieldOp.
+ auto &yieldOp = assumingBlock->back();
+ rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
+ rewriter.replaceOp(op, yieldOp.getOperands());
+ rewriter.eraseOp(&yieldOp);
+
+ // Merge blocks together as there was no branching behavior from the
+ // AssumingOp.
+ rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
+ rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
+ return success();
+ }
+};
+}; // namespace
+
+void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *context) {
+ // If taking a passing witness, inline region
+ patterns.insert<AssumingWithTrue>(context);
+}
+
//===----------------------------------------------------------------------===//
// AssumingAllOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index 32fa496e7347..7c90753e255e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -324,6 +324,42 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
return %1 : !shape.shape
}
+// -----
+// assuming with a known passing witness can be removed
+// CHECK-LABEL: func @f
+func @f() {
+ // CHECK-NEXT: source
+ // CHECK-NEXT: sink
+ // CHECK-NEXT: return
+ %0 = shape.const_witness true
+ %1 = shape.assuming %0 -> index {
+ %2 = "test.source"() : () -> (index)
+ shape.assuming_yield %2 : index
+ }
+ "test.sink"(%1) : (index) -> ()
+ return
+}
+
+// -----
+// assuming without a known passing passing witness cannot be removed
+// CHECK-LABEL: func @f
+func @f() {
+ // CHECK-NEXT: test.source
+ // CHECK-NEXT: shape.assuming
+ // CHECK-NEXT: test.source
+ // CHECK-NEXT: shape.assuming_yield
+ // CHECK-NEXT: }
+ // CHECK-NEXT: test.sink
+ // CHECK-NEXT: return
+ %0 = "test.source"() : () -> (!shape.witness)
+ %1 = shape.assuming %0 -> index {
+ %2 = "test.source"() : () -> (index)
+ shape.assuming_yield %2 : index
+ }
+ "test.sink"(%1) : (index) -> ()
+ return
+}
+
// -----
// Broadcastable with broadcastable constant shapes can be removed.
// CHECK-LABEL: func @f
More information about the Mlir-commits
mailing list