[Mlir-commits] [mlir] cbcb12f - [MLIR] Handle in-place folding properly in greedy pattern rewrite driver
Uday Bondhugula
llvmlistbot at llvm.org
Sat Apr 11 07:28:05 PDT 2020
Author: Uday Bondhugula
Date: 2020-04-11T19:57:29+05:30
New Revision: cbcb12fd44dfdb51bbf4489d213d96f17be3091f
URL: https://github.com/llvm/llvm-project/commit/cbcb12fd44dfdb51bbf4489d213d96f17be3091f
DIFF: https://github.com/llvm/llvm-project/commit/cbcb12fd44dfdb51bbf4489d213d96f17be3091f.diff
LOG: [MLIR] Handle in-place folding properly in greedy pattern rewrite driver
OperatioFolder::tryToFold performs both true folding and in a few
instances in-place updates through op rewrites. In the latter case, we
should still be applying the supplied pattern rewrites in the same
iteration; however this wasn't the case since tryToFold returned
success() for both true folding and in-place updates, and the patterns
for the in-place updated ops were being applied only in the next
iteration of the driver's outer loop. This fix would make it converge
faster.
Differential Revision: https://reviews.llvm.org/D77485
Added:
Modified:
mlir/include/mlir/Transforms/FoldUtils.h
mlir/lib/Transforms/Utils/FoldUtils.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index 0bab87c5e4e3..d2ba43339ce3 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -56,11 +56,12 @@ class OperationFolder {
/// folded results, and returns success. `preReplaceAction` is invoked on `op`
/// before it is replaced. 'processGeneratedConstants' is invoked for any new
/// operations generated when folding. If the op was completely folded it is
- /// erased.
+ /// erased. If it is just updated in place, `inPlaceUpdate` is set to true.
LogicalResult
tryToFold(Operation *op,
function_ref<void(Operation *)> processGeneratedConstants = nullptr,
- function_ref<void(Operation *)> preReplaceAction = nullptr);
+ function_ref<void(Operation *)> preReplaceAction = nullptr,
+ bool *inPlaceUpdate = nullptr);
/// Notifies that the given constant `op` should be remove from this
/// OperationFolder's internal bookkeeping.
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index f2099bca75ea..9e67c2b6b348 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -74,7 +74,10 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
LogicalResult OperationFolder::tryToFold(
Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
- function_ref<void(Operation *)> preReplaceAction) {
+ function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
+ if (inPlaceUpdate)
+ *inPlaceUpdate = false;
+
// If this is a unique'd constant, return failure as we know that it has
// already been folded.
if (referencedDialects.count(op))
@@ -87,8 +90,11 @@ LogicalResult OperationFolder::tryToFold(
return failure();
// Check to see if the operation was just updated in place.
- if (results.empty())
+ if (results.empty()) {
+ if (inPlaceUpdate)
+ *inPlaceUpdate = true;
return success();
+ }
// Constant folding succeeded. We will start replacing this op's uses and
// erase this op. Invoke the callback provided by the caller to perform any
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 80ad143ce0d3..53c8e9fbd1c2 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -104,7 +104,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
// be re-added to the worklist. This function should be called when an
// operation is modified or removed, as it may trigger further
// simplifications.
- template <typename Operands> void addToWorklist(Operands &&operands) {
+ template <typename Operands>
+ void addToWorklist(Operands &&operands) {
for (Value operand : operands) {
// If the use count of this operand is now < 2, we re-add the defining
// operation to the worklist.
@@ -133,7 +134,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
};
} // end anonymous namespace
-/// Perform the rewrites while folding and erasing any dead ops.
+/// Performs the rewrites while folding and erasing any dead ops. Returns true
+/// if the rewrite converges in `maxIterations`.
bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
int maxIterations) {
// Add the given operation to the worklist.
@@ -183,9 +185,12 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
};
// Try to fold this op.
- if (succeeded(folder.tryToFold(op, collectOps, preReplaceAction))) {
+ bool inPlaceUpdate;
+ if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
+ &inPlaceUpdate)))) {
changed = true;
- continue;
+ if (!inPlaceUpdate)
+ continue;
}
// Make sure that any new operations are inserted at this point.
More information about the Mlir-commits
mailing list