[Mlir-commits] [mlir] Enable LICM for ops with only read side effects in scf.for (PR #120302)
donald chen
llvmlistbot at llvm.org
Thu Dec 19 01:24:59 PST 2024
================
@@ -395,6 +395,83 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
+FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {
+ auto lowerBound = this->getLowerBound();
+ auto upperBound = this->getUpperBound();
+ auto step = this->getStep();
+ auto initArgs = this->getInitArgs();
+ auto results = this->getResults();
+ auto loc = this->getLoc();
+
+ IRRewriter rewriter(this->getContext());
+ OpBuilder::InsertionGuard insertGuard(rewriter);
+ rewriter.setInsertionPointAfter(this->getOperation());
+
+ // Form the trip count calculation
+ auto subOp = rewriter.create<arith::SubIOp>(loc, upperBound, lowerBound);
+ auto ceilDivSIOp = rewriter.create<arith::CeilDivSIOp>(loc, subOp, step);
+ Value zero;
+ if (upperBound.getType().isIndex()) {
+ zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ } else {
+ zero = rewriter.create<arith::ConstantIntOp>(
+ loc, 0,
+ /*width=*/
+ upperBound.getType().getIntOrFloatBitWidth());
+ }
+ auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
+ ceilDivSIOp, zero);
+ scf::YieldOp yieldInThen;
+ // Create the trip-count check
+ auto ifOp = rewriter.create<scf::IfOp>(
+ loc, cmpIOp,
+ [&](OpBuilder &builder, Location loc) {
+ yieldInThen = builder.create<scf::YieldOp>(loc, results);
+ },
+ [&](OpBuilder &builder, Location loc) {
+ builder.create<scf::YieldOp>(loc, initArgs);
+ });
+
+ for (auto [forOpResult, ifOpResult] : llvm::zip(results, ifOp.getResults()))
+ rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
+ // Move the scf.for into the then block
+ rewriter.moveOpBefore(this->getOperation(), yieldInThen);
+ return std::make_pair(ifOp.getOperation(), &this->getRegion());
+}
+
+LogicalResult ForOp::unwrapTripCountCheck() {
+ auto ifOp = (*this)->getParentRegion()->getParentOp();
+ if (!isa<scf::IfOp>(ifOp))
+ return failure();
+
+ auto wrappedForOp = this->getOperation();
+
+ IRRewriter rewriter(ifOp->getContext());
+ OpBuilder::InsertionGuard insertGuard(rewriter);
+ rewriter.setInsertionPoint(ifOp);
----------------
cxy-1993 wrote:
This requires that this function be called immediately after wrapInTripCountCheck. How can this be guaranteed?
https://github.com/llvm/llvm-project/pull/120302
More information about the Mlir-commits
mailing list