[Mlir-commits] [mlir] d01ea0e - [mlir] Drop reliance of SliceAnalysis on specific ops.

Nicolas Vasilache llvmlistbot at llvm.org
Mon Feb 15 22:40:46 PST 2021


Author: Nicolas Vasilache
Date: 2021-02-16T06:34:32Z
New Revision: d01ea0edaa2e38e1345dc484f8b74e0e53d3245b

URL: https://github.com/llvm/llvm-project/commit/d01ea0edaa2e38e1345dc484f8b74e0e53d3245b
DIFF: https://github.com/llvm/llvm-project/commit/d01ea0edaa2e38e1345dc484f8b74e0e53d3245b.diff

LOG: [mlir] Drop reliance of SliceAnalysis on specific ops.

SliceAnalysis originally was developed in the context of affine.for within mlfunc.
It predates the notion of region.
This revision updates it to not hardcode specific ops like scf::ForOp.
When rooted at an op, the behavior of the slice computation changes as it recurses into the regions of the op. This does not support gathering all values transitively depending on a loop induction variable anymore.
Additional variants rooted at a Value are added to also support the existing behavior.

Differential revision: https://reviews.llvm.org/D96702

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/SliceAnalysis.h
    mlir/lib/Analysis/SliceAnalysis.cpp
    mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
    mlir/lib/Transforms/Utils/LoopUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h
index 8d1bff5f7db7..f418684b6319 100644
--- a/mlir/include/mlir/Analysis/SliceAnalysis.h
+++ b/mlir/include/mlir/Analysis/SliceAnalysis.h
@@ -19,12 +19,13 @@
 namespace mlir {
 
 class Operation;
+class Value;
 
 /// Type of the condition to limit the propagation of transitive use-defs.
 /// This can be used in particular to limit the propagation to a given Scope or
 /// to avoid passing through certain types of operation in a configurable
 /// manner.
-using TransitiveFilter = std::function<bool(Operation *)>;
+using TransitiveFilter = llvm::function_ref<bool(Operation *)>;
 
 /// Fills `forwardSlice` with the computed forward slice (i.e. all
 /// the transitive uses of op), **without** including that operation.
@@ -67,10 +68,13 @@ using TransitiveFilter = std::function<bool(Operation *)>;
 /// 2. reversing the result of 1. gives:
 ///      {4, 3, 6, 2, 1, 5, 8, 7, 9}
 ///
-void getForwardSlice(
-    Operation *op, llvm::SetVector<Operation *> *forwardSlice,
-    TransitiveFilter filter = /* pass-through*/
-    [](Operation *) { return true; });
+void getForwardSlice(Operation *op, llvm::SetVector<Operation *> *forwardSlice,
+                     TransitiveFilter filter = nullptr /* pass-through*/);
+
+/// Value-rooted version of `getForwardSlice`. Return the union of all forward
+/// slices for the uses of the value `root`.
+void getForwardSlice(Value root, llvm::SetVector<Operation *> *forwardSlice,
+                     TransitiveFilter filter = nullptr /* pass-through*/);
 
 /// Fills `backwardSlice` with the computed backward slice (i.e.
 /// all the transitive defs of op), **without** including that operation.
@@ -106,10 +110,14 @@ void getForwardSlice(
 /// Assuming all local orders match the numbering order:
 ///    {1, 2, 5, 3, 4, 6}
 ///
-void getBackwardSlice(
-    Operation *op, llvm::SetVector<Operation *> *backwardSlice,
-    TransitiveFilter filter = /* pass-through*/
-    [](Operation *) { return true; });
+void getBackwardSlice(Operation *op,
+                      llvm::SetVector<Operation *> *backwardSlice,
+                      TransitiveFilter filter = nullptr /* pass-through*/);
+
+/// Value-rooted version of `getBackwardSlice`. Return the union of all backward
+/// slices for the op defining or owning the value `root`.
+void getBackwardSlice(Value root, llvm::SetVector<Operation *> *backwardSlice,
+                      TransitiveFilter filter = nullptr /* pass-through*/);
 
 /// Iteratively computes backward slices and forward slices until
 /// a fixed point is reached. Returns an `llvm::SetVector<Operation *>` which
@@ -188,12 +196,10 @@ void getBackwardSlice(
 /// and keep things ordered but this is still hand-wavy and not worth the
 /// trouble for now: punt to a simple worklist-based solution.
 ///
-llvm::SetVector<Operation *> getSlice(
-    Operation *op,
-    TransitiveFilter backwardFilter = /* pass-through*/
-    [](Operation *) { return true; },
-    TransitiveFilter forwardFilter = /* pass-through*/
-    [](Operation *) { return true; });
+llvm::SetVector<Operation *>
+getSlice(Operation *op,
+         TransitiveFilter backwardFilter = nullptr /* pass-through*/,
+         TransitiveFilter forwardFilter = nullptr /* pass-through*/);
 
 /// Multi-root DAG topological sort.
 /// Performs a topological sort of the Operation in the `toSort` SetVector.

diff  --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp
index 07cbca8298c4..47d12582c10c 100644
--- a/mlir/lib/Analysis/SliceAnalysis.cpp
+++ b/mlir/lib/Analysis/SliceAnalysis.cpp
@@ -30,36 +30,24 @@ using llvm::SetVector;
 static void getForwardSliceImpl(Operation *op,
                                 SetVector<Operation *> *forwardSlice,
                                 TransitiveFilter filter) {
-  if (!op) {
+  if (!op)
     return;
-  }
 
   // Evaluate whether we should keep this use.
   // This is useful in particular to implement scoping; i.e. return the
   // transitive forwardSlice in the current scope.
-  if (!filter(op)) {
+  if (filter && !filter(op))
     return;
-  }
 
-  if (auto forOp = dyn_cast<AffineForOp>(op)) {
-    for (Operation *userOp : forOp.getInductionVar().getUsers())
+  for (Region &region : op->getRegions())
+    for (Block &block : region)
+      for (Operation &blockOp : block)
+        if (forwardSlice->count(&blockOp) == 0)
+          getForwardSliceImpl(&blockOp, forwardSlice, filter);
+  for (Value result : op->getResults()) {
+    for (Operation *userOp : result.getUsers())
       if (forwardSlice->count(userOp) == 0)
         getForwardSliceImpl(userOp, forwardSlice, filter);
-  } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
-    for (Operation *userOp : forOp.getInductionVar().getUsers())
-      if (forwardSlice->count(userOp) == 0)
-        getForwardSliceImpl(userOp, forwardSlice, filter);
-    for (Value result : forOp.getResults())
-      for (Operation *userOp : result.getUsers())
-        if (forwardSlice->count(userOp) == 0)
-          getForwardSliceImpl(userOp, forwardSlice, filter);
-  } else {
-    assert(op->getNumRegions() == 0 && "unexpected generic op with regions");
-    for (Value result : op->getResults()) {
-      for (Operation *userOp : result.getUsers())
-        if (forwardSlice->count(userOp) == 0)
-          getForwardSliceImpl(userOp, forwardSlice, filter);
-    }
   }
 
   forwardSlice->insert(op);
@@ -79,45 +67,47 @@ void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
   forwardSlice->insert(v.rbegin(), v.rend());
 }
 
+void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
+                           TransitiveFilter filter) {
+  for (Operation *user : root.getUsers())
+    getForwardSliceImpl(user, forwardSlice, filter);
+
+  // Reverse to get back the actual topological order.
+  // std::reverse does not work out of the box on SetVector and I want an
+  // in-place swap based thing (the real std::reverse, not the LLVM adapter).
+  std::vector<Operation *> v(forwardSlice->takeVector());
+  forwardSlice->insert(v.rbegin(), v.rend());
+}
+
 static void getBackwardSliceImpl(Operation *op,
                                  SetVector<Operation *> *backwardSlice,
                                  TransitiveFilter filter) {
-  if (!op)
+  if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
     return;
 
-  assert((op->getNumRegions() == 0 ||
-          isa<AffineForOp, scf::ForOp, linalg::LinalgOp, linalg::PadTensorOp>(
-              op)) &&
-         "unexpected generic op with regions");
-
   // Evaluate whether we should keep this def.
   // This is useful in particular to implement scoping; i.e. return the
-  // transitive forwardSlice in the current scope.
-  if (!filter(op)) {
+  // transitive backwardSlice in the current scope.
+  if (filter && !filter(op))
     return;
-  }
 
   for (auto en : llvm::enumerate(op->getOperands())) {
     auto operand = en.value();
-    if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
-      if (auto affIv = getForInductionVarOwner(operand)) {
-        auto *affOp = affIv.getOperation();
-        if (backwardSlice->count(affOp) == 0)
-          getBackwardSliceImpl(affOp, backwardSlice, filter);
-      } else if (auto loopIv = scf::getForInductionVarOwner(operand)) {
-        auto *loopOp = loopIv.getOperation();
-        if (backwardSlice->count(loopOp) == 0)
-          getBackwardSliceImpl(loopOp, backwardSlice, filter);
-      } else if (blockArg.getOwner() !=
-                 &op->getParentOfType<FuncOp>().getBody().front()) {
-        op->emitError("unsupported CF for operand ") << en.index();
-        llvm_unreachable("Unsupported control flow");
-      }
-      continue;
-    }
-    auto *op = operand.getDefiningOp();
-    if (backwardSlice->count(op) == 0) {
-      getBackwardSliceImpl(op, backwardSlice, filter);
+    if (auto *definingOp = operand.getDefiningOp()) {
+      if (backwardSlice->count(definingOp) == 0)
+        getBackwardSliceImpl(definingOp, backwardSlice, filter);
+    } else if (auto blockArg = operand.dyn_cast<BlockArgument>()) {
+      Block *block = blockArg.getOwner();
+      Operation *parentOp = block->getParentOp();
+      // TODO: determine whether we want to recurse backward into the other
+      // blocks of parentOp, which are not technically backward unless they flow
+      // into us. For now, just bail.
+      assert(parentOp->getNumRegions() == 1 &&
+             parentOp->getRegion(0).getBlocks().size() == 1);
+      if (backwardSlice->count(parentOp) == 0)
+        getBackwardSliceImpl(parentOp, backwardSlice, filter);
+    } else {
+      llvm_unreachable("No definingOp and not a block argument.");
     }
   }
 
@@ -134,6 +124,16 @@ void mlir::getBackwardSlice(Operation *op,
   backwardSlice->remove(op);
 }
 
+void mlir::getBackwardSlice(Value root, SetVector<Operation *> *backwardSlice,
+                            TransitiveFilter filter) {
+  if (Operation *definingOp = root.getDefiningOp()) {
+    getBackwardSlice(definingOp, backwardSlice, filter);
+    return;
+  }
+  Operation *bbAargOwner = root.cast<BlockArgument>().getOwner()->getParentOp();
+  getBackwardSlice(bbAargOwner, backwardSlice, filter);
+}
+
 SetVector<Operation *> mlir::getSlice(Operation *op,
                                       TransitiveFilter backwardFilter,
                                       TransitiveFilter forwardFilter) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index f3d98f634788..c3bc73aea720 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -243,7 +243,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) {
                         << "\n");
 
       llvm::SetVector<Operation *> forwardSlice;
-      getForwardSlice(transferRead, &forwardSlice);
+      getForwardSlice(transferRead.getOperation(), &forwardSlice);
 
       // Look for the last TransferWriteOp in the forwardSlice of
       // `transferRead` that operates on the same memref.
@@ -381,9 +381,10 @@ hoistPaddingOnTensorsPrerequisites(linalg::PadTensorOp padTensorOp, int nLevels,
   // Get the backwards slice from `padTensorOp` that is dominated by the
   // outermost enclosing loop.
   DominanceInfo domInfo(outermostEnclosingForOp);
-  getBackwardSlice(padTensorOp, &backwardSlice, [&](Operation *op) {
-    return domInfo.dominates(outermostEnclosingForOp, op);
-  });
+  getBackwardSlice(padTensorOp.getOperation(), &backwardSlice,
+                   [&](Operation *op) {
+                     return domInfo.dominates(outermostEnclosingForOp, op);
+                   });
 
   // Bail on any op with a region that is not a LoopLikeInterface or a LinalgOp.
   if (llvm::any_of(backwardSlice, [](Operation *op) {

diff  --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index d8bc6e0466c5..a8c32c84d124 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -1830,9 +1830,9 @@ Loops mlir::tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes) {
 // Return failure when any op fails to hoist.
 static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
   SetVector<Operation *> forwardSlice;
-  getForwardSlice(outer.getOperation(), &forwardSlice, [&inner](Operation *op) {
-    return op != inner.getOperation();
-  });
+  getForwardSlice(
+      outer.getInductionVar(), &forwardSlice,
+      [&inner](Operation *op) { return op != inner.getOperation(); });
   LogicalResult status = success();
   SmallVector<Operation *, 8> toHoist;
   for (auto &op : outer.getBody()->without_terminator()) {
@@ -1844,8 +1844,8 @@ static LogicalResult hoistOpsBetween(scf::ForOp outer, scf::ForOp inner) {
       status = failure();
       continue;
     }
-    // Skip scf::ForOp, these are not considered a failure.
-    if (op.getNumRegions() > 0)
+    // Skip intermediate scf::ForOp, these are not considered a failure.
+    if (isa<scf::ForOp>(op))
       continue;
     // Skip other ops with regions.
     if (op.getNumRegions() > 0) {


        


More information about the Mlir-commits mailing list