[llvm] 6569cf2 - [mlir] Add a ThreadPool to MLIRContext and refactor MLIR threading usage

River Riddle via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 22 18:33:06 PDT 2021


Author: River Riddle
Date: 2021-06-23T01:29:24Z
New Revision: 6569cf2a44bf95106e7168bdb79c4674742708fa

URL: https://github.com/llvm/llvm-project/commit/6569cf2a44bf95106e7168bdb79c4674742708fa
DIFF: https://github.com/llvm/llvm-project/commit/6569cf2a44bf95106e7168bdb79c4674742708fa.diff

LOG: [mlir] Add a ThreadPool to MLIRContext and refactor MLIR threading usage

This revision refactors the usage of multithreaded utilities in MLIR to use a common
thread pool within the MLIR context, in addition to a new utility that makes writing
multi-threaded code in MLIR less error prone. Using a unified thread pool brings about
several advantages:

* Better thread usage and more control
We currently use the static llvm threading utilities, which do not allow multiple
levels of asynchronous scheduling (even if there are open threads). This is due to
how the current TaskGroup structure works, which only allows one truly multithreaded
instance at a time. By having our own ThreadPool we gain more control and flexibility
over our job/thread scheduling, and in a followup can enable threading more parts of
the compiler.

* The static nature of TaskGroup causes issues in certain configurations
Due to the static nature of TaskGroup, there have been quite a few problems related to
destruction that have caused several downstream projects to disable threading. See
D104207 for discussion on some related fallout. By having a ThreadPool scoped to
the context, we don't have to worry about destruction and can ensure that any
additional MLIR thread usage ends when the context is destroyed.

Differential Revision: https://reviews.llvm.org/D104516

Added: 
    mlir/include/mlir/IR/Threading.h

Modified: 
    llvm/include/llvm/Support/ThreadPool.h
    llvm/lib/Support/ThreadPool.cpp
    mlir/include/mlir/IR/MLIRContext.h
    mlir/lib/IR/MLIRContext.cpp
    mlir/lib/IR/Verifier.cpp
    mlir/lib/Pass/Pass.cpp
    mlir/lib/Transforms/Inliner.cpp
    mlir/test/Dialect/Affine/SuperVectorize/compose_maps.mlir
    mlir/test/Dialect/Affine/slicing-utils.mlir
    mlir/test/IR/diagnostic-handler-filter.mlir
    mlir/test/Pass/pass-timing.mlir
    mlir/test/Pass/pipeline-parsing.mlir

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/ThreadPool.h b/llvm/include/llvm/Support/ThreadPool.h
index 9d319eb71bbea..4c41b88d60438 100644
--- a/llvm/include/llvm/Support/ThreadPool.h
+++ b/llvm/include/llvm/Support/ThreadPool.h
@@ -70,6 +70,9 @@ class ThreadPool {
 
   unsigned getThreadCount() const { return ThreadCount; }
 
+  /// Returns true if the current thread is a worker thread of this thread pool.
+  bool isWorkerThread() const;
+
 private:
   bool workCompletedUnlocked() { return !ActiveThreads && Tasks.empty(); }
 

diff  --git a/llvm/lib/Support/ThreadPool.cpp b/llvm/lib/Support/ThreadPool.cpp
index 46a1990cd7196..f442b3b0bc980 100644
--- a/llvm/lib/Support/ThreadPool.cpp
+++ b/llvm/lib/Support/ThreadPool.cpp
@@ -72,6 +72,14 @@ void ThreadPool::wait() {
   CompletionCondition.wait(LockGuard, [&] { return workCompletedUnlocked(); });
 }
 
+bool ThreadPool::isWorkerThread() const {
+  std::thread::id CurrentThreadId = std::this_thread::get_id();
+  for (const std::thread &Thread : Threads)
+    if (CurrentThreadId == Thread.get_id())
+      return true;
+  return false;
+}
+
 std::shared_future<void> ThreadPool::asyncImpl(TaskTy Task) {
   /// Wrap the Task in a packaged_task to return a future object.
   PackagedTaskTy PackagedTask(std::move(Task));

diff  --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 8d53c0bd225e0..7b0fcd66dc2b0 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -15,6 +15,10 @@
 #include <memory>
 #include <vector>
 
+namespace llvm {
+class ThreadPool;
+} // end namespace llvm
+
 namespace mlir {
 class AbstractOperation;
 class DebugActionManager;
@@ -114,6 +118,12 @@ class MLIRContext {
     disableMultithreading(!enable);
   }
 
+  /// Return the thread pool owned by this context. This method requires that
+  /// multithreading be enabled within the context, and should generally not be
+  /// used directly. Users should instead prefer the threading utilities within
+  /// Threading.h.
+  llvm::ThreadPool &getThreadPool();
+
   /// Return true if we should attach the operation to diagnostics emitted via
   /// Operation::emit.
   bool shouldPrintOpOnDiagnostic();

diff  --git a/mlir/include/mlir/IR/Threading.h b/mlir/include/mlir/IR/Threading.h
new file mode 100644
index 0000000000000..384223161014f
--- /dev/null
+++ b/mlir/include/mlir/IR/Threading.h
@@ -0,0 +1,153 @@
+//===- Threading.h - MLIR Threading Utilities -------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines various utilies for multithreaded processing within MLIR.
+// These utilities automatically handle many of the necessary threading
+// conditions, such as properly ordering diagnostics, observing if threading is
+// disabled, etc. These utilities should be used over other threading utilities
+// whenever feasible.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_THREADING_H
+#define MLIR_IR_THREADING_H
+
+#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/Sequence.h"
+#include "llvm/Support/ThreadPool.h"
+#include <atomic>
+
+namespace mlir {
+
+/// Invoke the given function on the elements between [begin, end)
+/// asynchronously. If the given function returns a failure when processing any
+/// of the elements, execution is stopped and a failure is returned from this
+/// function. This means that in the case of failure, not all elements of the
+/// range will be processed. Diagnostics emitted during processing are ordered
+/// relative to the element's position within [begin, end). If the provided
+/// context does not have multi-threading enabled, this function always
+/// processes elements sequentially.
+template <typename IteratorT, typename FuncT>
+LogicalResult failableParallelForEach(MLIRContext *context, IteratorT begin,
+                                      IteratorT end, FuncT &&func) {
+  unsigned numElements = static_cast<unsigned>(std::distance(begin, end));
+  if (numElements == 0)
+    return success();
+
+  // If multithreading is disabled or there is a small number of elements,
+  // process the elements directly on this thread.
+  // FIXME: ThreadPool should allow work stealing to avoid deadlocks when
+  // scheduling work within a worker thread.
+  if (!context->isMultithreadingEnabled() || numElements <= 1 ||
+      context->getThreadPool().isWorkerThread()) {
+    for (; begin != end; ++begin)
+      if (failed(func(*begin)))
+        return failure();
+    return success();
+  }
+
+  // Build a wrapper processing function that properly initializes a parallel
+  // diagnostic handler.
+  ParallelDiagnosticHandler handler(context);
+  std::atomic<unsigned> curIndex(0);
+  std::atomic<bool> processingFailed(false);
+  auto processFn = [&] {
+    while (!processingFailed) {
+      unsigned index = curIndex++;
+      if (index >= numElements)
+        break;
+      handler.setOrderIDForThread(index);
+      if (failed(func(*std::next(begin, index))))
+        processingFailed = true;
+      handler.eraseOrderIDForThread();
+    }
+  };
+
+  // Otherwise, process the elements in parallel.
+  llvm::ThreadPool &threadPool = context->getThreadPool();
+  size_t numActions = std::min(numElements, threadPool.getThreadCount());
+  SmallVector<std::shared_future<void>> threadFutures;
+  threadFutures.reserve(numActions - 1);
+  for (unsigned i = 1; i < numActions; ++i)
+    threadFutures.emplace_back(threadPool.async(processFn));
+  processFn();
+
+  // Wait for all of the threads to finish.
+  for (std::shared_future<void> &future : threadFutures)
+    future.wait();
+  return failure(processingFailed);
+}
+
+/// Invoke the given function on the elements in the provided range
+/// asynchronously. If the given function returns a failure when processing any
+/// of the elements, execution is stopped and a failure is returned from this
+/// function. This means that in the case of failure, not all elements of the
+/// range will be processed. Diagnostics emitted during processing are ordered
+/// relative to the element's position within the range. If the provided context
+/// does not have multi-threading enabled, this function always processes
+/// elements sequentially.
+template <typename RangeT, typename FuncT>
+LogicalResult failableParallelForEach(MLIRContext *context, RangeT &&range,
+                                      FuncT &&func) {
+  return failableParallelForEach(context, std::begin(range), std::end(range),
+                                 std::forward<FuncT>(func));
+}
+
+/// Invoke the given function on the elements between [begin, end)
+/// asynchronously. If the given function returns a failure when processing any
+/// of the elements, execution is stopped and a failure is returned from this
+/// function. This means that in the case of failure, not all elements of the
+/// range will be processed. Diagnostics emitted during processing are ordered
+/// relative to the element's position within [begin, end). If the provided
+/// context does not have multi-threading enabled, this function always
+/// processes elements sequentially.
+template <typename FuncT>
+LogicalResult failableParallelForEachN(MLIRContext *context, size_t begin,
+                                       size_t end, FuncT &&func) {
+  return failableParallelForEach(context, llvm::seq(begin, end),
+                                 std::forward<FuncT>(func));
+}
+
+/// Invoke the given function on the elements between [begin, end)
+/// asynchronously. Diagnostics emitted during processing are ordered relative
+/// to the element's position within [begin, end). If the provided context does
+/// not have multi-threading enabled, this function always processes elements
+/// sequentially.
+template <typename IteratorT, typename FuncT>
+void parallelForEach(MLIRContext *context, IteratorT begin, IteratorT end,
+                     FuncT &&func) {
+  (void)failableParallelForEach(context, begin, end, [&](auto &&value) {
+    return func(std::forward<decltype(value)>(value)), success();
+  });
+}
+
+/// Invoke the given function on the elements in the provided range
+/// asynchronously. Diagnostics emitted during processing are ordered relative
+/// to the element's position within the range. If the provided context does not
+/// have multi-threading enabled, this function always processes elements
+/// sequentially.
+template <typename RangeT, typename FuncT>
+void parallelForEach(MLIRContext *context, RangeT &&range, FuncT &&func) {
+  parallelForEach(context, std::begin(range), std::end(range),
+                  std::forward<FuncT>(func));
+}
+
+/// Invoke the given function on the elements between [begin, end)
+/// asynchronously. Diagnostics emitted during processing are ordered relative
+/// to the element's position within [begin, end). If the provided context does
+/// not have multi-threading enabled, this function always processes elements
+/// sequentially.
+template <typename FuncT>
+void parallelForEachN(MLIRContext *context, size_t begin, size_t end,
+                      FuncT &&func) {
+  parallelForEach(context, llvm::seq(begin, end), std::forward<FuncT>(func));
+}
+
+} // end namespace mlir
+
+#endif // MLIR_IR_THREADING_H

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 8aa4fe7db3cbd..1ae3e6c21cc51 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -34,6 +34,7 @@
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/RWMutex.h"
+#include "llvm/Support/ThreadPool.h"
 #include "llvm/Support/raw_ostream.h"
 #include <memory>
 
@@ -260,6 +261,9 @@ class MLIRContextImpl {
   // Other
   //===--------------------------------------------------------------------===//
 
+  /// The thread pool to use when processing MLIR tasks in parallel.
+  llvm::ThreadPool threadPool;
+
   /// This is a list of dialects that are created referring to this context.
   /// The MLIRContext owns the objects.
   DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
@@ -571,6 +575,12 @@ void MLIRContext::disableMultithreading(bool disable) {
   impl->typeUniquer.disableMultithreading(disable);
 }
 
+llvm::ThreadPool &MLIRContext::getThreadPool() {
+  assert(isMultithreadingEnabled() &&
+         "expected multi-threading to be enabled within the context");
+  return impl->threadPool;
+}
+
 void MLIRContext::enterMultiThreadedExecution() {
 #ifndef NDEBUG
   ++impl->multiThreadedExecutionContext;

diff  --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp
index d1d9ffb5e61a5..42068677c1ced 100644
--- a/mlir/lib/IR/Verifier.cpp
+++ b/mlir/lib/IR/Verifier.cpp
@@ -30,6 +30,7 @@
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/RegionKindInterface.h"
+#include "mlir/IR/Threading.h"
 #include "llvm/ADT/StringMap.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/Parallel.h"
@@ -43,11 +44,6 @@ namespace {
 /// This class encapsulates all the state used to verify an operation region.
 class OperationVerifier {
 public:
-  explicit OperationVerifier(MLIRContext *context)
-      // TODO: Re-enable parallelism once deadlocks found in D104207 are
-      // resolved.
-      : parallelismEnabled(false) {}
-
   /// Verify the given operation.
   LogicalResult verifyOpAndDominance(Operation &op);
 
@@ -66,9 +62,6 @@ class OperationVerifier {
   /// Operation.
   LogicalResult verifyDominanceOfContainedRegions(Operation &op,
                                                   DominanceInfo &domInfo);
-
-  /// This is true if parallelism is enabled on the MLIRContext.
-  const bool parallelismEnabled;
 };
 } // end anonymous namespace
 
@@ -91,28 +84,9 @@ LogicalResult OperationVerifier::verifyOpAndDominance(Operation &op) {
 
   // Check the dominance properties and invariants of any operations in the
   // regions contained by the 'opsWithIsolatedRegions' operations.
-  if (!parallelismEnabled || opsWithIsolatedRegions.size() <= 1) {
-    // If parallelism is disabled or if there is only 0/1 operation to do, use
-    // a simple non-parallel loop.
-    for (Operation *op : opsWithIsolatedRegions) {
-      if (failed(verifyOpAndDominance(*op)))
-        return failure();
-    }
-  } else {
-    // Otherwise, verify the operations and their bodies in parallel.
-    ParallelDiagnosticHandler handler(op.getContext());
-    std::atomic<bool> passFailed(false);
-    llvm::parallelForEachN(0, opsWithIsolatedRegions.size(), [&](size_t opIdx) {
-      handler.setOrderIDForThread(opIdx);
-      if (failed(verifyOpAndDominance(*opsWithIsolatedRegions[opIdx])))
-        passFailed = true;
-      handler.eraseOrderIDForThread();
-    });
-    if (passFailed)
-      return failure();
-  }
-
-  return success();
+  return failableParallelForEach(
+      op.getContext(), opsWithIsolatedRegions,
+      [&](Operation *op) { return verifyOpAndDominance(*op); });
 }
 
 /// Returns true if this block may be valid without terminator. That is if:
@@ -378,5 +352,5 @@ OperationVerifier::verifyDominanceOfContainedRegions(Operation &op,
 /// compiler bugs.  On error, this reports the error through the MLIRContext and
 /// returns failure.
 LogicalResult mlir::verify(Operation *op) {
-  return OperationVerifier(op->getContext()).verifyOpAndDominance(*op);
+  return OperationVerifier().verifyOpAndDominance(*op);
 }

diff  --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index bae172200a764..3b9437d81407b 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -14,6 +14,7 @@
 #include "PassDetail.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/Threading.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Support/FileUtilities.h"
 #include "llvm/ADT/STLExtras.h"
@@ -580,61 +581,40 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
     }
   }
 
-  // A parallel diagnostic handler that provides deterministic diagnostic
-  // ordering.
-  ParallelDiagnosticHandler diagHandler(&getContext());
-
-  // An index for the current operation/analysis manager pair.
-  std::atomic<unsigned> opIt(0);
-
   // Get the current thread for this adaptor.
   PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
                                                         this};
   auto *instrumentor = am.getPassInstrumentor();
 
   // An atomic failure variable for the async executors.
-  std::atomic<bool> passFailed(false);
-  llvm::parallelForEach(
-      asyncExecutors.begin(),
-      std::next(asyncExecutors.begin(),
-                std::min(asyncExecutors.size(), opAMPairs.size())),
-      [&](MutableArrayRef<OpPassManager> pms) {
-        for (auto e = opAMPairs.size(); !passFailed && opIt < e;) {
-          // Get the next available operation index.
-          unsigned nextID = opIt++;
-          if (nextID >= e)
-            break;
-
-          // Set the order id for this thread in the diagnostic handler.
-          diagHandler.setOrderIDForThread(nextID);
-
-          // Get the pass manager for this operation and execute it.
-          auto &it = opAMPairs[nextID];
-          auto *pm = findPassManagerFor(
-              pms, it.first->getName().getIdentifier(), getContext());
-          assert(pm && "expected valid pass manager for operation");
-
-          unsigned initGeneration = pm->impl->initializationGeneration;
-          LogicalResult pipelineResult =
-              runPipeline(pm->getPasses(), it.first, it.second, verifyPasses,
-                          initGeneration, instrumentor, &parentInfo);
-
-          // Drop this thread from being tracked by the diagnostic handler.
-          // After this task has finished, the thread may be used outside of
-          // this pass manager context meaning that we don't want to track
-          // diagnostics from it anymore.
-          diagHandler.eraseOrderIDForThread();
-
-          // Handle a failed pipeline result.
-          if (failed(pipelineResult)) {
-            passFailed = true;
-            break;
-          }
-        }
-      });
+  std::vector<std::atomic<bool>> activePMs(asyncExecutors.size());
+  std::fill(activePMs.begin(), activePMs.end(), false);
+  auto processFn = [&](auto &opPMPair) {
+    // Find a pass manager for this operation.
+    auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
+      bool expectedInactive = false;
+      return isActive.compare_exchange_strong(expectedInactive, true);
+    });
+    unsigned pmIndex = it - activePMs.begin();
+
+    // Get the pass manager for this operation and execute it.
+    auto *pm = findPassManagerFor(asyncExecutors[pmIndex],
+                                  opPMPair.first->getName().getIdentifier(),
+                                  getContext());
+    assert(pm && "expected valid pass manager for operation");
+
+    unsigned initGeneration = pm->impl->initializationGeneration;
+    LogicalResult pipelineResult =
+        runPipeline(pm->getPasses(), opPMPair.first, opPMPair.second,
+                    verifyPasses, initGeneration, instrumentor, &parentInfo);
+
+    // Reset the active bit for this pass manager.
+    activePMs[pmIndex].store(false);
+    return pipelineResult;
+  };
 
   // Signal a failure if any of the executors failed.
-  if (passFailed)
+  if (failed(failableParallelForEach(&getContext(), opAMPairs, processFn)))
     signalPassFailure();
 }
 

diff  --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index bbcb1b2e0bc63..b5a20dc0124bb 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -15,6 +15,7 @@
 
 #include "PassDetail.h"
 #include "mlir/Analysis/CallGraph.h"
+#include "mlir/IR/Threading.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -662,21 +663,9 @@ LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList,
     return success();
 
   // Optimize each of the nodes within the SCC in parallel.
-  // NOTE: This is simple now, because we don't enable optimizing nodes within
-  // children. When we remove this restriction, this logic will need to be
-  // reworked.
-  if (context->isMultithreadingEnabled() && nodesToVisit.size() > 1) {
-    if (failed(optimizeSCCAsync(nodesToVisit, context)))
+  if (failed(optimizeSCCAsync(nodesToVisit, context)))
       return failure();
 
-    // Otherwise, we are optimizing within a single thread.
-  } else {
-    for (CallGraphNode *node : nodesToVisit) {
-      if (failed(optimizeCallable(node, opPipelines[0])))
-        return failure();
-    }
-  }
-
   // Recompute the uses held by each of the nodes.
   for (CallGraphNode *node : nodesToVisit)
     useList.recomputeUses(node, cg);
@@ -685,7 +674,7 @@ LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList,
 
 LogicalResult
 InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
-                              MLIRContext *context) {
+                              MLIRContext *ctx) {
   // Ensure that there are enough pipeline maps for the optimizer to run in
   // parallel. Note: The number of pass managers here needs to remain constant
   // to prevent issues with pass instrumentations that rely on having the same
@@ -703,35 +692,24 @@ InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
   for (CallGraphNode *node : nodesToVisit)
     getAnalysisManager().nest(node->getCallableRegion()->getParentOp());
 
-  // An index for the current node to optimize.
-  std::atomic<unsigned> nodeIt(0);
-
-  // Optimize the nodes of the SCC in parallel.
-  ParallelDiagnosticHandler optimizerHandler(context);
-  std::atomic<bool> passFailed(false);
-  llvm::parallelForEach(
-      opPipelines.begin(), std::next(opPipelines.begin(), numThreads),
-      [&](llvm::StringMap<OpPassManager> &pipelines) {
-        for (auto e = nodesToVisit.size(); !passFailed && nodeIt < e;) {
-          // Get the next available operation index.
-          unsigned nextID = nodeIt++;
-          if (nextID >= e)
-            break;
-
-          // Set the order for this thread so that diagnostics will be
-          // properly ordered, and reset after optimization has finished.
-          optimizerHandler.setOrderIDForThread(nextID);
-          LogicalResult pipelineResult =
-              optimizeCallable(nodesToVisit[nextID], pipelines);
-          optimizerHandler.eraseOrderIDForThread();
-
-          if (failed(pipelineResult)) {
-            passFailed = true;
-            break;
-          }
-        }
-      });
-  return failure(passFailed);
+  // An atomic failure variable for the async executors.
+  std::vector<std::atomic<bool>> activePMs(opPipelines.size());
+  std::fill(activePMs.begin(), activePMs.end(), false);
+  return failableParallelForEach(ctx, nodesToVisit, [&](CallGraphNode *node) {
+    // Find a pass manager for this operation.
+    auto it = llvm::find_if(activePMs, [](std::atomic<bool> &isActive) {
+      bool expectedInactive = false;
+      return isActive.compare_exchange_strong(expectedInactive, true);
+    });
+    unsigned pmIndex = it - activePMs.begin();
+
+    // Optimize this callable node.
+    LogicalResult result = optimizeCallable(node, opPipelines[pmIndex]);
+
+    // Reset the active bit for this pass manager.
+    activePMs[pmIndex].store(false);
+    return result;
+  });
 }
 
 LogicalResult

diff  --git a/mlir/test/Dialect/Affine/SuperVectorize/compose_maps.mlir b/mlir/test/Dialect/Affine/SuperVectorize/compose_maps.mlir
index 126a176fa8362..8b3fc8b07bbb4 100644
--- a/mlir/test/Dialect/Affine/SuperVectorize/compose_maps.mlir
+++ b/mlir/test/Dialect/Affine/SuperVectorize/compose_maps.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -affine-super-vectorizer-test -compose-maps 2>&1 |  FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -affine-super-vectorizer-test -compose-maps -split-input-file 2>&1 |  FileCheck %s
 
 // For all these cases, the test traverses the `test_affine_map` ops and
 // composes them in order one-by-one.
@@ -16,6 +16,8 @@ func @simple1() {
   return
 }
 
+// -----
+
 func @simple2() {
   // CHECK: Composed map: (d0)[s0, s1] -> (d0 - s0 + s1)
   "test_affine_map"() { affine_map = affine_map<(d0)[s0] -> (d0 + s0 - 1)> } : () -> ()
@@ -23,6 +25,8 @@ func @simple2() {
   return
 }
 
+// -----
+
 func @simple3a() {
   // CHECK: Composed map: (d0, d1)[s0, s1, s2, s3] -> ((d0 ceildiv s2) * s0, (d1 ceildiv s3) * s1)
   "test_affine_map"() { affine_map = affine_map<(d0, d1)[s0, s1] -> (d0 ceildiv s0, d1 ceildiv s1)> } : () -> ()
@@ -30,12 +34,16 @@ func @simple3a() {
   return
 }
 
+// -----
+
 func @simple3b() {
   // CHECK: Composed map: (d0, d1)[s0, s1] -> (d0 mod s0, d1 mod s1)
   "test_affine_map"() { affine_map = affine_map<(d0, d1)[s0, s1] -> (d0 mod s0, d1 mod s1)> } : () -> ()
   return
 }
 
+// -----
+
 func @simple3c() {
   // CHECK: Composed map: (d0, d1)[s0, s1, s2, s3, s4, s5] -> ((d0 ceildiv s4) * s4 + d0 mod s2, (d1 ceildiv s5) * s5 + d1 mod s3)
   "test_affine_map"() { affine_map = affine_map<(d0, d1)[s0, s1] -> ((d0 ceildiv s0) * s0, (d1 ceildiv s1) * s1, d0, d1)> } : () -> ()
@@ -43,6 +51,8 @@ func @simple3c() {
   return
 }
 
+// -----
+
 func @simple4() {
   // CHECK: Composed map: (d0, d1)[s0, s1] -> (d1 * s1, d0 ceildiv s0)
   "test_affine_map"() { affine_map = affine_map<(d0, d1) -> (d1, d0)> } : () -> ()
@@ -50,6 +60,8 @@ func @simple4() {
   return
 }
 
+// -----
+
 func @simple5a() {
   // CHECK: Composed map: (d0) -> (d0 * 3 + 18)
   "test_affine_map"() { affine_map = affine_map<(d0) -> (d0 - 1)> } : () -> ()
@@ -59,6 +71,8 @@ func @simple5a() {
   return
 }
 
+// -----
+
 func @simple5b() {
   // CHECK: Composed map: (d0) -> ((d0 + 6) ceildiv 2)
   "test_affine_map"() { affine_map = affine_map<(d0) -> (d0 - 1)> } : () -> ()
@@ -68,6 +82,8 @@ func @simple5b() {
   return
 }
 
+// -----
+
 func @simple5c() {
   // CHECK: Composed map: (d0) -> (d0 * 8 + 48)
   "test_affine_map"() { affine_map = affine_map<(d0) -> (d0 - 1)> } : () -> ()
@@ -77,6 +93,8 @@ func @simple5c() {
   return
 }
 
+// -----
+
 func @simple5d() {
   // CHECK: Composed map: (d0) -> ((d0 * 4) floordiv 3 + 8)
   "test_affine_map"() { affine_map = affine_map<(d0) -> (d0 - 1)> } : () -> ()
@@ -86,6 +104,8 @@ func @simple5d() {
   return
 }
 
+// -----
+
 func @simple5e() {
   // CHECK: Composed map: (d0) -> ((d0 + 6) ceildiv 8)
   "test_affine_map"() { affine_map = affine_map<(d0) -> (d0 - 1)> } : () -> ()
@@ -94,6 +114,8 @@ func @simple5e() {
   return
 }
 
+// -----
+
 func @simple5f() {
   // CHECK: Composed map: (d0) -> ((d0 * 4 - 4) floordiv 3)
   "test_affine_map"() { affine_map = affine_map<(d0) -> (d0 - 1)> } : () -> ()
@@ -102,6 +124,8 @@ func @simple5f() {
   return
 }
 
+// -----
+
 func @perm_and_proj() {
   // CHECK: Composed map: (d0, d1, d2, d3) -> (d1, d3, d0)
   "test_affine_map"() { affine_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2, d0)> } : () -> ()
@@ -109,6 +133,8 @@ func @perm_and_proj() {
   return
 }
 
+// -----
+
 func @symbols1() {
   // CHECK: Composed map: (d0)[s0] -> (d0 + s0 + 1, d0 - s0 - 1)
   "test_affine_map"() { affine_map = affine_map<(d0)[s0] -> (d0 + s0, d0 - s0)> } : () -> ()
@@ -116,6 +142,8 @@ func @symbols1() {
   return
 }
 
+// -----
+
 func @drop() {
   // CHECK: Composed map: (d0, d1, d2)[s0, s1] -> (d0 * 2 + d1 + d2 + s1)
   "test_affine_map"() { affine_map = affine_map<(d0, d1, d2)[s0, s1] -> (d0 + s1, d1 + s0, d0 + d1 + d2)> } : () -> ()
@@ -123,6 +151,8 @@ func @drop() {
   return
 }
 
+// -----
+
 func @multi_symbols() {
   // CHECK: Composed map: (d0)[s0, s1, s2] -> (d0 + s1 + s2 + 1, d0 - s0 - s2 - 1)
   "test_affine_map"() { affine_map = affine_map<(d0)[s0] -> (d0 + s0, d0 - s0)> } : () -> ()

diff  --git a/mlir/test/Dialect/Affine/slicing-utils.mlir b/mlir/test/Dialect/Affine/slicing-utils.mlir
index 9ce5cb1fdd727..500aab1919e87 100644
--- a/mlir/test/Dialect/Affine/slicing-utils.mlir
+++ b/mlir/test/Dialect/Affine/slicing-utils.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -affine-super-vectorizer-test -forward-slicing=true 2>&1 | FileCheck %s --check-prefix=FWD
-// RUN: mlir-opt -allow-unregistered-dialect %s -affine-super-vectorizer-test -backward-slicing=true 2>&1 | FileCheck %s --check-prefix=BWD
-// RUN: mlir-opt -allow-unregistered-dialect %s -affine-super-vectorizer-test -slicing=true 2>&1 | FileCheck %s --check-prefix=FWDBWD
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -affine-super-vectorizer-test -forward-slicing=true 2>&1 | FileCheck %s --check-prefix=FWD
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -affine-super-vectorizer-test -backward-slicing=true 2>&1 | FileCheck %s --check-prefix=BWD
+// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -affine-super-vectorizer-test -slicing=true 2>&1 | FileCheck %s --check-prefix=FWDBWD
 
 ///   1       2      3      4
 ///   |_______|      |______|
@@ -217,6 +217,8 @@ func @slicing_test() {
   return
 }
 
+// -----
+
 // FWD-LABEL: slicing_test_2
 // BWD-LABEL: slicing_test_2
 // FWDBWD-LABEL: slicing_test_2
@@ -250,6 +252,8 @@ func @slicing_test_2() {
   return
 }
 
+// -----
+
 // FWD-LABEL: slicing_test_3
 // BWD-LABEL: slicing_test_3
 // FWDBWD-LABEL: slicing_test_3
@@ -265,6 +269,8 @@ func @slicing_test_3() {
   return
 }
 
+// -----
+
 // FWD-LABEL: slicing_test_function_argument
 // BWD-LABEL: slicing_test_function_argument
 // FWDBWD-LABEL: slicing_test_function_argument
@@ -274,6 +280,8 @@ func @slicing_test_function_argument(%arg0: index) -> index {
   return %0 : index
 }
 
+// -----
+
 // FWD-LABEL: slicing_test_multiple_return
 // BWD-LABEL: slicing_test_multiple_return
 // FWDBWD-LABEL: slicing_test_multiple_return

diff  --git a/mlir/test/IR/diagnostic-handler-filter.mlir b/mlir/test/IR/diagnostic-handler-filter.mlir
index d193c1ac0bd58..40141d441a765 100644
--- a/mlir/test/IR/diagnostic-handler-filter.mlir
+++ b/mlir/test/IR/diagnostic-handler-filter.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-diagnostic-filter='filters=mysource1' -o - 2>&1 | FileCheck %s
+// RUN: mlir-opt %s -test-diagnostic-filter='filters=mysource1' -split-input-file -o - 2>&1 | FileCheck %s
 // This test verifies that diagnostic handler can emit the call stack successfully.
 
 // CHECK-LABEL: Test 'test1'
@@ -8,6 +8,8 @@ func private @test1() attributes {
   test.loc = loc(callsite("foo"("mysource1":0:0) at callsite("mysource2":1:0 at "mysource3":2:0)))
 }
 
+// -----
+
 // CHECK-LABEL: Test 'test2'
 // CHECK-NEXT: mysource1:0:0: error: test diagnostic
 func private @test2() attributes {

diff  --git a/mlir/test/Pass/pass-timing.mlir b/mlir/test/Pass/pass-timing.mlir
index baef12638e695..ab6ae8162815e 100644
--- a/mlir/test/Pass/pass-timing.mlir
+++ b/mlir/test/Pass/pass-timing.mlir
@@ -2,7 +2,7 @@
 // RUN: mlir-opt %s -mlir-disable-threading=true -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck -check-prefix=PIPELINE %s
 // RUN: mlir-opt %s -mlir-disable-threading=false -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -mlir-timing -mlir-timing-display=list 2>&1 | FileCheck -check-prefix=MT_LIST %s
 // RUN: mlir-opt %s -mlir-disable-threading=false -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck -check-prefix=MT_PIPELINE %s
-// RUN: mlir-opt %s -mlir-disable-threading=false -verify-each=false -test-pm-nested-pipeline -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck -check-prefix=NESTED_MT_PIPELINE %s
+// RUN: mlir-opt %s -mlir-disable-threading=true -verify-each=false -test-pm-nested-pipeline -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck -check-prefix=NESTED_PIPELINE %s
 
 // LIST: Execution time report
 // LIST: Total Execution Time:
@@ -48,20 +48,20 @@
 // MT_PIPELINE-NEXT: Rest
 // MT_PIPELINE-NEXT: Total
 
-// NESTED_MT_PIPELINE: Execution time report
-// NESTED_MT_PIPELINE: Total Execution Time:
-// NESTED_MT_PIPELINE: Name
-// NESTED_MT_PIPELINE-NEXT: Parser
-// NESTED_MT_PIPELINE-NEXT: Pipeline Collection : ['func', 'module']
-// NESTED_MT_PIPELINE-NEXT:   'func' Pipeline
-// NESTED_MT_PIPELINE-NEXT:     TestFunctionPass
-// NESTED_MT_PIPELINE-NEXT:   'module' Pipeline
-// NESTED_MT_PIPELINE-NEXT:     TestModulePass
-// NESTED_MT_PIPELINE-NEXT:     'func' Pipeline
-// NESTED_MT_PIPELINE-NEXT:       TestFunctionPass
-// NESTED_MT_PIPELINE-NEXT: Output
-// NESTED_MT_PIPELINE-NEXT: Rest
-// NESTED_MT_PIPELINE-NEXT: Total
+// NESTED_PIPELINE: Execution time report
+// NESTED_PIPELINE: Total Execution Time:
+// NESTED_PIPELINE: Name
+// NESTED_PIPELINE-NEXT: Parser
+// NESTED_PIPELINE-NEXT: Pipeline Collection : ['func', 'module']
+// NESTED_PIPELINE-NEXT:   'func' Pipeline
+// NESTED_PIPELINE-NEXT:     TestFunctionPass
+// NESTED_PIPELINE-NEXT:   'module' Pipeline
+// NESTED_PIPELINE-NEXT:     TestModulePass
+// NESTED_PIPELINE-NEXT:     'func' Pipeline
+// NESTED_PIPELINE-NEXT:       TestFunctionPass
+// NESTED_PIPELINE-NEXT: Output
+// NESTED_PIPELINE-NEXT: Rest
+// NESTED_PIPELINE-NEXT: Total
 
 func @foo() {
   return

diff  --git a/mlir/test/Pass/pipeline-parsing.mlir b/mlir/test/Pass/pipeline-parsing.mlir
index f83de945a3120..d311db2657c7c 100644
--- a/mlir/test/Pass/pipeline-parsing.mlir
+++ b/mlir/test/Pass/pipeline-parsing.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -pass-pipeline='module(test-module-pass,func(test-function-pass)),func(test-function-pass)' -pass-pipeline="func(cse,canonicalize)" -verify-each=false -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck %s
-// RUN: mlir-opt %s -test-textual-pm-nested-pipeline -verify-each=false -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck %s --check-prefix=TEXTUAL_CHECK
+// RUN: mlir-opt %s -mlir-disable-threading -pass-pipeline='module(test-module-pass,func(test-function-pass)),func(test-function-pass)' -pass-pipeline="func(cse,canonicalize)" -verify-each=false -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck %s
+// RUN: mlir-opt %s -mlir-disable-threading -test-textual-pm-nested-pipeline -verify-each=false -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck %s --check-prefix=TEXTUAL_CHECK
 // RUN: not mlir-opt %s -pass-pipeline='module(test-module-pass' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_1 %s
 // RUN: not mlir-opt %s -pass-pipeline='module(test-module-pass))' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_2 %s
 // RUN: not mlir-opt %s -pass-pipeline='module()(' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_3 %s


        


More information about the llvm-commits mailing list