[Mlir-commits] [mlir] ed9194b - [mlir] GreedyPatternRewriter: Add ancestors to worklist
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 13 01:51:42 PST 2023
Author: Matthias Springer
Date: 2023-01-13T10:51:28+01:00
New Revision: ed9194be6d55776cbbf8432b93f0f23ec7b25a46
URL: https://github.com/llvm/llvm-project/commit/ed9194be6d55776cbbf8432b93f0f23ec7b25a46
DIFF: https://github.com/llvm/llvm-project/commit/ed9194be6d55776cbbf8432b93f0f23ec7b25a46.diff
LOG: [mlir] GreedyPatternRewriter: Add ancestors to worklist
When adding an op to the worklist, also add its ancestors to the worklist. This allows for RewritePatterns to match an op `a` based on what is inside of the body of `a`.
This change fixes a problem that became apparent with `vector.warp_execute_on_lane_0`, but could probably be triggered with similar patterns. The pattern extracts an op `b` with `eligible = true` from the body of an op `a`:
```
test.a {
%0 = test.b() {eligible = true}
yield %0
}
```
Afterwards:
```
%0 = test.b() {eligible = true}
test.a {
yield %0
}
```
The pattern is an `OpRewritePattern<OpA>`. For some reason, `test.a` is not on the GreedyPatternRewriter's worklist. E.g., because no pattern could be applied and it was removed. Now, another pattern updates `test.b`, so that `eligible` is changed from `true` to `false`. The `OpRewritePattern<OpA>` could now be applied, but (without this revision) `test.a` is still not on the worklist.
Note: In the above example, an `OpRewritePattern<OpB>` could have been used instead of an `OpRewritePattern<OpA>`. With such a design, we can run into the same problem (when the `eligible` attr is on `test.a` and `test.b` is removed from the worklist because no patterns could be applied).
Note: This change uncovered an unrelated bug in TestSCFUtils.cpp that was triggered due to a change in the order in which ops are processed. A TODO is added to the broken code and test cases are adapted so that the bug is no longer triggered.
Differential Revision: https://reviews.llvm.org/D140304
Added:
Modified:
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
mlir/test/Dialect/Math/expand-math.mlir
mlir/test/Dialect/SCF/loop-pipelining.mlir
mlir/test/IR/greedy-pattern-rewriter-driver.mlir
mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cdb0b78c7a74e..2cf895e271e60 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -42,7 +42,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
/// Simplify the operations within the given regions.
bool simplify(MutableArrayRef<Region> regions);
- /// Add the given operation to the worklist.
+ /// Add the given operation to the worklist. Parent ops may or may not be
+ /// added to the worklist, depending on the type of rewrite driver. By
+ /// default, parent ops are added.
virtual void addToWorklist(Operation *op);
/// Pop the next operation from the worklist.
@@ -56,6 +58,9 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
void finalizeRootUpdate(Operation *op) override;
protected:
+ /// Add the given operation to the worklist.
+ void addSingleOpToWorklist(Operation *op);
+
// Implement the hook for inserting operations, and make sure that newly
// inserted ops are added to the worklist for processing.
void notifyOperationInserted(Operation *op) override;
@@ -101,6 +106,10 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
GreedyRewriteConfig config;
private:
+ /// Only ops within this scope are simplified. This is set at the beginning
+ /// of `simplify()` to the current scope the rewriter operates on.
+ DenseSet<Region *> scope;
+
#ifndef NDEBUG
/// A logger used to emit information during the application process.
llvm::ScopedPrinter logger{llvm::dbgs()};
@@ -119,6 +128,9 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
}
bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
+ for (Region &r : regions)
+ scope.insert(&r);
+
#ifndef NDEBUG
const char *logLineComment =
"//===-------------------------------------------===//\n";
@@ -306,6 +318,24 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions) {
}
void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
+ // Gather potential ancestors while looking for a "scope" parent region.
+ SmallVector<Operation *, 8> ancestors;
+ ancestors.push_back(op);
+ while (Region *region = op->getParentRegion()) {
+ if (scope.contains(region)) {
+ // All gathered ops are in fact ancestors.
+ for (Operation *op : ancestors)
+ addSingleOpToWorklist(op);
+ break;
+ }
+ op = region->getParentOp();
+ if (!op)
+ break;
+ ancestors.push_back(op);
+ }
+}
+
+void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
// Check to see if the worklist already contains this op.
if (worklistMap.count(op))
return;
@@ -540,7 +570,8 @@ namespace {
/// This is a specialized GreedyPatternRewriteDriver to apply patterns and
/// perform folding for a supplied set of ops. It repeatedly simplifies while
/// restricting the rewrites to only the provided set of ops or optionally
-/// to those directly affected by it (result users or operand providers).
+/// to those directly affected by it (result users or operand providers). Parent
+/// ops are not considered.
class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
explicit MultiOpPatternRewriteDriver(MLIRContext *ctx,
@@ -553,7 +584,7 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
void addToWorklist(Operation *op) override {
if (!strictMode || strictModeFilteredOps.contains(op))
- GreedyPatternRewriteDriver::addToWorklist(op);
+ GreedyPatternRewriteDriver::addSingleOpToWorklist(op);
}
private:
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 0d7a63589ee47..bfd33d5040457 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -22,13 +22,13 @@ func.func @tanh(%arg: f32) -> f32 {
// CHECK: %[[RESULT:.+]] = arith.select %[[COND]], %[[RES1]], %[[RES2]] : f32
// CHECK: return %[[RESULT]]
-// ----
+// -----
// CHECK-LABEL: func @ctlz
func.func @ctlz(%arg: i32) -> i32 {
- // CHECK: %[[C0:.+]] = arith.constant 0 : i32
- // CHECK: %[[C32:.+]] = arith.constant 32 : i32
- // CHECK: %[[C1:.+]] = arith.constant 1 : i32
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32
+ // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i32
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32
// CHECK: %[[WHILE:.+]]:3 = scf.while (%[[A1:.+]] = %arg0, %[[A2:.+]] = %[[C32]], %[[A3:.+]] = %[[C0]])
// CHECK: %[[CMP:.+]] = arith.cmpi ne, %[[A1]], %[[A3]]
// CHECK: scf.condition(%[[CMP]]) %[[A1]], %[[A2]], %[[A3]]
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index ead1e71cc29dd..68b513362a250 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -417,7 +417,7 @@ func.func @loop_carried(%A: memref<?xf32>, %result: memref<?xf32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[CSTF:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[CSTF:.*]] = arith.constant 2.000000e+00 : f32
// Prologue:
// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
// CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[L0]], %[[CSTF]] : f32
@@ -426,19 +426,22 @@ func.func @loop_carried(%A: memref<?xf32>, %result: memref<?xf32>) {
// CHECK-NEXT: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]]
// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
// CHECK-SAME: %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) {
-// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[LARG]], %[[ADDARG]] : f32
+// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[ADDARG]], %[[CSTF]] : f32
+// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[LARG]], %[[MUL0]] : f32
// CHECK-NEXT: %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index
// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref<?xf32>
-// CHECK-NEXT: scf.yield %[[ADDARG]], %[[ADD1]], %[[L2]] : f32, f32, f32
+// CHECK-NEXT: scf.yield %[[MUL0]], %[[ADD1]], %[[L2]] : f32, f32, f32
// CHECK-NEXT: }
// Epilogue:
-// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[R]]#1 : f32
-// CHECK-NEXT: return %[[ADD2]] : f32
+// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[R]]#1, %[[CSTF]] : f32
+// CHECK-NEXT: %[[ADD2:.*]] = arith.addf %[[R]]#2, %[[MUL1]] : f32
+// CHECK-NEXT: %[[MUL2:.*]] = arith.mulf %[[ADD2]], %[[CSTF]] : f32
+// CHECK-NEXT: return %[[MUL2]] : f32
func.func @backedge_
diff erent_stage(%A: memref<?xf32>) -> f32 {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
- %cf = arith.constant 1.0 : f32
+ %cf = arith.constant 2.0 : f32
%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
%A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
@@ -455,7 +458,7 @@ func.func @backedge_
diff erent_stage(%A: memref<?xf32>) -> f32 {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
-// CHECK-DAG: %[[CSTF:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[CSTF:.*]] = arith.constant 2.000000e+00 : f32
// Prologue:
// CHECK: %[[L0:.*]] = scf.execute_region
// CHECK-NEXT: memref.load %[[A]][%[[C0]]] : memref<?xf32>
@@ -467,23 +470,26 @@ func.func @backedge_
diff erent_stage(%A: memref<?xf32>) -> f32 {
// CHECK: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]]
// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
// CHECK-SAME: %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) {
+// CHECK: %[[MUL0:.*]] = arith.mulf %[[ADDARG]], %[[CSTF]] : f32
// CHECK: %[[ADD1:.*]] = scf.execute_region
-// CHECK-NEXT: arith.addf %[[LARG]], %[[ADDARG]] : f32
+// CHECK-NEXT: arith.addf %[[LARG]], %[[MUL0]] : f32
// CHECK: %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index
// CHECK: %[[L2:.*]] = scf.execute_region
// CHECK-NEXT: memref.load %[[A]][%[[IV2]]] : memref<?xf32>
-// CHECK: scf.yield %[[ADDARG]], %[[ADD1]], %[[L2]] : f32, f32, f32
+// CHECK: scf.yield %[[MUL0]], %[[ADD1]], %[[L2]] : f32, f32, f32
// CHECK-NEXT: }
// Epilogue:
+// CHECK: %[[MUL1:.*]] = arith.mulf %[[R]]#1, %[[CSTF]] : f32
// CHECK: %[[ADD2:.*]] = scf.execute_region
-// CHECK-NEXT: arith.addf %[[R]]#2, %[[R]]#1 : f32
-// CHECK: return %[[ADD2]] : f32
+// CHECK-NEXT: arith.addf %[[R]]#2, %[[MUL1]] : f32
+// CHECK: %[[MUL2:.*]] = arith.mulf %[[ADD2]], %[[CSTF]] : f32
+// CHECK: return %[[MUL2]] : f32
func.func @region_backedge_
diff erent_stage(%A: memref<?xf32>) -> f32 {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
- %cf = arith.constant 1.0 : f32
+ %cf = arith.constant 2.0 : f32
%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
%A_elem = scf.execute_region -> f32 {
%A_elem1 = memref.load %A[%i0] : memref<?xf32>
@@ -507,7 +513,7 @@ func.func @region_backedge_
diff erent_stage(%A: memref<?xf32>) -> f32 {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
-// CHECK-DAG: %[[CSTF:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[CSTF:.*]] = arith.constant 2.000000e+00 : f32
// Prologue:
// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
// Kernel:
@@ -515,18 +521,20 @@ func.func @region_backedge_
diff erent_stage(%A: memref<?xf32>) -> f32 {
// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
// CHECK-SAME: %[[LARG:.*]] = %[[L0]]) -> (f32, f32) {
// CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[LARG]], %[[C]] : f32
+// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CSTF]] : f32
// CHECK-NEXT: %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index
// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
-// CHECK-NEXT: scf.yield %[[ADD0]], %[[L2]] : f32, f32
+// CHECK-NEXT: scf.yield %[[MUL0]], %[[L2]] : f32, f32
// CHECK-NEXT: }
// Epilogue:
// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[R]]#1, %[[R]]#0 : f32
-// CHECK-NEXT: return %[[ADD1]] : f32
+// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CSTF]] : f32
+// CHECK-NEXT: return %[[MUL1]] : f32
func.func @backedge_same_stage(%A: memref<?xf32>) -> f32 {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
- %cf = arith.constant 1.0 : f32
+ %cf = arith.constant 2.0 : f32
%r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
%A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
@@ -538,7 +546,7 @@ func.func @backedge_same_stage(%A: memref<?xf32>) -> f32 {
// -----
-// CHECK: @pipeline_op_with_region(%[[ARG0:.+]]: memref<?xf32>, %[[ARG1:.+]]: memref<?xf32>, %[[ARG2:.+]]: memref<?xf32>) {
+// CHECK: @pipeline_op_with_region(%[[ARG0:.+]]: memref<?xf32>, %[[ARG1:.+]]: memref<?xf32>, %[[ARG2:.+]]: memref<?xf32>, %[[CF:.*]]: f32) {
// CHECK: %[[C0:.+]] = arith.constant 0 :
// CHECK: %[[C3:.+]] = arith.constant 3 :
// CHECK: %[[C1:.+]] = arith.constant 1 :
@@ -590,11 +598,10 @@ func.func @backedge_same_stage(%A: memref<?xf32>) -> f32 {
__test_pipelining_stage__ = 1,
__test_pipelining_op_order__ = 2
}
-func.func @pipeline_op_with_region(%A: memref<?xf32>, %B: memref<?xf32>, %result: memref<?xf32>) {
+func.func @pipeline_op_with_region(%A: memref<?xf32>, %B: memref<?xf32>, %result: memref<?xf32>, %cf: f32) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
- %cf = arith.constant 1.0 : f32
%a_buf = memref.alloc() : memref<2x8xf32>
%b_buf = memref.alloc() : memref<2x8xf32>
scf.for %i0 = %c0 to %c4 step %c1 {
diff --git a/mlir/test/IR/greedy-pattern-rewriter-driver.mlir b/mlir/test/IR/greedy-pattern-rewriter-driver.mlir
index 4f1a06fa6cf21..6f4923a9f4f75 100644
--- a/mlir/test/IR/greedy-pattern-rewriter-driver.mlir
+++ b/mlir/test/IR/greedy-pattern-rewriter-driver.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -test-patterns="max-iterations=1" | FileCheck %s
+// RUN: mlir-opt %s -test-patterns="max-iterations=1" \
+// RUN: -allow-unregistered-dialect --split-input-file | FileCheck %s
// CHECK-LABEL: func @add_to_worklist_after_inplace_update()
func.func @add_to_worklist_after_inplace_update() {
@@ -10,3 +11,16 @@ func.func @add_to_worklist_after_inplace_update() {
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: func @add_ancestors_to_worklist()
+func.func @add_ancestors_to_worklist() {
+ // CHECK: "foo.maybe_eligible_op"() {eligible} : () -> index
+ // CHECK-NEXT: "test.one_region_op"()
+ "test.one_region_op"() ({
+ %0 = "foo.maybe_eligible_op" () : () -> (index)
+ "foo.yield"(%0) : (index) -> ()
+ }) {hoist_eligible_ops}: () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 4beccf99670be..151da3554e9f1 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -140,6 +140,8 @@ struct TestSCFPipeliningPass
auto attrCycle =
op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
if (attrCycle && attrStage) {
+ // TODO: Index can be out-of-bounds if ops of the loop body disappear
+ // due to folding.
schedule[attrCycle.getInt()] =
std::make_pair(op, unsigned(attrStage.getInt()));
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 2573f76deb691..d3ef160176cc3 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -167,6 +167,38 @@ struct IncrementIntAttribute : public OpRewritePattern<AnyAttrOfOp> {
}
};
+/// This patterns adds an "eligible" attribute to "foo.maybe_eligible_op".
+struct MakeOpEligible : public RewritePattern {
+ MakeOpEligible(MLIRContext *context)
+ : RewritePattern("foo.maybe_eligible_op", /*benefit=*/1, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ if (op->hasAttr("eligible"))
+ return failure();
+ rewriter.updateRootInPlace(
+ op, [&]() { op->setAttr("eligible", rewriter.getUnitAttr()); });
+ return success();
+ }
+};
+
+/// This pattern hoists eligible ops out of a "test.one_region_op".
+struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
+ using OpRewritePattern<test::OneRegionOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(test::OneRegionOp op,
+ PatternRewriter &rewriter) const override {
+ Operation *terminator = op.getRegion().front().getTerminator();
+ Operation *toBeHoisted = terminator->getOperands()[0].getDefiningOp();
+ if (toBeHoisted->getParentOp() != op)
+ return failure();
+ if (!toBeHoisted->hasAttr("eligible"))
+ return failure();
+ toBeHoisted->moveBefore(op);
+ return success();
+ }
+};
+
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
@@ -183,7 +215,8 @@ struct TestPatternDriver
// Verify named pattern is generated with expected name.
patterns.add<FoldingPattern, TestNamedPatternRule,
FolderInsertBeforePreviouslyFoldedConstantPattern,
- FolderCommutativeOp2WithConstant>(&getContext());
+ FolderCommutativeOp2WithConstant, HoistEligibleOps,
+ MakeOpEligible>(&getContext());
// Additional patterns for testing the GreedyPatternRewriteDriver.
patterns.insert<IncrementIntAttribute<3>>(&getContext());
More information about the Mlir-commits
mailing list