[Mlir-commits] [mlir] [mlir][Hoisting] Hoisting vector.extract/vector.broadcast pairs (PR #86108)
Matthias Springer
llvmlistbot at llvm.org
Fri Apr 12 00:23:05 PDT 2024
================
@@ -43,6 +43,120 @@ using llvm::dbgs;
using namespace mlir;
using namespace mlir::linalg;
+scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp loop,
+ Value newInitOperand, int index,
+ Value newYieldValue) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(loop.getOperation());
+ auto inits = llvm::to_vector(loop.getInits());
+
+ // Replace the init value with the new operand
+ inits[index] = newInitOperand;
+
+ scf::ForOp newLoop = rewriter.create<scf::ForOp>(
+ loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
+ inits, [](OpBuilder &, Location, Value, ValueRange) {});
+
+ // Generate the new yield with the replaced operand
+ auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
+ yieldOp->getOperand(index).replaceAllUsesWith(newYieldValue);
+
+ // Move the loop body to the new op.
+ rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
+ newLoop.getBody()->getArguments().take_front(
+ loop.getBody()->getNumArguments()));
+
+ // Replace the old loop.
+ rewriter.replaceOp(loop.getOperation(),
+ newLoop->getResults().take_front(loop.getNumResults()));
+ return newLoop;
+}
+
+// Hoist out a pair of corresponding vector.extract+vector.broadcast
+// operations. This function transforms a loop like this:
+// %loop = scf.for _ = _ to _ step _ iter_args(%iterarg = %v) -> (t1) {
+// %e = vector.extract %iterarg : t1 to t2
+// %u = // do something with %e : t2
+// %b = vector.broadcast %u : t2 to t1
+// scf.yield %b : t1
+// }
+// into the following:
+// %e = vector.extract %v: t1 to t2
+// %loop' = scf.for _ = _ to _ step _ iter_args(%iterarg = %e) -> (t2) {
+// %u' = // do something with %iterarg : t2
+// scf.yield %u' : t2
+// }
+// %loop = vector.broadcast %loop' : t2 to t1
+void mlir::linalg::hoistRedundantVectorBroadcasts(Operation *root) {
+ bool changed = true;
+ while (changed) {
+ changed = false;
+ // First move loop invariant ops outside of their loop. This needs to be
+ // done before as we cannot move ops without interrupting the function walk.
+ root->walk(
+ [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
+
+ root->walk([&](vector::ExtractOp extractOp) {
+ LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
+ << *extractOp.getOperation() << "\n");
+
+ auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
+ if (!loop)
+ return WalkResult::advance();
+
+ // Check that the vector to extract from is an iter_arg
+ auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
+ if (!blockArg)
+ return WalkResult::advance();
+
+ // If the iter_arg does not have only one use, it won't be possible to
+ // hoist the extractOp out.
+ if (!blockArg.hasOneUse())
+ return WalkResult::advance();
+
+ auto initArg = loop.getTiedLoopInit(blockArg)->get();
+ auto index = blockArg.getArgNumber() - loop.getNumInductionVars();
+
+ // Check that the loop yields a broadcast
+ auto lastOp = loop.getBody()->getTerminator();
+ auto yieldOp = dyn_cast<scf::YieldOp>(lastOp);
+ if (!yieldOp)
+ return WalkResult::advance();
+
+ auto broadcast = dyn_cast<vector::BroadcastOp>(
+ yieldOp->getOperand(index).getDefiningOp());
+
+ LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
+
+ Type broadcastInputType = broadcast.getSourceType();
+ if (broadcastInputType != extractOp.getType())
+ return WalkResult::advance();
+
+ // The position of the extract must be defined outside of the loop if
+ // it is dynamic
+ for (auto operand : extractOp.getDynamicPosition())
+ if (!loop.isDefinedOutsideOfLoop(operand))
+ return WalkResult::advance();
+
+ extractOp.getVectorMutable().assign(initArg);
+ loop.moveOutOfLoop(extractOp);
+ broadcast->moveAfter(loop);
+
+ IRRewriter rewriter(extractOp.getContext());
+ auto newLoop = replaceWithDifferentYield(
+ rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
+
+ LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
+
+ newLoop.getResult(index).replaceAllUsesWith(broadcast);
+ broadcast.getSourceMutable().assign(newLoop.getResult(index));
----------------
matthias-springer wrote:
Should be wrapped in `rewriter.modifyOpInPlace` for consistency. (We may start listening for such IR modifications at some point.)
https://github.com/llvm/llvm-project/pull/86108
More information about the Mlir-commits
mailing list