[Mlir-commits] [mlir] [mlir][Hoisting] Hoisting vector.extract/vector.broadcast pairs (PR #86108)

Steven Varoumas llvmlistbot at llvm.org
Fri Apr 12 03:43:01 PDT 2024


================
@@ -43,6 +43,120 @@ using llvm::dbgs;
 using namespace mlir;
 using namespace mlir::linalg;
 
+scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
+                                     Value newInitOperand, int index,
+                                     Value newYieldValue) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(loop.getOperation());
+  auto inits = llvm::to_vector(loop.getInits());
+
+  // Replace the init value with the new operand
+  inits[index] = newInitOperand;
+
+  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
+      loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
+      inits, [](OpBuilder &, Location, Value, ValueRange) {});
+
+  // Generate the new yield with the replaced operand
+  auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
+  yieldOp->getOperand(index).replaceAllUsesWith(newYieldValue);
+
+  // Move the loop body to the new op.
+  rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
+                       newLoop.getBody()->getArguments().take_front(
+                           loop.getBody()->getNumArguments()));
+
+  // Replace the old loop.
+  rewriter.replaceOp(loop.getOperation(),
+                     newLoop->getResults().take_front(loop.getNumResults()));
+  return newLoop;
+}
+
+// Hoist out a pair of corresponding vector.extract+vector.broadcast
+// operations. This function transforms a loop like this:
+//  %loop = scf.for _ = _ to _ step _ iter_args(%iterarg = %v) -> (t1) {
+//   %e = vector.extract %iterarg : t1 to t2
+//   %u = // do something with %e : t2
+//   %b = vector.broadcast %u : t2 to t1
+//   scf.yield %b : t1
+//  }
+// into the following:
+//  %e = vector.extract %v: t1 to t2
+//  %loop' = scf.for _ = _ to _ step _ iter_args(%iterarg = %e) -> (t2) {
+//   %u' = // do something with %iterarg : t2
+//   scf.yield %u' : t2
+//  }
+//  %loop = vector.broadcast %loop' : t2 to t1
+void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
+  bool changed = true;
+  while (changed) {
+    changed = false;
+    // First move loop invariant ops outside of their loop. This needs to be
+    // done before as we cannot move ops without interrupting the function walk.
+    root->walk(
+        [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
+
+    root->walk([&](vector::ExtractOp extractOp) {
+      LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
+                        << *extractOp.getOperation() << "\n");
+
+      auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
+      if (!loop)
+        return WalkResult::advance();
+
+      // Check that the vector to extract from is an iter_arg
+      auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
+      if (!blockArg)
+        return WalkResult::advance();
+
+      // If the iter_arg does not have only one use, it won't be possible to
+      // hoist the extractOp out.
+      if (!blockArg.hasOneUse())
+        return WalkResult::advance();
+
+      auto initArg = loop.getTiedLoopInit(blockArg)->get();
+      auto index = blockArg.getArgNumber() - loop.getNumInductionVars();
+
+      // Check that the loop yields a broadcast
+      auto lastOp = loop.getBody()->getTerminator();
+      auto yieldOp = dyn_cast<scf::YieldOp>(lastOp);
+      if (!yieldOp)
+        return WalkResult::advance();
+
+      auto broadcast = dyn_cast<vector::BroadcastOp>(
+          yieldOp->getOperand(index).getDefiningOp());
+
+      LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
+
+      Type broadcastInputType = broadcast.getSourceType();
+      if (broadcastInputType != extractOp.getType())
+        return WalkResult::advance();
+
+      // The position of the extract must be defined outside of the loop if
----------------
stevenvar wrote:

We need this because otherwise we would hoist the extract outside of the loop but its parameter (the position) would be defined inside the loop, which would break the assertion that all operands of an op must be defined prior to this op.

https://github.com/llvm/llvm-project/pull/86108


More information about the Mlir-commits mailing list