[Mlir-commits] [mlir] [mlir][Hoisting] Hoisting vector.extract/vector.broadcast pairs (PR #86108)

Steven Varoumas llvmlistbot at llvm.org
Fri Apr 19 03:35:26 PDT 2024


================
@@ -43,6 +43,121 @@ using llvm::dbgs;
 using namespace mlir;
 using namespace mlir::linalg;
 
+scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
+                                     Value newInitOperand, int index,
+                                     Value newYieldValue) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(loop.getOperation());
+  auto inits = llvm::to_vector(loop.getInits());
+
+  // Replace the init value with the new operand
+  inits[index] = newInitOperand;
+
+  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
+      loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
+      inits, [](OpBuilder &, Location, Value, ValueRange) {});
+
+  // Generate the new yield with the replaced operand
+  auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
+  rewriter.replaceAllUsesWith(yieldOp->getOperand(index), newYieldValue);
+
+  // Move the loop body to the new op.
+  rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
+                       newLoop.getBody()->getArguments().take_front(
+                           loop.getBody()->getNumArguments()));
+
+  // Replace the old loop.
+  rewriter.replaceOp(loop.getOperation(),
+                     newLoop->getResults().take_front(loop.getNumResults()));
+  return newLoop;
+}
+
+// Hoist out a pair of corresponding vector.extract+vector.broadcast
+// operations. This function transforms a loop like this:
+//  %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
+//   %e = vector.extract %iarg : t1 to t2
+//   %u = "some_use"(%e) : (t2) -> t2
+//   %b = vector.broadcast %u : t2 to t1
+//   scf.yield %b : t1
+//  }
+// into the following:
+//  %e = vector.extract %v: t1 to t2
+//  %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
+//   %u' = "some_use"(%iarg) : (t2) -> t2
+//   scf.yield %u' : t2
+//  }
+//  %res = vector.broadcast %res' : t2 to t1
+void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
+  bool changed = true;
+  while (changed) {
+    changed = false;
+    // First move loop invariant ops outside of their loop. This needs to be
+    // done before as we cannot move ops without interrupting the function walk.
+    root->walk(
+        [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
+
+    root->walk([&](vector::ExtractOp extractOp) {
+      LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
+                        << *extractOp.getOperation() << "\n");
+
+      auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
+      if (!loop)
+        return WalkResult::advance();
+
+      // Check that the vector to extract from is an iter_arg
+      auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
+      if (!blockArg)
+        return WalkResult::advance();
+
+      // If the iter_arg does not have only one use, it won't be possible to
+      // hoist the extractOp out.
+      if (!blockArg.hasOneUse())
+        return WalkResult::advance();
+
+      auto initArg = loop.getTiedLoopInit(blockArg)->get();
+      auto index = blockArg.getArgNumber() - loop.getNumInductionVars();
+
+      // Check that the loop yields a broadcast
+      auto yieldedVal =
+          loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
+      auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
+      if (!broadcast)
+        return WalkResult::advance();
+
+      LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
+
+      Type broadcastInputType = broadcast.getSourceType();
+      if (broadcastInputType != extractOp.getType())
+        return WalkResult::advance();
+
+      // The position of the extract must be defined outside of the loop if
+      // it is dynamic
+      for (auto operand : extractOp.getDynamicPosition())
+        if (!loop.isDefinedOutsideOfLoop(operand))
+          return WalkResult::advance();
+
+      IRRewriter rewriter(extractOp.getContext());
+
+      extractOp.getVectorMutable().assign(initArg);
+      loop.moveOutOfLoop(extractOp);
+      rewriter.moveOpAfter(broadcast, loop);
+
+      auto newLoop = replaceWithDifferentYield(
+          rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
+
+      LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
+
+      rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
+      rewriter.modifyOpInPlace(broadcast, [&] {
+        broadcast.getSourceMutable().assign(newLoop.getResult(index));
----------------
stevenvar wrote:

No `setSource`, but I missed `setOperand` which does the same thing -> changed!

https://github.com/llvm/llvm-project/pull/86108


More information about the Mlir-commits mailing list