[Mlir-commits] [mlir] 7932d21 - [MLIR] Introduce a new rewrite driver to simplify supplied list of ops
Uday Bondhugula
llvmlistbot at llvm.org
Wed Jul 21 07:56:02 PDT 2021
Author: Uday Bondhugula
Date: 2021-07-21T20:25:16+05:30
New Revision: 7932d21f5d795230ce9f8e74415fddcc29d91642
URL: https://github.com/llvm/llvm-project/commit/7932d21f5d795230ce9f8e74415fddcc29d91642
DIFF: https://github.com/llvm/llvm-project/commit/7932d21f5d795230ce9f8e74415fddcc29d91642.diff
LOG: [MLIR] Introduce a new rewrite driver to simplify supplied list of ops
Introduce a new rewrite driver (MultiOpPatternRewriteDriver) to rewrite
a supplied list of ops and other ops. Provide a knob to restrict
rewrites strictly to those ops or also to affected ops (but still not to
completely related ops).
This rewrite driver is commonly needed to run any simplification and
cleanup at the end of a transforms pass or transforms utility in a way
that only simplifies relevant IR. This makes it easy to write test cases
while not performing unrelated whole IR simplification that may
invalidate other state at the caller.
The introduced utility provides more freedom to developers of transforms
and transform utilities to perform focussed and local simplification. In
several cases, it provides greater efficiency as well as more
simplification when compared to repeatedly calling
`applyOpPatternsAndFold`; in other cases, it avoids the need to
undesirably call `applyPatternsAndFoldGreedily` to do unrelated
simplification in a FuncOp.
Update a few transformations that were earlier using
applyOpPatternsAndFold (SimplifyAffineStructures,
affineDataCopyGenerate, a linalg transform).
TODO:
- OpPatternRewriteDriver can be removed as it's a special case of
MultiOpPatternRewriteDriver, i.e., both can be merged.
Differential Revision: https://reviews.llvm.org/D106232
Added:
Modified:
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/Dialect/Affine/simplify-affine-structures.mlir
mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
index 2ef8501bdc0c5..6be3b3fd384b8 100644
--- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
+++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
@@ -77,6 +77,20 @@ LogicalResult applyOpPatternsAndFold(Operation *op,
const FrozenRewritePatternSet &patterns,
bool *erased = nullptr);
+/// Applies the specified rewrite patterns on `ops` while also trying to fold
+/// these ops as well as any other ops that were in turn created due to such
+/// rewrites. Furthermore, any pre-existing ops in the IR outside of `ops`
+/// remain completely unmodified if `strict` is set to true. If `strict` is
+/// false, other operations that use results of rewritten ops or supply operands
+/// to such ops are in turn simplified; any other ops still remain unmodified
+/// (i.e., regardless of `strict`). Note that ops in `ops` could be erased as a
+/// result of folding, becoming dead, or via pattern rewrites. If more far
+/// reaching simplification is desired, applyPatternsAndFoldGreedily should be
+/// used. Returns true if at all any IR was rewritten.
+bool applyOpPatternsAndFold(ArrayRef<Operation *> ops,
+ const FrozenRewritePatternSet &patterns,
+ bool strict);
+
} // end namespace mlir
#endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 851ec5051a6be..52d3884d90915 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -231,6 +231,5 @@ void AffineDataCopyGeneration::runOnFunction() {
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
- for (Operation *op : copyOps)
- (void)applyOpPatternsAndFold(op, frozenPatterns);
+ (void)applyOpPatternsAndFold(copyOps, frozenPatterns, /*strict=*/true);
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 8f59074e6b791..cc1e89f93dc22 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -80,10 +80,14 @@ void SimplifyAffineStructures::runOnFunction() {
auto func = getFunction();
simplifiedAttributes.clear();
RewritePatternSet patterns(func.getContext());
+ AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext());
AffineForOp::getCanonicalizationPatterns(patterns, func.getContext());
AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext());
- AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+
+ // The simplification of affine attributes will likely simplify the op. Try to
+ // fold/apply canonicalization patterns when we have affine dialect ops.
+ SmallVector<Operation *> opsToSimplify;
func.walk([&](Operation *op) {
for (auto attr : op->getAttrs()) {
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>())
@@ -92,9 +96,8 @@ void SimplifyAffineStructures::runOnFunction() {
simplifyAndUpdateAttribute(op, attr.first, setAttr);
}
- // The simplification of the attribute will likely simplify the op. Try to
- // fold / apply canonicalization patterns when we have affine dialect ops.
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
- (void)applyOpPatternsAndFold(op, frozenPatterns);
+ opsToSimplify.push_back(op);
});
+ (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, /*strict=*/true);
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 5b028d63b2379..350b19bae37e9 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -81,6 +81,25 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
// inserted ops are added to the worklist for processing.
void notifyOperationInserted(Operation *op) override { addToWorklist(op); }
+ // Look over the provided operands for any defining operations that should
+ // be re-added to the worklist. This function should be called when an
+ // operation is modified or removed, as it may trigger further
+ // simplifications.
+ template <typename Operands>
+ void addToWorklist(Operands &&operands) {
+ for (Value operand : operands) {
+ // If the use count of this operand is now < 2, we re-add the defining
+ // operation to the worklist.
+ // TODO: This is based on the fact that zero use operations
+ // may be deleted, and that single use values often have more
+ // canonicalization opportunities.
+ if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
+ continue;
+ if (auto *defOp = operand.getDefiningOp())
+ addToWorklist(defOp);
+ }
+ }
+
// If an operation is about to be removed, make sure it is not in our
// worklist anymore because we'd get dangling references to it.
void notifyOperationRemoved(Operation *op) override {
@@ -100,26 +119,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
addToWorklist(user);
}
-private:
- // Look over the provided operands for any defining operations that should
- // be re-added to the worklist. This function should be called when an
- // operation is modified or removed, as it may trigger further
- // simplifications.
- template <typename Operands>
- void addToWorklist(Operands &&operands) {
- for (Value operand : operands) {
- // If the use count of this operand is now < 2, we re-add the defining
- // operation to the worklist.
- // TODO: This is based on the fact that zero use operations
- // may be deleted, and that single use values often have more
- // canonicalization opportunities.
- if (!operand || (!operand.use_empty() && !operand.hasOneUse()))
- continue;
- if (auto *defInst = operand.getDefiningOp())
- addToWorklist(defInst);
- }
- }
-
/// The low-level pattern applicator.
PatternApplicator matcher;
@@ -133,6 +132,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
/// Non-pattern based folder for operations.
OperationFolder folder;
+private:
/// Configuration information for how to simplify.
GreedyRewriteConfig config;
};
@@ -277,11 +277,6 @@ class OpPatternRewriteDriver : public PatternRewriter {
matcher.applyDefaultCostModel();
}
- /// Performs the rewrites and folding only on `op`. The simplification
- /// converges if the op is erased as a result of being folded, replaced, or
- /// dead, or no more changes happen in an iteration. Returns success if the
- /// rewrite converges in `maxIterations`. `erased` is set to true if `op` gets
- /// erased.
LogicalResult simplifyLocally(Operation *op, int maxIterations, bool &erased);
// These are hooks implemented for PatternRewriter.
@@ -309,13 +304,18 @@ class OpPatternRewriteDriver : public PatternRewriter {
} // anonymous namespace
+/// Performs the rewrites and folding only on `op`. The simplification
+/// converges if the op is erased as a result of being folded, replaced, or
+/// becoming dead, or no more changes happen in an iteration. Returns success if
+/// the rewrite converges in `maxIterations`. `erased` is set to true if `op`
+/// gets erased.
LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
int maxIterations,
bool &erased) {
bool changed = false;
erased = false;
opErasedViaPatternRewrites = false;
- int i = 0;
+ int iterations = 0;
// Iterate until convergence or until maxIterations. Deletion of the op as
// a result of being dead or folded is convergence.
do {
@@ -345,12 +345,162 @@ LogicalResult OpPatternRewriteDriver::simplifyLocally(Operation *op,
changed |= succeeded(matcher.matchAndRewrite(op, *this));
if ((erased = opErasedViaPatternRewrites))
return success();
- } while (changed && ++i < maxIterations);
+ } while (changed && ++iterations < maxIterations);
// Whether the rewrite converges, i.e. wasn't changed in the last iteration.
return failure(changed);
}
+//===----------------------------------------------------------------------===//
+// MultiOpPatternRewriteDriver
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// This is a specialized GreedyPatternRewriteDriver to apply patterns and
+/// perform folding for a supplied set of ops. It repeatedly simplifies while
+/// restricting the rewrites to only the provided set of ops or optionally
+/// to those directly affected by it (result users or operand providers).
+class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
+public:
+ explicit MultiOpPatternRewriteDriver(MLIRContext *ctx,
+ const FrozenRewritePatternSet &patterns,
+ bool strict)
+ : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()),
+ strictMode(strict) {}
+
+ bool simplifyLocally(ArrayRef<Operation *> op);
+
+private:
+ // Look over the provided operands for any defining operations that should
+ // be re-added to the worklist. This function should be called when an
+ // operation is modified or removed, as it may trigger further
+ // simplifications. If `strict` is set to true, only ops in
+ // `strictModeFilteredOps` are considered.
+ template <typename Operands>
+ void addOperandsToWorklist(Operands &&operands) {
+ for (Value operand : operands) {
+ if (auto *defOp = operand.getDefiningOp()) {
+ if (!strictMode || strictModeFilteredOps.contains(defOp))
+ addToWorklist(defOp);
+ }
+ }
+ }
+
+ void notifyOperationRemoved(Operation *op) override {
+ GreedyPatternRewriteDriver::notifyOperationRemoved(op);
+ if (strictMode)
+ strictModeFilteredOps.erase(op);
+ }
+
+ /// If `strictMode` is true, any pre-existing ops outside of
+ /// `strictModeFilteredOps` remain completely untouched by the rewrite driver.
+ /// If `strictMode` is false, operations that use results of (or supply
+ /// operands to) any rewritten ops stemming from the simplification of the
+ /// provided ops are in turn simplified; any other ops still remain untouched
+ /// (i.e., regardless of `strictMode`).
+ bool strictMode = false;
+
+ /// The list of ops we are restricting our rewrites to if `strictMode` is on.
+ /// These include the supplied set of ops as well as new ops created while
+ /// rewriting those ops. This set is not maintained when strictMode is off.
+ llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
+};
+
+} // end anonymous namespace
+
+/// Performs the specified rewrites on `ops` while also trying to fold these ops
+/// as well as any other ops that were in turn created due to these rewrite
+/// patterns. Any pre-existing ops outside of `ops` remain completely
+/// unmodified if `strictMode` is true. If `strictMode` is false, other
+/// operations that use results of rewritten ops or supply operands to such ops
+/// are in turn simplified; any other ops still remain unmodified (i.e.,
+/// regardless of `strictMode`). Note that ops in `ops` could be erased as a
+/// result of folding, becoming dead, or via pattern rewrites. Returns true if
+/// at all any changes happened.
+// Unlike `OpPatternRewriteDriver::simplifyLocally` which works on a single op
+// or GreedyPatternRewriteDriver::simplify, this method just iterates until
+// the worklist is empty. As our objective is to keep simplification "local",
+// there is no strong rationale to re-add all operations into the worklist and
+// rerun until an iteration changes nothing. If more widereaching simplification
+// is desired, GreedyPatternRewriteDriver should be used.
+bool MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops) {
+ if (strictMode) {
+ strictModeFilteredOps.clear();
+ strictModeFilteredOps.insert(ops.begin(), ops.end());
+ }
+
+ bool changed = false;
+ worklist.clear();
+ worklistMap.clear();
+ for (Operation *op : ops)
+ addToWorklist(op);
+
+ // These are scratch vectors used in the folding loop below.
+ SmallVector<Value, 8> originalOperands, resultValues;
+ while (!worklist.empty()) {
+ Operation *op = popFromWorklist();
+
+ // Nulls get added to the worklist when operations are removed, ignore
+ // them.
+ if (op == nullptr)
+ continue;
+
+ // If the operation is trivially dead - remove it.
+ if (isOpTriviallyDead(op)) {
+ notifyOperationRemoved(op);
+ op->erase();
+ changed = true;
+ continue;
+ }
+
+ // Collects all the operands and result uses of the given `op` into work
+ // list. Also remove `op` and nested ops from worklist.
+ originalOperands.assign(op->operand_begin(), op->operand_end());
+ auto preReplaceAction = [&](Operation *op) {
+ // Add the operands to the worklist for visitation.
+ addOperandsToWorklist(originalOperands);
+
+ // Add all the users of the result to the worklist so we make sure
+ // to revisit them.
+ for (Value result : op->getResults())
+ for (Operation *userOp : result.getUsers()) {
+ if (!strictMode || strictModeFilteredOps.contains(userOp))
+ addToWorklist(userOp);
+ }
+ notifyOperationRemoved(op);
+ };
+
+ // Add the given operation generated by the folder to the worklist.
+ auto processGeneratedConstants = [this](Operation *op) {
+ // Newly created ops are also simplified -- these are also "local".
+ addToWorklist(op);
+ // When strict mode is off, we don't need to maintain
+ // strictModeFilteredOps.
+ if (strictMode)
+ strictModeFilteredOps.insert(op);
+ };
+
+ // Try to fold this op.
+ bool inPlaceUpdate;
+ if (succeeded(folder.tryToFold(op, processGeneratedConstants,
+ preReplaceAction, &inPlaceUpdate))) {
+ changed = true;
+ if (!inPlaceUpdate) {
+ // Op has been erased.
+ continue;
+ }
+ }
+
+ // Try to match one of the patterns. The rewriter is automatically
+ // notified of any necessary changes, so there is nothing else to do
+ // here.
+ changed |= succeeded(matcher.matchAndRewrite(op, *this));
+ }
+
+ return changed;
+}
+
/// Rewrites only `op` using the supplied canonicalization patterns and
/// folding. `erased` is set to true if the op is erased as a result of being
/// folded, replaced, or dead.
@@ -370,3 +520,15 @@ LogicalResult mlir::applyOpPatternsAndFold(
});
return converged;
}
+
+bool mlir::applyOpPatternsAndFold(ArrayRef<Operation *> ops,
+ const FrozenRewritePatternSet &patterns,
+ bool strict) {
+ if (ops.empty())
+ return false;
+
+ // Start the pattern driver.
+ MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
+ strict);
+ return driver.simplifyLocally(ops);
+}
diff --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
index 6d1b806441738..6867e16d3c440 100644
--- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir
@@ -261,12 +261,12 @@ func @simplify_set(%a : index, %b : index) {
// CHECK-DAG: -> (s0 * 2 + 1)
// Test "op local" simplification on affine.apply. DCE on addi will not happen.
-func @affine.apply(%N : index) {
+func @affine.apply(%N : index) -> index {
%v = affine.apply affine_map<(d0, d1) -> (d0 + d1 + 1)>(%N, %N)
- addi %v, %v : index
+ %res = addi %v, %v : index
// CHECK: affine.apply #map{{.*}}()[%arg0]
// CHECK-NEXT: addi
- return
+ return %res: index
}
// -----
diff --git a/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
index 200ed84759505..89eca4c49ac98 100644
--- a/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
+++ b/mlir/test/Dialect/Linalg/fold-affine-min-scf.mlir
@@ -10,10 +10,9 @@ func @scf_for(%A : memref<i64>, %step : index) {
%c16 = constant 16 : index
%c1024 = constant 1024 : index
+ // CHECK: %[[C2:.*]] = constant 2 : i64
// CHECK: scf.for
- // CHECK-NEXT: %[[C2:.*]] = constant 2 : index
- // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]]
- // CHECK-NEXT: memref.store %[[C2I64]], %{{.*}}[] : memref<i64>
+ // CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref<i64>
scf.for %i = %c0 to %c4 step %c2 {
%1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4)
%2 = index_cast %1: index to i64
@@ -21,9 +20,7 @@ func @scf_for(%A : memref<i64>, %step : index) {
}
// CHECK: scf.for
- // CHECK-NEXT: %[[C2:.*]] = constant 2 : index
- // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]]
- // CHECK-NEXT: memref.store %[[C2I64]], %{{.*}}[] : memref<i64>
+ // CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref<i64>
scf.for %i = %c1 to %c7 step %c2 {
%1 = affine.min affine_map<(d0)[s0] -> (s0 - d0, 2)> (%i)[%c7]
%2 = index_cast %1: index to i64
@@ -93,10 +90,9 @@ func @scf_parallel(%A : memref<i64>, %step : index) {
%c7 = constant 7 : index
%c4 = constant 4 : index
+ // CHECK: %[[C2:.*]] = constant 2 : i64
// CHECK: scf.parallel
- // CHECK-NEXT: %[[C2:.*]] = constant 2 : index
- // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]]
- // CHECK-NEXT: memref.store %[[C2I64]], %{{.*}}[] : memref<i64>
+ // CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref<i64>
scf.parallel (%i) = (%c0) to (%c4) step (%c2) {
%1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4)
%2 = index_cast %1: index to i64
@@ -104,9 +100,7 @@ func @scf_parallel(%A : memref<i64>, %step : index) {
}
// CHECK: scf.parallel
- // CHECK-NEXT: %[[C2:.*]] = constant 2 : index
- // CHECK-NEXT: %[[C2I64:.*]] = index_cast %[[C2:.*]]
- // CHECK-NEXT: memref.store %[[C2I64]], %{{.*}}[] : memref<i64>
+ // CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref<i64>
scf.parallel (%i) = (%c1) to (%c7) step (%c2) {
%1 = affine.min affine_map<(d0)[s0] -> (2, s0 - d0)> (%i)[%c7]
%2 = index_cast %1: index to i64
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index 15e3d299e0e4f..332f1fe635d67 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -126,8 +126,8 @@ void TestAffineDataCopy::runOnFunction() {
assert(isa<AffineStoreOp>(op) && "expected affine store op");
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
}
- (void)applyOpPatternsAndFold(op, std::move(patterns));
}
+ (void)applyOpPatternsAndFold(copyOps, std::move(patterns), /*strict=*/true);
}
namespace mlir {
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 8a8ce298cc6af..93ae2be682d86 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -551,11 +551,11 @@ static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
FrozenRewritePatternSet frozenPatterns(std::move(foldPattern));
- // Explicitly walk and apply the pattern locally to avoid more general folding
+ // Explicitly apply the pattern on affected ops to avoid more general folding
// on the rest of the IR.
- funcOp.walk([&frozenPatterns](AffineMinOp minOp) {
- (void)applyOpPatternsAndFold(minOp, frozenPatterns);
- });
+ SmallVector<Operation *, 4> minOps;
+ funcOp.walk([&](AffineMinOp minOp) { minOps.push_back(minOp); });
+ (void)applyOpPatternsAndFold(minOps, frozenPatterns, /*strict=*/false);
}
// For now, just assume it is the zero of type.
More information about the Mlir-commits
mailing list