[Mlir-commits] [mlir] [MLIR][SCF] Add support for loop pipeline peeling for dynamic loops. (PR #106436)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 3 16:27:17 PDT 2024


================
@@ -640,70 +636,113 @@ LogicalResult LoopPipelinerInternal::createKernel(
   return success();
 }
 
-void LoopPipelinerInternal::emitEpilogue(
-    RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
+LogicalResult
+LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
+                                    llvm::SmallVector<Value> &returnValues) {
+  Location loc = forOp.getLoc();
   // Emit different versions of the induction variable. They will be
   // removed by dead code if not used.
+
+  // bounds_range = ub - lb
+  // total_iterations = (bounds_range + step - 1) / step
+  Type t = lb.getType();
+  Value minus1 =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+  Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
+  Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step);
+  Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
+  Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
+
+  SmallVector<Value> predicates(maxStage + 1);
   for (int64_t i = 0; i < maxStage; i++) {
-    Location loc = forOp.getLoc();
-    Type t = lb.getType();
-    Value minusOne =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
-    // number of iterations = ((ub - 1) - lb) / step
-    Value totalNumIteration = rewriter.create<arith::DivUIOp>(
-        loc,
-        rewriter.create<arith::SubIOp>(
-            loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
-        step);
-    // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
+    // iterI = total_iters - 1 - i
+    // May go negative...
     Value minusI =
         rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
+    Value iterI = rewriter.create<arith::AddIOp>(
+        loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
+        minusI);
+    // newLastIter = lb + step * iterI
     Value newlastIter = rewriter.create<arith::AddIOp>(
-        loc, lb,
-        rewriter.create<arith::MulIOp>(
-            loc, step,
-            rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
+        loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
+
     setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+
+    if (dynamicLoop) {
+      // pred = iterI >= lb
+      predicates[i + 1] = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::sge, iterI, lb);
+    }
   }
+
   // Emit `maxStage - 1` epilogue part that includes operations from stages
   // [i; maxStage].
   for (int64_t i = 1; i <= maxStage; i++) {
+    SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
     for (Operation *op : opOrder) {
       if (stages[op] < i)
         continue;
+      unsigned currentVersion = maxStage - stages[op] + i;
+      unsigned nextVersion = currentVersion + 1;
       Operation *newOp =
           cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
             auto it = valueMapping.find(newOperand->get());
             if (it != valueMapping.end()) {
-              Value replacement = it->second[maxStage - stages[op] + i];
+              Value replacement = it->second[currentVersion];
               newOperand->set(replacement);
             }
           });
+      if (dynamicLoop) {
+        OpBuilder::InsertionGuard insertGuard(rewriter);
+        newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
+        if (!newOp)
+          return failure();
+      }
       if (annotateFn)
         annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
-      for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
-        setValueMapping(op->getResult(destId), newOp->getResult(destId),
-                        maxStage - stages[op] + i);
+
+      for (auto [opRes, newRes] :
+           llvm::zip(op->getResults(), newOp->getResults())) {
+        setValueMapping(opRes, newRes, currentVersion);
         // If the value is a loop carried dependency update the loop argument
         // mapping and keep track of the last version to replace the original
         // forOp uses.
         for (OpOperand &operand :
              forOp.getBody()->getTerminator()->getOpOperands()) {
-          if (operand.get() != op->getResult(destId))
+          if (operand.get() != opRes)
             continue;
-          unsigned version = maxStage - stages[op] + i + 1;
           // If the version is greater than maxStage it means it maps to the
           // original forOp returned value.
-          if (version > maxStage) {
-            returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
-            continue;
-          }
-          setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
-                          newOp->getResult(destId), version);
+          unsigned ri = operand.getOperandNumber();
+          returnValues[ri] = newRes;
+          Value mapVal = forOp.getRegionIterArgs()[ri];
+          returnMap[ri] = std::make_pair(mapVal, currentVersion);
+          if (nextVersion <= maxStage)
+            setValueMapping(mapVal, newRes, nextVersion);
+        }
+      }
+    }
+    if (dynamicLoop) {
+      // Select return values from this stage (live outs) based on predication.
+      // If the stage is valid select the peeled value, else use previous stage
+      // value.
+      for (auto pair : llvm::enumerate(returnValues)) {
+        unsigned ri = pair.index();
+        auto [mapVal, currentVersion] = returnMap[ri];
+        if (mapVal) {
----------------
sjw36 wrote:

For example:
```
func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
  %cf0 = arith.constant 1.0 : f32
  %cf1 = arith.constant 33.0 : f32
  %cst = arith.constant 0 : index
  %res:1 = scf.for %i0 = %lb to %ub step %step iter_args (%arg0 = %cf0) -> (f32) {
    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
    %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
    %A2_elem = arith.mulf %A1_elem, %cf1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
    scf.yield %A2_elem : f32
  } { __test_pipelining_loop__ }
  memref.store %res#0, %result[%cst] : memref<?xf32>
  return
}
```
I see now the example predicates every operation using the predicateFn, not just the side-effecting ops. So this becomes:
```
  func.func @dynamic_loop_result(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: index, %arg3: index, %arg4: index) {
    %c-1 = arith.constant -1 : index
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant 1.000000e+00 : f32
    %cst_1 = arith.constant 3.300000e+01 : f32
    %c0 = arith.constant 0 : index
    %0 = arith.cmpi slt, %arg2, %arg3 : index
    %1 = scf.if %0 -> (f32) {
      %13 = memref.load %arg0[%arg2] : memref<?xf32>
      scf.yield %13 : f32
    } else {
      scf.yield %cst : f32
    }
    %2 = arith.subi %arg3, %arg4 : index
    %3:2 = scf.for %arg5 = %arg2 to %2 step %arg4 iter_args(%arg6 = %cst_0, %arg7 = %1) -> (f32, f32) {
      %13 = arith.addf %arg7, %arg6 : f32
      %14 = arith.mulf %13, %cst_1 : f32
      %15 = arith.addi %arg5, %arg4 : index
      %16 = memref.load %arg0[%15] : memref<?xf32>
      scf.yield %14, %16 : f32, f32
    }
    %4 = arith.subi %arg3, %arg2 : index
    %5 = arith.addi %4, %arg4 : index
    %6 = arith.addi %5, %c-1 : index
    %7 = arith.divui %6, %arg4 : index
    %8 = arith.addi %7, %c-1 : index
    %9 = arith.cmpi sge, %8, %arg2 : index
    %10 = scf.if %9 -> (f32) {
      %13 = arith.addf %3#1, %3#0 : f32
      scf.yield %13 : f32
    } else {
      scf.yield %cst : f32
    }
    %11 = scf.if %9 -> (f32) {
      %13 = arith.mulf %10, %cst_1 : f32
      scf.yield %13 : f32
    } else {
      scf.yield %cst : f32
    }
    %12 = arith.select %9, %11, %3#0 : f32   /// redundant
    memref.store %12, %arg1[%c0] : memref<?xf32>
    return
  }
```
As you can see every operations is guarded (including ops that do not produce a loop result). And it doesn't really do speculative execution then.

If only side-effecting ops are guarded and only results are selected based on stage range, results would be:
```
  func.func @dynamic_loop_result(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: index, %arg3: index, %arg4: index) {
    %c-1 = arith.constant -1 : index
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant 1.000000e+00 : f32
    %cst_1 = arith.constant 3.300000e+01 : f32
    %c0 = arith.constant 0 : index
    %0 = arith.cmpi slt, %arg2, %arg3 : index
    %1 = scf.if %0 -> (f32) {
      %13 = memref.load %arg0[%arg2] : memref<?xf32>
      scf.yield %13 : f32
    } else {
      scf.yield %cst : f32
    }
    %2 = arith.subi %arg3, %arg4 : index
    %3:2 = scf.for %arg5 = %arg2 to %2 step %arg4 iter_args(%arg6 = %cst_0, %arg7 = %1) -> (f32, f32) {
      %13 = arith.addf %arg7, %arg6 : f32
      %14 = arith.mulf %13, %cst_1 : f32
      %15 = arith.addi %arg5, %arg4 : index
      %16 = memref.load %arg0[%15] : memref<?xf32>
      scf.yield %14, %16 : f32, f32
    }
    %4 = arith.subi %arg3, %arg2 : index
    %5 = arith.addi %4, %arg4 : index
    %6 = arith.addi %5, %c-1 : index
    %7 = arith.divui %6, %arg4 : index
    %8 = arith.addi %7, %c-1 : index
    %9 = arith.cmpi sge, %8, %arg2 : index
    %10 = arith.addf %3#1, %3#0 : f32
    %11 = arith.mulf %10, %cst_1 : f32
    %12 = arith.select %9, %11, %3#0 : f32
    memref.store %12, %arg1[%c0] : memref<?xf32>
    return
  }
```

And this seems to be what the Prologue logic is doing as well (see line 343).

https://github.com/llvm/llvm-project/pull/106436


More information about the Mlir-commits mailing list