[Mlir-commits] [mlir] Enable LICM for ops with only read side effects in scf.for (PR #120302)

donald chen llvmlistbot at llvm.org
Thu Dec 19 01:24:59 PST 2024


================
@@ -395,6 +395,83 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {
 
 std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
 
+FailureOr<std::pair<Operation *, Region *>> ForOp::wrapInTripCountCheck() {
+  auto lowerBound = this->getLowerBound();
+  auto upperBound = this->getUpperBound();
+  auto step = this->getStep();
+  auto initArgs = this->getInitArgs();
+  auto results = this->getResults();
+  auto loc = this->getLoc();
+
+  IRRewriter rewriter(this->getContext());
+  OpBuilder::InsertionGuard insertGuard(rewriter);
+  rewriter.setInsertionPointAfter(this->getOperation());
+
+  // Form the trip count calculation
+  auto subOp = rewriter.create<arith::SubIOp>(loc, upperBound, lowerBound);
+  auto ceilDivSIOp = rewriter.create<arith::CeilDivSIOp>(loc, subOp, step);
+  Value zero;
+  if (upperBound.getType().isIndex()) {
+    zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+  } else {
+    zero = rewriter.create<arith::ConstantIntOp>(
+        loc, 0,
+        /*width=*/
+        upperBound.getType().getIntOrFloatBitWidth());
+  }
+  auto cmpIOp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
+                                               ceilDivSIOp, zero);
+  scf::YieldOp yieldInThen;
+  // Create the trip-count check
+  auto ifOp = rewriter.create<scf::IfOp>(
+      loc, cmpIOp,
+      [&](OpBuilder &builder, Location loc) {
+        yieldInThen = builder.create<scf::YieldOp>(loc, results);
+      },
+      [&](OpBuilder &builder, Location loc) {
+        builder.create<scf::YieldOp>(loc, initArgs);
+      });
+
+  for (auto [forOpResult, ifOpResult] : llvm::zip(results, ifOp.getResults()))
+    rewriter.replaceAllUsesExcept(forOpResult, ifOpResult, yieldInThen);
+  // Move the scf.for into the then block
+  rewriter.moveOpBefore(this->getOperation(), yieldInThen);
+  return std::make_pair(ifOp.getOperation(), &this->getRegion());
+}
+
+LogicalResult ForOp::unwrapTripCountCheck() {
+  auto ifOp = (*this)->getParentRegion()->getParentOp();
+  if (!isa<scf::IfOp>(ifOp))
+    return failure();
+
+  auto wrappedForOp = this->getOperation();
+
+  IRRewriter rewriter(ifOp->getContext());
+  OpBuilder::InsertionGuard insertGuard(rewriter);
+  rewriter.setInsertionPoint(ifOp);
----------------
cxy-1993 wrote:

This requires that this function be called immediately after wrapInTripCountCheck. How can this be guaranteed?

https://github.com/llvm/llvm-project/pull/120302


More information about the Mlir-commits mailing list