[Mlir-commits] [mlir] [MLIR][SCF] Add support for loop pipeline peeling for dynamic loops. (PR #106436)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 29 12:46:20 PDT 2024


https://github.com/sjw36 updated https://github.com/llvm/llvm-project/pull/106436

>From 9be66a1a091d268b9d61bceb39d43cd2b66e63bd Mon Sep 17 00:00:00 2001
From: SJW <swaters at amd.com>
Date: Thu, 29 Aug 2024 19:45:57 +0000
Subject: [PATCH] [MLIR][SCF] Add support for loop pipeline peeling for dynamic
 loops. * Allow speculative execution and predicate results per stage.

---
 .../Dialect/SCF/Transforms/LoopPipelining.cpp | 126 ++++++++++++------
 mlir/test/Dialect/SCF/loop-pipelining.mlir    |  43 +++++-
 mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp    |   4 +-
 3 files changed, 129 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index d8e1cc0ecef88e..258e075e263259 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -94,8 +94,8 @@ struct LoopPipelinerInternal {
       RewriterBase &rewriter);
   /// Emits the epilogue, this creates `maxStage - 1` part which will contain
   /// operations from stages [i; maxStage], where i is the part index.
-  void emitEpilogue(RewriterBase &rewriter,
-                    llvm::SmallVector<Value> &returnValues);
+  LogicalResult emitEpilogue(RewriterBase &rewriter,
+                             llvm::SmallVector<Value> &returnValues);
 };
 
 bool LoopPipelinerInternal::initializeLoopInfo(
@@ -133,10 +133,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
     LDBG("--no epilogue or predicate set -> BAIL");
     return false;
   }
-  if (dynamicLoop && peelEpilogue) {
-    LDBG("--dynamic loop doesn't support epilogue yet -> BAIL");
-    return false;
-  }
   std::vector<std::pair<Operation *, unsigned>> schedule;
   options.getScheduleFn(forOp, schedule);
   if (schedule.empty()) {
@@ -313,10 +309,10 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
           });
       int predicateIdx = i - stages[op];
       if (predicates[predicateIdx]) {
+        OpBuilder::InsertionGuard insertGuard(rewriter);
         newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
         assert(newOp && "failed to predicate op.");
       }
-      rewriter.setInsertionPointAfter(newOp);
       if (annotateFn)
         annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
       for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
@@ -561,6 +557,7 @@ LogicalResult LoopPipelinerInternal::createKernel(
     }
 
     if (predicates[useStage]) {
+      OpBuilder::InsertionGuard insertGuard(rewriter);
       newOp = predicateFn(rewriter, newOp, predicates[useStage]);
       if (!newOp)
         return failure();
@@ -568,7 +565,6 @@ LogicalResult LoopPipelinerInternal::createKernel(
       for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
         mapping.map(std::get<0>(values), std::get<1>(values));
     }
-    rewriter.setInsertionPointAfter(newOp);
     if (annotateFn)
       annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
   }
@@ -640,70 +636,123 @@ LogicalResult LoopPipelinerInternal::createKernel(
   return success();
 }
 
-void LoopPipelinerInternal::emitEpilogue(
-    RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
+LogicalResult
+LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
+                                    llvm::SmallVector<Value> &returnValues) {
+  Location loc = forOp.getLoc();
   // Emit different versions of the induction variable. They will be
   // removed by dead code if not used.
+
+  // bounds_range = ub - lb
+  // total_iterations = bounds_range / step + (bounds_range % step ? 1 : 0)
+  Type t = lb.getType();
+  Value minus1 =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+
+  Value const_0 =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+  Value const_1 =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
+  Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
+  Value boundsRem = rewriter.create<arith::RemUIOp>(loc, boundsRange, step);
+  Value hasRem = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
+                                                boundsRem, const_0);
+  Value selRem =
+      rewriter.create<arith::SelectOp>(loc, hasRem, const_1, const_0);
+  Value boundsDiv = rewriter.create<arith::DivUIOp>(loc, boundsRange, step);
+  Value totalIterations =
+      rewriter.create<arith::AddIOp>(loc, boundsDiv, selRem);
+
+  SmallVector<Value> predicates(maxStage + 1);
   for (int64_t i = 0; i < maxStage; i++) {
-    Location loc = forOp.getLoc();
-    Type t = lb.getType();
-    Value minusOne =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
-    // number of iterations = ((ub - 1) - lb) / step
-    Value totalNumIteration = rewriter.create<arith::DivUIOp>(
-        loc,
-        rewriter.create<arith::SubIOp>(
-            loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
-        step);
-    // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
+    // iterI = total_iters - 1 - i
+    // May go negative...
     Value minusI =
         rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
+    Value iterI = rewriter.create<arith::AddIOp>(
+        loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
+        minusI);
+    // newLastIter = lb + step * iterI
     Value newlastIter = rewriter.create<arith::AddIOp>(
-        loc, lb,
-        rewriter.create<arith::MulIOp>(
-            loc, step,
-            rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
+        loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
+
     setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+
+    if (dynamicLoop) {
+      // pred = iterI >= lb
+      predicates[i + 1] = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::sge, iterI, lb);
+    }
   }
+
   // Emit `maxStage - 1` epilogue part that includes operations from stages
   // [i; maxStage].
   for (int64_t i = 1; i <= maxStage; i++) {
+    SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
     for (Operation *op : opOrder) {
       if (stages[op] < i)
         continue;
+      unsigned currentVersion = maxStage - stages[op] + i;
+      unsigned nextVersion = currentVersion + 1;
       Operation *newOp =
           cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
             auto it = valueMapping.find(newOperand->get());
             if (it != valueMapping.end()) {
-              Value replacement = it->second[maxStage - stages[op] + i];
+              Value replacement = it->second[currentVersion];
               newOperand->set(replacement);
             }
           });
       if (annotateFn)
         annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
-      for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
-        setValueMapping(op->getResult(destId), newOp->getResult(destId),
-                        maxStage - stages[op] + i);
+      if (dynamicLoop) {
+        OpBuilder::InsertionGuard insertGuard(rewriter);
+        newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
+        if (!newOp)
+          return failure();
+      }
+
+      for (auto [opRes, newRes] :
+           llvm::zip(op->getResults(), newOp->getResults())) {
+        setValueMapping(opRes, newRes, currentVersion);
         // If the value is a loop carried dependency update the loop argument
         // mapping and keep track of the last version to replace the original
         // forOp uses.
         for (OpOperand &operand :
              forOp.getBody()->getTerminator()->getOpOperands()) {
-          if (operand.get() != op->getResult(destId))
+          if (operand.get() != opRes)
             continue;
-          unsigned version = maxStage - stages[op] + i + 1;
           // If the version is greater than maxStage it means it maps to the
           // original forOp returned value.
-          if (version > maxStage) {
-            returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
-            continue;
-          }
-          setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
-                          newOp->getResult(destId), version);
+          unsigned ri = operand.getOperandNumber();
+          returnValues[ri] = newRes;
+          Value mapVal = forOp.getRegionIterArgs()[ri];
+          returnMap[ri] = std::make_pair(mapVal, currentVersion);
+          if (nextVersion <= maxStage)
+            setValueMapping(mapVal, newRes, nextVersion);
+        }
+      }
+    }
+    if (dynamicLoop) {
+      // Select return values from this stage (live outs) based on predication.
+      // If the stage is valid select the peeled value, else use previous stage
+      // value.
+      for (auto pair : llvm::enumerate(returnValues)) {
+        unsigned ri = pair.index();
+        auto [mapVal, currentVersion] = returnMap[ri];
+        if (mapVal) {
+          unsigned nextVersion = currentVersion + 1;
+          Value pred = predicates[currentVersion];
+          Value prevValue = valueMapping[mapVal][currentVersion];
+          auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
+                                                        prevValue);
+          returnValues[ri] = selOp;
+          if (nextVersion <= maxStage)
+            setValueMapping(mapVal, selOp, nextVersion);
         }
       }
     }
   }
+  return success();
 }
 
 void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
@@ -760,7 +809,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
   if (options.peelEpilogue) {
     // 4. Emit the epilogue after the new forOp.
     rewriter.setInsertionPointAfter(newForOp);
-    pipeliner.emitEpilogue(rewriter, returnValues);
+    if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
+      return failure();
   }
   // 5. Erase the original loop and replace the uses with the epilogue output.
   if (forOp->getNumResults() > 0)
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 9687f80f5ddfc8..957dc5295c0583 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -764,11 +764,46 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 //      NOEPILOGUE:     memref.load %[[A]][%[[IV3]]] : memref<?xf32>
 //      NOEPILOGUE:   scf.yield %[[V2]], %[[L3]] : f32, f32
 
-// In case dynamic loop pipelining is off check that the transformation didn't
-// apply.
+// Check for predicated epilogue for dynamic loop.
 // CHECK-LABEL: dynamic_loop(
-//   CHECK-NOT:   memref.load
-//       CHECK:   scf.for
+//        CHECK:   %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+//        CHECK:       memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]] 
+//        CHECK:       %[[ADDF_26:.*]] = arith.addf %[[ARG7]], %{{.*}} 
+//        CHECK:       %[[MULI_27:.*]] = arith.muli %{{.*}}, %{{.*}} 
+//        CHECK:       %[[ADDI_28:.*]] = arith.addi %[[ARG5]], %[[MULI_27]] 
+//        CHECK:       %[[LOAD_29:.*]] = memref.load %{{.*}}[%[[ADDI_28]]] 
+//        CHECK:       scf.yield %[[ADDF_26]], %[[LOAD_29]] 
+//        CHECK:   }
+//        CHECK:   %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}} 
+//        CHECK:   %[[REMUI_11:.*]] = arith.remui %[[SUBI_10]], %{{.*}} 
+//        CHECK:   %[[CMPI_12:.*]] = arith.cmpi ne, %[[REMUI_11]], %{{.*}} 
+//        CHECK:   %[[SELECT_13:.*]] = arith.select %[[CMPI_12]], %{{.*}}, %{{.*}} 
+//        CHECK:   %[[DIVUI_14:.*]] = arith.divui %[[SUBI_10]], %{{.*}} 
+//        CHECK:   %[[ADDI_15:.*]] = arith.addi %[[DIVUI_14]], %[[SELECT_13]] 
+//        CHECK:   %[[ADDI_16:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1 
+//        CHECK:   %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[ADDI_16]] 
+//        CHECK:   %[[ADDI_18:.*]] = arith.addi %{{.*}}, %[[MULI_17]] 
+//        CHECK:   %[[CMPI_19:.*]] = arith.cmpi sge, %[[ADDI_16]], %{{.*}} 
+//        CHECK:   %[[ADDI_20:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1 
+//        CHECK:   %[[ADDI_21:.*]] = arith.addi %[[ADDI_20]], %{{.*}}-1 
+//        CHECK:   %[[MULI_22:.*]] = arith.muli %{{.*}}, %[[ADDI_21]] 
+//        CHECK:   %[[ADDI_23:.*]] = arith.addi %{{.*}}, %[[MULI_22]] 
+//        CHECK:   %[[CMPI_24:.*]] = arith.cmpi sge, %[[ADDI_21]], %{{.*}} 
+//        CHECK:   scf.if %[[CMPI_19]] {
+//        CHECK:     memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_23]]] 
+//        CHECK:   } else {
+//        CHECK:   }
+//        CHECK:   %[[IF_25:.*]] = scf.if %[[CMPI_24]] -> (f32) {
+//        CHECK:     %[[ADDF_26:.*]] = arith.addf %{{.*}}#1, %{{.*}} 
+//        CHECK:     scf.yield %[[ADDF_26]] 
+//        CHECK:   } else {
+//        CHECK:     scf.yield %{{.*}} 
+//        CHECK:   }
+//        CHECK:   scf.if %[[CMPI_24]] {
+//        CHECK:     memref.store %[[IF_25]], %{{.*}}[%[[ADDI_18]]] 
+//        CHECK:   } else {
+//        CHECK:   }
+//        CHECK:   return
 func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
   %cf = arith.constant 1.0 : f32
   scf.for %i0 = %lb to %ub step %step {
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 8a92d840ad1302..3ff7f9966e93da 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -214,12 +214,12 @@ struct TestSCFPipeliningPass
     RewritePatternSet patterns(&getContext());
     mlir::scf::PipeliningOption options;
     options.getScheduleFn = getSchedule;
+    options.supportDynamicLoops = true;
+    options.predicateFn = predicateOp;
     if (annotatePipeline)
       options.annotateFn = annotate;
     if (noEpiloguePeeling) {
-      options.supportDynamicLoops = true;
       options.peelEpilogue = false;
-      options.predicateFn = predicateOp;
     }
     scf::populateSCFLoopPipeliningPatterns(patterns, options);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));



More information about the Mlir-commits mailing list