[Mlir-commits] [mlir] [MLIR][SCF] Add support for loop pipeline peeling for dynamic loops. (PR #106436)

Lei Zhang llvmlistbot at llvm.org
Sun Sep 1 11:58:54 PDT 2024


================
@@ -640,70 +636,113 @@ LogicalResult LoopPipelinerInternal::createKernel(
   return success();
 }
 
-void LoopPipelinerInternal::emitEpilogue(
-    RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
+LogicalResult
+LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
+                                    llvm::SmallVector<Value> &returnValues) {
+  Location loc = forOp.getLoc();
   // Emit different versions of the induction variable. They will be
   // removed by dead code if not used.
+
+  // bounds_range = ub - lb
+  // total_iterations = (bounds_range + step - 1) / step
+  Type t = lb.getType();
+  Value minus1 =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+  Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
+  Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step);
+  Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
+  Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
+
+  SmallVector<Value> predicates(maxStage + 1);
   for (int64_t i = 0; i < maxStage; i++) {
-    Location loc = forOp.getLoc();
-    Type t = lb.getType();
-    Value minusOne =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
-    // number of iterations = ((ub - 1) - lb) / step
-    Value totalNumIteration = rewriter.create<arith::DivUIOp>(
-        loc,
-        rewriter.create<arith::SubIOp>(
-            loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
-        step);
-    // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
+    // iterI = total_iters - 1 - i
+    // May go negative...
     Value minusI =
         rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
+    Value iterI = rewriter.create<arith::AddIOp>(
+        loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
+        minusI);
+    // newLastIter = lb + step * iterI
     Value newlastIter = rewriter.create<arith::AddIOp>(
-        loc, lb,
-        rewriter.create<arith::MulIOp>(
-            loc, step,
-            rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
+        loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
+
     setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+
+    if (dynamicLoop) {
+      // pred = iterI >= lb
+      predicates[i + 1] = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::sge, iterI, lb);
+    }
   }
+
   // Emit `maxStage - 1` epilogue part that includes operations from stages
   // [i; maxStage].
   for (int64_t i = 1; i <= maxStage; i++) {
+    SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
     for (Operation *op : opOrder) {
       if (stages[op] < i)
         continue;
+      unsigned currentVersion = maxStage - stages[op] + i;
+      unsigned nextVersion = currentVersion + 1;
       Operation *newOp =
           cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
             auto it = valueMapping.find(newOperand->get());
             if (it != valueMapping.end()) {
-              Value replacement = it->second[maxStage - stages[op] + i];
+              Value replacement = it->second[currentVersion];
               newOperand->set(replacement);
             }
           });
+      if (dynamicLoop) {
+        OpBuilder::InsertionGuard insertGuard(rewriter);
+        newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
+        if (!newOp)
+          return failure();
+      }
       if (annotateFn)
         annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
-      for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
-        setValueMapping(op->getResult(destId), newOp->getResult(destId),
-                        maxStage - stages[op] + i);
+
+      for (auto [opRes, newRes] :
+           llvm::zip(op->getResults(), newOp->getResults())) {
+        setValueMapping(opRes, newRes, currentVersion);
         // If the value is a loop carried dependency update the loop argument
         // mapping and keep track of the last version to replace the original
         // forOp uses.
         for (OpOperand &operand :
              forOp.getBody()->getTerminator()->getOpOperands()) {
-          if (operand.get() != op->getResult(destId))
+          if (operand.get() != opRes)
             continue;
-          unsigned version = maxStage - stages[op] + i + 1;
           // If the version is greater than maxStage it means it maps to the
           // original forOp returned value.
-          if (version > maxStage) {
-            returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
-            continue;
-          }
-          setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
-                          newOp->getResult(destId), version);
+          unsigned ri = operand.getOperandNumber();
+          returnValues[ri] = newRes;
+          Value mapVal = forOp.getRegionIterArgs()[ri];
+          returnMap[ri] = std::make_pair(mapVal, currentVersion);
+          if (nextVersion <= maxStage)
+            setValueMapping(mapVal, newRes, nextVersion);
+        }
+      }
+    }
+    if (dynamicLoop) {
----------------
antiagainst wrote:

Can we add a test to excercise this case?

https://github.com/llvm/llvm-project/pull/106436


More information about the Mlir-commits mailing list