[Mlir-commits] [mlir] 3324598 - [mlir] Add a pass to remove all shape.cstr_ and assuming_ ops.
Tres Popp
llvmlistbot at llvm.org
Thu Jun 18 04:32:33 PDT 2020
Author: Tres Popp
Date: 2020-06-18T13:31:30+02:00
New Revision: 3324598844a28850527b8abc8b83579ad7ab94a2
URL: https://github.com/llvm/llvm-project/commit/3324598844a28850527b8abc8b83579ad7ab94a2
DIFF: https://github.com/llvm/llvm-project/commit/3324598844a28850527b8abc8b83579ad7ab94a2.diff
LOG: [mlir] Add a pass to remove all shape.cstr_ and assuming_ ops.
Summary:
This is to provide a utility to remove unsupported constraints or for
pipelines that happen to receive these but cannot lower them due to not
supporting assertions.
Differential Revision: https://reviews.llvm.org/D81560
Added:
mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
mlir/test/Dialect/Shape/remove-shape-constraints.mlir
Modified:
mlir/include/mlir/Dialect/Shape/IR/Shape.h
mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
index 994975834d5d..b78b7304592d 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h
+++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h
@@ -22,6 +22,8 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir {
+class PatternRewriter;
+
namespace shape {
namespace ShapeTypes {
diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
index a6f579cec505..49714c96b1cd 100644
--- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
+++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
@@ -375,7 +375,7 @@ def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
let hasFolder = 1;
}
-def Shape_YieldOp : Shape_Op<"yield",
+def Shape_YieldOp : Shape_Op<"yield",
[HasParent<"ReduceOp">,
NoSideEffect,
ReturnLike,
@@ -533,6 +533,14 @@ def Shape_AssumingOp : Shape_Op<"assuming",
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
+ let extraClassDeclaration = [{
+ // Inline the region into the region containing the AssumingOp and delete
+ // the AssumingOp.
+ //
+ // This does no checks on the inputs to the AssumingOp.
+ static void inlineRegionIntoParent(AssumingOp &op, PatternRewriter &rewriter);
+ }];
+
let hasCanonicalizer = 1;
}
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index 7e6065341608..e8d2167916d0 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -18,6 +18,7 @@
namespace mlir {
+class FunctionPass;
class MLIRContext;
class OwningRewritePatternList;
class Pass;
@@ -30,6 +31,17 @@ std::unique_ptr<Pass> createShapeToShapeLowering();
/// Collects a set of patterns to rewrite ops within the Shape dialect.
void populateShapeRewritePatterns(MLIRContext *context,
OwningRewritePatternList &patterns);
+
+// Collects a set of patterns to replace all constraints with passing witnesses.
+// This is intended to then allow all ShapeConstraint related ops and data to
+// have no effects and allow them to be freely removed such as through
+// canonicalization and dead code elimination.
+//
+// After this pass, no cstr_ operations exist.
+void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns,
+ MLIRContext *ctx);
+std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
+
} // end namespace mlir
#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
index 46dc4dc37160..022bd3773ce2 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td
@@ -11,6 +11,11 @@
include "mlir/Pass/PassBase.td"
+def RemoveShapeConstraints : FunctionPass<"remove-shape-constraints"> {
+ let summary = "Replace all cstr_ ops with a true witness";
+ let constructor = "mlir::createRemoveShapeConstraintsPass()";
+}
+
def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> {
let summary = "Legalize Shape dialect to be convertible to Standard";
let constructor = "mlir::createShapeToShapeLowering()";
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 4a876e16bcec..664c0cb05b80 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -168,22 +168,7 @@ struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
if (!witness || !witness.passingAttr())
return failure();
- auto *blockBeforeAssuming = rewriter.getInsertionBlock();
- auto *assumingBlock = op.getBody();
- auto initPosition = rewriter.getInsertionPoint();
- auto *blockAfterAssuming =
- rewriter.splitBlock(blockBeforeAssuming, initPosition);
-
- // Remove the AssumingOp and AssumingYieldOp.
- auto &yieldOp = assumingBlock->back();
- rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
- rewriter.replaceOp(op, yieldOp.getOperands());
- rewriter.eraseOp(&yieldOp);
-
- // Merge blocks together as there was no branching behavior from the
- // AssumingOp.
- rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
- rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
+ AssumingOp::inlineRegionIntoParent(op, rewriter);
return success();
}
};
@@ -191,10 +176,30 @@ struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
void AssumingOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns,
MLIRContext *context) {
- // If taking a passing witness, inline region
+ // If taking a passing witness, inline region.
patterns.insert<AssumingWithTrue>(context);
}
+void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
+ PatternRewriter &rewriter) {
+ auto *blockBeforeAssuming = rewriter.getInsertionBlock();
+ auto *assumingBlock = op.getBody();
+ auto initPosition = rewriter.getInsertionPoint();
+ auto *blockAfterAssuming =
+ rewriter.splitBlock(blockBeforeAssuming, initPosition);
+
+ // Remove the AssumingOp and AssumingYieldOp.
+ auto &yieldOp = assumingBlock->back();
+ rewriter.inlineRegionBefore(op.doRegion(), blockAfterAssuming);
+ rewriter.replaceOp(op, yieldOp.getOperands());
+ rewriter.eraseOp(&yieldOp);
+
+ // Merge blocks together as there was no branching behavior from the
+ // AssumingOp.
+ rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
+ rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
+}
+
//===----------------------------------------------------------------------===//
// AssumingAllOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
index 6f812b60658e..987f9c544b33 100644
--- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRShapeOpsTransforms
+ RemoveShapeConstraints.cpp
ShapeToShapeLowering.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
new file mode 100644
index 000000000000..641b4bc38e43
--- /dev/null
+++ b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
@@ -0,0 +1,64 @@
+//===-- RemoveShapeConstraints.cpp - Remove Shape Cstr and Assuming Ops ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+/// Removal patterns.
+class RemoveCstrBroadcastableOp
+ : public OpRewritePattern<shape::CstrBroadcastableOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
+ PatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
+ return success();
+ }
+};
+
+class RemoveCstrEqOp : public OpRewritePattern<shape::CstrEqOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(shape::CstrEqOp op,
+ PatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op.getOperation(), true);
+ return success();
+ }
+};
+
+/// Removal pass.
+class RemoveShapeConstraintsPass
+ : public RemoveShapeConstraintsBase<RemoveShapeConstraintsPass> {
+
+ void runOnFunction() override {
+ MLIRContext &ctx = getContext();
+
+ OwningRewritePatternList patterns;
+ populateRemoveShapeConstraintsPatterns(patterns, &ctx);
+
+ applyPatternsAndFoldGreedily(getFunction(), patterns);
+ }
+};
+
+} // namespace
+
+void mlir::populateRemoveShapeConstraintsPatterns(
+ OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ patterns.insert<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(ctx);
+}
+
+std::unique_ptr<FunctionPass> mlir::createRemoveShapeConstraintsPass() {
+ return std::make_unique<RemoveShapeConstraintsPass>();
+}
diff --git a/mlir/test/Dialect/Shape/remove-shape-constraints.mlir b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
new file mode 100644
index 000000000000..69887c6994f4
--- /dev/null
+++ b/mlir/test/Dialect/Shape/remove-shape-constraints.mlir
@@ -0,0 +1,56 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -remove-shape-constraints -canonicalize <%s | FileCheck %s --dump-input=fail --check-prefixes=CANON,CHECK-BOTH
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -remove-shape-constraints <%s | FileCheck %s --dump-input=fail --check-prefixes=REPLACE,CHECK-BOTH
+
+// -----
+// Check that cstr_broadcastable is removed.
+//
+// CHECK-BOTH: func @f
+func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
+ // REPLACE-NEXT: %[[WITNESS:.+]] = shape.const_witness true
+ // REPLACE-NOT: shape.cstr_eq
+ // REPLACE: shape.assuming %[[WITNESS]]
+ // CANON-NEXT: test.source
+ // CANON-NEXT: return
+ %0 = shape.cstr_broadcastable %arg0, %arg1
+ %1 = shape.assuming %0 -> index {
+ %2 = "test.source"() : () -> (index)
+ shape.assuming_yield %2 : index
+ }
+ return %1 : index
+}
+
+// -----
+// Check that cstr_eq is removed.
+//
+// CHECK-BOTH: func @f
+func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
+ // REPLACE-NEXT: %[[WITNESS:.+]] = shape.const_witness true
+ // REPLACE-NOT: shape.cstr_eq
+ // REPLACE: shape.assuming %[[WITNESS]]
+ // CANON-NEXT: test.source
+ // CANON-NEXT: return
+ %0 = shape.cstr_eq %arg0, %arg1
+ %1 = shape.assuming %0 -> index {
+ %2 = "test.source"() : () -> (index)
+ shape.assuming_yield %2 : index
+ }
+ return %1 : index
+}
+
+// -----
+// With a non-const value, we cannot fold away the code, but all constraints
+// should be removed still.
+//
+// CHECK-BOTH: func @f
+func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
+ // CANON-NEXT: test.source
+ // CANON-NEXT: return
+ %0 = shape.cstr_broadcastable %arg0, %arg1
+ %1 = shape.cstr_eq %arg0, %arg1
+ %2 = shape.assuming_all %0, %1
+ %3 = shape.assuming %0 -> index {
+ %4 = "test.source"() : () -> (index)
+ shape.assuming_yield %4 : index
+ }
+ return %3 : index
+}
More information about the Mlir-commits
mailing list