[Mlir-commits] [mlir] f2b94bd - [mlir] check whether region and block visitors are interrupted

Ashay Rane llvmlistbot at llvm.org
Fri Jul 15 14:51:40 PDT 2022


Author: Ashay Rane
Date: 2022-07-15T14:50:42-07:00
New Revision: f2b94bd7eaa83d853dc7568fac87b1f8bf4ddec6

URL: https://github.com/llvm/llvm-project/commit/f2b94bd7eaa83d853dc7568fac87b1f8bf4ddec6
DIFF: https://github.com/llvm/llvm-project/commit/f2b94bd7eaa83d853dc7568fac87b1f8bf4ddec6.diff

LOG: [mlir] check whether region and block visitors are interrupted

The visitor functions for `Region` and `Block` types did not always
check the value returned by recursive calls.  This caused the top-level
visitor invocation to return `WalkResult::advance()` even if one or more
recursive invocations returned `WalkResult::interrupt()`.  This patch
fixes the problem by check if any recursive call is interrupted, and if
so, return `WalkResult::interrupt()`.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D129718

Added: 
    mlir/test/IR/generic-block-visitors-interrupt.mlir
    mlir/test/IR/generic-region-visitors-interrupt.mlir

Modified: 
    mlir/lib/IR/Visitors.cpp
    mlir/test/lib/IR/TestVisitorsGeneric.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp
index 822d881400ee0..74d6750703331 100644
--- a/mlir/lib/IR/Visitors.cpp
+++ b/mlir/lib/IR/Visitors.cpp
@@ -114,7 +114,8 @@ WalkResult detail::walk(Operation *op,
     }
     for (auto &block : region) {
       for (auto &nestedOp : block)
-        walk(&nestedOp, callback, order);
+        if (walk(&nestedOp, callback, order).wasInterrupted())
+          return WalkResult::interrupt();
     }
     if (order == WalkOrder::PostOrder) {
       if (callback(&region).wasInterrupted())
@@ -140,7 +141,8 @@ WalkResult detail::walk(Operation *op,
           return WalkResult::interrupt();
       }
       for (auto &nestedOp : block)
-        walk(&nestedOp, callback, order);
+        if (walk(&nestedOp, callback, order).wasInterrupted())
+          return WalkResult::interrupt();
       if (order == WalkOrder::PostOrder) {
         if (callback(&block).wasInterrupted())
           return WalkResult::interrupt();

diff  --git a/mlir/test/IR/generic-block-visitors-interrupt.mlir b/mlir/test/IR/generic-block-visitors-interrupt.mlir
new file mode 100644
index 0000000000000..9479288e768ed
--- /dev/null
+++ b/mlir/test/IR/generic-block-visitors-interrupt.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt -test-generic-ir-block-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+
+func.func @main(%arg0: f32) -> f32 {
+  %v1 = "foo"() {interrupt = true} : () -> f32
+  %v2 = arith.addf %v1, %arg0 : f32
+  return %v2 : f32
+}
+
+// CHECK: step 0 walk was interrupted

diff  --git a/mlir/test/IR/generic-region-visitors-interrupt.mlir b/mlir/test/IR/generic-region-visitors-interrupt.mlir
new file mode 100644
index 0000000000000..8fa1fdf62c327
--- /dev/null
+++ b/mlir/test/IR/generic-region-visitors-interrupt.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt -test-generic-ir-region-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s
+
+func.func @main(%arg0: f32) -> f32 {
+  %v1 = "foo"() {interrupt = true} : () -> f32
+  %v2 = arith.addf %v1, %arg0 : f32
+  return %v2 : f32
+}
+
+// CHECK: step 0 walk was interrupted

diff  --git a/mlir/test/lib/IR/TestVisitorsGeneric.cpp b/mlir/test/lib/IR/TestVisitorsGeneric.cpp
index c51a0e0d392ad..833d6294db525 100644
--- a/mlir/test/lib/IR/TestVisitorsGeneric.cpp
+++ b/mlir/test/lib/IR/TestVisitorsGeneric.cpp
@@ -113,6 +113,73 @@ struct TestGenericIRVisitorInterruptPass
   }
 };
 
+struct TestGenericIRBlockVisitorInterruptPass
+    : public PassWrapper<TestGenericIRBlockVisitorInterruptPass,
+                         OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestGenericIRBlockVisitorInterruptPass)
+
+  StringRef getArgument() const final {
+    return "test-generic-ir-block-visitors-interrupt";
+  }
+  StringRef getDescription() const final {
+    return "Test generic IR visitors with interrupts, starting with Blocks.";
+  }
+
+  void runOnOperation() override {
+    int stepNo = 0;
+
+    auto walker = [&](Block *block) {
+      for (Operation &op : *block)
+        for (OpResult result : op.getResults())
+          if (Operation *definingOp = result.getDefiningOp())
+            if (definingOp->getAttrOfType<BoolAttr>("interrupt"))
+              return WalkResult::interrupt();
+
+      llvm::outs() << "step " << stepNo++ << "\n";
+      return WalkResult::advance();
+    };
+
+    auto result = getOperation()->walk(walker);
+    if (result.wasInterrupted())
+      llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
+  }
+};
+
+struct TestGenericIRRegionVisitorInterruptPass
+    : public PassWrapper<TestGenericIRRegionVisitorInterruptPass,
+                         OperationPass<ModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestGenericIRRegionVisitorInterruptPass)
+
+  StringRef getArgument() const final {
+    return "test-generic-ir-region-visitors-interrupt";
+  }
+  StringRef getDescription() const final {
+    return "Test generic IR visitors with interrupts, starting with Regions.";
+  }
+
+  void runOnOperation() override {
+    int stepNo = 0;
+
+    auto walker = [&](Region *region) {
+      for (Block &block : *region)
+        for (Operation &op : block)
+          for (OpResult result : op.getResults())
+            if (Operation *definingOp = result.getDefiningOp())
+              if (definingOp->getAttrOfType<BoolAttr>("interrupt"))
+                return WalkResult::interrupt();
+
+      llvm::outs() << "step " << stepNo++ << "\n";
+      return WalkResult::advance();
+    };
+
+    auto result = getOperation()->walk(walker);
+    if (result.wasInterrupted())
+      llvm::outs() << "step " << stepNo++ << " walk was interrupted\n";
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -120,6 +187,8 @@ namespace test {
 void registerTestGenericIRVisitorsPass() {
   PassRegistration<TestGenericIRVisitorPass>();
   PassRegistration<TestGenericIRVisitorInterruptPass>();
+  PassRegistration<TestGenericIRBlockVisitorInterruptPass>();
+  PassRegistration<TestGenericIRRegionVisitorInterruptPass>();
 }
 
 } // namespace test


        


More information about the Mlir-commits mailing list