[Mlir-commits] [mlir] f2b94bd - [mlir] check whether region and block visitors are interrupted
Ashay Rane
llvmlistbot at llvm.org
Fri Jul 15 14:51:40 PDT 2022
Author: Ashay Rane
Date: 2022-07-15T14:50:42-07:00
New Revision: f2b94bd7eaa83d853dc7568fac87b1f8bf4ddec6
URL: https://github.com/llvm/llvm-project/commit/f2b94bd7eaa83d853dc7568fac87b1f8bf4ddec6
DIFF: https://github.com/llvm/llvm-project/commit/f2b94bd7eaa83d853dc7568fac87b1f8bf4ddec6.diff
LOG: [mlir] check whether region and block visitors are interrupted
The visitor functions for `Region` and `Block` types did not always
check the value returned by recursive calls. This caused the top-level
visitor invocation to return `WalkResult::advance()` even if one or more
recursive invocations returned `WalkResult::interrupt()`. This patch
fixes the problem by check if any recursive call is interrupted, and if
so, return `WalkResult::interrupt()`.
Reviewed By: dcaballe
Differential Revision: https://reviews.llvm.org/D129718
Added:
mlir/test/IR/generic-block-visitors-interrupt.mlir
mlir/test/IR/generic-region-visitors-interrupt.mlir
Modified:
mlir/lib/IR/Visitors.cpp
mlir/test/lib/IR/TestVisitorsGeneric.cpp
Removed:
################################################################################
diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp
index 822d881400ee0..74d6750703331 100644
--- a/mlir/lib/IR/Visitors.cpp
+++ b/mlir/lib/IR/Visitors.cpp
@@ -114,7 +114,8 @@ WalkResult detail::walk(Operation *op,
}
for (auto &block : region) {
for (auto &nestedOp : block)
- walk(&nestedOp, callback, order);
+ if (walk(&nestedOp, callback, order).wasInterrupted())
+ return WalkResult::interrupt();
}
if (order == WalkOrder::PostOrder) {
if (callback(®ion).wasInterrupted())
@@ -140,7 +141,8 @@ WalkResult detail::walk(Operation *op,
return WalkResult::interrupt();
}
for (auto &nestedOp : block)
- walk(&nestedOp, callback, order);
+ if (walk(&nestedOp, callback, order).wasInterrupted())
+ return WalkResult::interrupt();
if (order == WalkOrder::PostOrder) {
if (callback(&block).wasInterrupted())
return WalkResult::interrupt();
diff --git a/mlir/test/IR/generic-block-visitors-interrupt.mlir b/mlir/test/IR/generic-block-visitors-interrupt.mlir
new file mode 100644
index 0000000000000..9479288e768ed
--- /dev/null
+++ b/mlir/test/IR/generic-block-visitors-interrupt.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt -test-generic-ir-block-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+
+func.func @main(%arg0: f32) -> f32 {
+ %v1 = "foo"() {interrupt = true} : () -> f32
+ %v2 = arith.addf %v1, %arg0 : f32
+ return %v2 : f32
+}
+
+// CHECK: step 0 walk was interrupted
diff --git a/mlir/test/IR/generic-region-visitors-interrupt.mlir b/mlir/test/IR/generic-region-visitors-interrupt.mlir
new file mode 100644
index 0000000000000..8fa1fdf62c327
--- /dev/null
+++ b/mlir/test/IR/generic-region-visitors-interrupt.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt -test-generic-ir-region-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+
+func.func @main(%arg0: f32) -> f32 {
+ %v1 = "foo"() {interrupt = true} : () -> f32
+ %v2 = arith.addf %v1, %arg0 : f32
+ return %v2 : f32
+}
+
+// CHECK: step 0 walk was interrupted
diff --git a/mlir/test/lib/IR/TestVisitorsGeneric.cpp b/mlir/test/lib/IR/TestVisitorsGeneric.cpp
index c51a0e0d392ad..833d6294db525 100644
--- a/mlir/test/lib/IR/TestVisitorsGeneric.cpp
+++ b/mlir/test/lib/IR/TestVisitorsGeneric.cpp
@@ -113,6 +113,73 @@ struct TestGenericIRVisitorInterruptPass
}
};
+struct TestGenericIRBlockVisitorInterruptPass
+ : public PassWrapper<TestGenericIRBlockVisitorInterruptPass,
+ OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestGenericIRBlockVisitorInterruptPass)
+
+ StringRef getArgument() const final {
+ return "test-generic-ir-block-visitors-interrupt";
+ }
+ StringRef getDescription() const final {
+ return "Test generic IR visitors with interrupts, starting with Blocks.";
+ }
+
+ void runOnOperation() override {
+ int stepNo = 0;
+
+ auto walker = [&](Block *block) {
+ for (Operation &op : *block)
+ for (OpResult result : op.getResults())
+ if (Operation *definingOp = result.getDefiningOp())
+ if (definingOp->getAttrOfType<BoolAttr>("interrupt"))
+ return WalkResult::interrupt();
+
+ llvm::outs() << "step " << stepNo++ << "\n";
+ return WalkResult::advance();
+ };
+
+ auto result = getOperation()->walk(walker);
+ if (result.wasInterrupted())
+ llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
+ }
+};
+
+struct TestGenericIRRegionVisitorInterruptPass
+ : public PassWrapper<TestGenericIRRegionVisitorInterruptPass,
+ OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestGenericIRRegionVisitorInterruptPass)
+
+ StringRef getArgument() const final {
+ return "test-generic-ir-region-visitors-interrupt";
+ }
+ StringRef getDescription() const final {
+ return "Test generic IR visitors with interrupts, starting with Regions.";
+ }
+
+ void runOnOperation() override {
+ int stepNo = 0;
+
+ auto walker = [&](Region *region) {
+ for (Block &block : *region)
+ for (Operation &op : block)
+ for (OpResult result : op.getResults())
+ if (Operation *definingOp = result.getDefiningOp())
+ if (definingOp->getAttrOfType<BoolAttr>("interrupt"))
+ return WalkResult::interrupt();
+
+ llvm::outs() << "step " << stepNo++ << "\n";
+ return WalkResult::advance();
+ };
+
+ auto result = getOperation()->walk(walker);
+ if (result.wasInterrupted())
+ llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
+ }
+};
+
} // namespace
namespace mlir {
@@ -120,6 +187,8 @@ namespace test {
void registerTestGenericIRVisitorsPass() {
PassRegistration<TestGenericIRVisitorPass>();
PassRegistration<TestGenericIRVisitorInterruptPass>();
+ PassRegistration<TestGenericIRBlockVisitorInterruptPass>();
+ PassRegistration<TestGenericIRRegionVisitorInterruptPass>();
}
} // namespace test
More information about the Mlir-commits
mailing list