[Mlir-commits] [mlir] ebf0599 - [MLIR][SCF] Add support for loop pipeline peeling for dynamic loops. (#106436)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 4 12:25:02 PDT 2024


Author: SJW
Date: 2024-09-04T12:24:58-07:00
New Revision: ebf0599314e17c3ab89f303d452811b1db3e6d1e

URL: https://github.com/llvm/llvm-project/commit/ebf0599314e17c3ab89f303d452811b1db3e6d1e
DIFF: https://github.com/llvm/llvm-project/commit/ebf0599314e17c3ab89f303d452811b1db3e6d1e.diff

LOG: [MLIR][SCF] Add support for loop pipeline peeling for dynamic loops. (#106436)

Allow speculative execution and predicate results per stage.

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
    mlir/test/Dialect/SCF/loop-pipelining.mlir
    mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index d8e1cc0ecef88e..a34542f0161aca 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -94,8 +94,8 @@ struct LoopPipelinerInternal {
       RewriterBase &rewriter);
   /// Emits the epilogue, this creates `maxStage - 1` part which will contain
   /// operations from stages [i; maxStage], where i is the part index.
-  void emitEpilogue(RewriterBase &rewriter,
-                    llvm::SmallVector<Value> &returnValues);
+  LogicalResult emitEpilogue(RewriterBase &rewriter,
+                             llvm::SmallVector<Value> &returnValues);
 };
 
 bool LoopPipelinerInternal::initializeLoopInfo(
@@ -133,10 +133,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
     LDBG("--no epilogue or predicate set -> BAIL");
     return false;
   }
-  if (dynamicLoop && peelEpilogue) {
-    LDBG("--dynamic loop doesn't support epilogue yet -> BAIL");
-    return false;
-  }
   std::vector<std::pair<Operation *, unsigned>> schedule;
   options.getScheduleFn(forOp, schedule);
   if (schedule.empty()) {
@@ -313,10 +309,10 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
           });
       int predicateIdx = i - stages[op];
       if (predicates[predicateIdx]) {
+        OpBuilder::InsertionGuard insertGuard(rewriter);
         newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
         assert(newOp && "failed to predicate op.");
       }
-      rewriter.setInsertionPointAfter(newOp);
       if (annotateFn)
         annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
       for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
@@ -561,6 +557,7 @@ LogicalResult LoopPipelinerInternal::createKernel(
     }
 
     if (predicates[useStage]) {
+      OpBuilder::InsertionGuard insertGuard(rewriter);
       newOp = predicateFn(rewriter, newOp, predicates[useStage]);
       if (!newOp)
         return failure();
@@ -568,7 +565,6 @@ LogicalResult LoopPipelinerInternal::createKernel(
       for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
         mapping.map(std::get<0>(values), std::get<1>(values));
     }
-    rewriter.setInsertionPointAfter(newOp);
     if (annotateFn)
       annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
   }
@@ -640,70 +636,113 @@ LogicalResult LoopPipelinerInternal::createKernel(
   return success();
 }
 
-void LoopPipelinerInternal::emitEpilogue(
-    RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
+LogicalResult
+LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
+                                    llvm::SmallVector<Value> &returnValues) {
+  Location loc = forOp.getLoc();
   // Emit 
diff erent versions of the induction variable. They will be
   // removed by dead code if not used.
+
+  // bounds_range = ub - lb
+  // total_iterations = (bounds_range + step - 1) / step
+  Type t = lb.getType();
+  Value minus1 =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+  Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
+  Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step);
+  Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
+  Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
+
+  SmallVector<Value> predicates(maxStage + 1);
   for (int64_t i = 0; i < maxStage; i++) {
-    Location loc = forOp.getLoc();
-    Type t = lb.getType();
-    Value minusOne =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
-    // number of iterations = ((ub - 1) - lb) / step
-    Value totalNumIteration = rewriter.create<arith::DivUIOp>(
-        loc,
-        rewriter.create<arith::SubIOp>(
-            loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
-        step);
-    // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
+    // iterI = total_iters - 1 - i
+    // May go negative...
     Value minusI =
         rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
+    Value iterI = rewriter.create<arith::AddIOp>(
+        loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
+        minusI);
+    // newLastIter = lb + step * iterI
     Value newlastIter = rewriter.create<arith::AddIOp>(
-        loc, lb,
-        rewriter.create<arith::MulIOp>(
-            loc, step,
-            rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
+        loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
+
     setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+
+    if (dynamicLoop) {
+      // pred = iterI >= lb
+      predicates[i + 1] = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::sge, iterI, lb);
+    }
   }
+
   // Emit `maxStage - 1` epilogue part that includes operations from stages
   // [i; maxStage].
   for (int64_t i = 1; i <= maxStage; i++) {
+    SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
     for (Operation *op : opOrder) {
       if (stages[op] < i)
         continue;
+      unsigned currentVersion = maxStage - stages[op] + i;
+      unsigned nextVersion = currentVersion + 1;
       Operation *newOp =
           cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
             auto it = valueMapping.find(newOperand->get());
             if (it != valueMapping.end()) {
-              Value replacement = it->second[maxStage - stages[op] + i];
+              Value replacement = it->second[currentVersion];
               newOperand->set(replacement);
             }
           });
+      if (dynamicLoop) {
+        OpBuilder::InsertionGuard insertGuard(rewriter);
+        newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
+        if (!newOp)
+          return failure();
+      }
       if (annotateFn)
         annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
-      for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
-        setValueMapping(op->getResult(destId), newOp->getResult(destId),
-                        maxStage - stages[op] + i);
+
+      for (auto [opRes, newRes] :
+           llvm::zip(op->getResults(), newOp->getResults())) {
+        setValueMapping(opRes, newRes, currentVersion);
         // If the value is a loop carried dependency update the loop argument
         // mapping and keep track of the last version to replace the original
         // forOp uses.
         for (OpOperand &operand :
              forOp.getBody()->getTerminator()->getOpOperands()) {
-          if (operand.get() != op->getResult(destId))
+          if (operand.get() != opRes)
             continue;
-          unsigned version = maxStage - stages[op] + i + 1;
           // If the version is greater than maxStage it means it maps to the
           // original forOp returned value.
-          if (version > maxStage) {
-            returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
-            continue;
-          }
-          setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
-                          newOp->getResult(destId), version);
+          unsigned ri = operand.getOperandNumber();
+          returnValues[ri] = newRes;
+          Value mapVal = forOp.getRegionIterArgs()[ri];
+          returnMap[ri] = std::make_pair(mapVal, currentVersion);
+          if (nextVersion <= maxStage)
+            setValueMapping(mapVal, newRes, nextVersion);
+        }
+      }
+    }
+    if (dynamicLoop) {
+      // Select return values from this stage (live outs) based on predication.
+      // If the stage is valid select the peeled value, else use previous stage
+      // value.
+      for (auto pair : llvm::enumerate(returnValues)) {
+        unsigned ri = pair.index();
+        auto [mapVal, currentVersion] = returnMap[ri];
+        if (mapVal) {
+          unsigned nextVersion = currentVersion + 1;
+          Value pred = predicates[currentVersion];
+          Value prevValue = valueMapping[mapVal][currentVersion];
+          auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
+                                                        prevValue);
+          returnValues[ri] = selOp;
+          if (nextVersion <= maxStage)
+            setValueMapping(mapVal, selOp, nextVersion);
         }
       }
     }
   }
+  return success();
 }
 
 void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
@@ -760,7 +799,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
   if (options.peelEpilogue) {
     // 4. Emit the epilogue after the new forOp.
     rewriter.setInsertionPointAfter(newForOp);
-    pipeliner.emitEpilogue(rewriter, returnValues);
+    if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
+      return failure();
   }
   // 5. Erase the original loop and replace the uses with the epilogue output.
   if (forOp->getNumResults() > 0)

diff  --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 9687f80f5ddfc8..4a1406faabce1b 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -764,11 +764,44 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 //      NOEPILOGUE:     memref.load %[[A]][%[[IV3]]] : memref<?xf32>
 //      NOEPILOGUE:   scf.yield %[[V2]], %[[L3]] : f32, f32
 
-// In case dynamic loop pipelining is off check that the transformation didn't
-// apply.
+// Check for predicated epilogue for dynamic loop.
 // CHECK-LABEL: dynamic_loop(
-//   CHECK-NOT:   memref.load
-//       CHECK:   scf.for
+//        CHECK:   %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+//        CHECK:       memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
+//        CHECK:       %[[ADDF_24:.*]] = arith.addf %[[ARG7]], %{{.*}}
+//        CHECK:       %[[MULI_25:.*]] = arith.muli %{{.*}}, %{{.*}}
+//        CHECK:       %[[ADDI_26:.*]] = arith.addi %[[ARG5]], %[[MULI_25]]
+//        CHECK:       %[[LOAD_27:.*]] = memref.load %{{.*}}[%[[ADDI_26]]]
+//        CHECK:       scf.yield %[[ADDF_24]], %[[LOAD_27]]
+//        CHECK:   }
+//        CHECK:   %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}}
+//        CHECK:   %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %{{.*}}
+//        CHECK:   %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %{{.*}}-1
+//        CHECK:   %[[DIVUI_13:.*]] = arith.divui %[[ADDI_12]], %{{.*}}
+//        CHECK:   %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
+//        CHECK:   %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
+//        CHECK:   %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
+//        CHECK:   %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
+//        CHECK:   %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
+//        CHECK:   %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
+//        CHECK:   %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
+//        CHECK:   %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
+//        CHECK:   %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
+//        CHECK:   scf.if %[[CMPI_17]] {
+//        CHECK:     memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
+//        CHECK:   } else {
+//        CHECK:   }
+//        CHECK:   %[[IF_23:.*]] = scf.if %[[CMPI_22]] -> (f32) {
+//        CHECK:     %[[ADDF_24:.*]] = arith.addf %{{.*}}#1, %{{.*}}
+//        CHECK:     scf.yield %[[ADDF_24]]
+//        CHECK:   } else {
+//        CHECK:     scf.yield %{{.*}}
+//        CHECK:   }
+//        CHECK:   scf.if %[[CMPI_22]] {
+//        CHECK:     memref.store %[[IF_23]], %{{.*}}[%[[ADDI_16]]]
+//        CHECK:   } else {
+//        CHECK:   }
+//        CHECK:   return
 func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
   %cf = arith.constant 1.0 : f32
   scf.for %i0 = %lb to %ub step %step {
@@ -781,6 +814,68 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
 
 // -----
 
+// NOEPILOGUE-LABEL:   func.func @dynamic_loop_result
+//       NOEPILOGUE:     %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+//       NOEPILOGUE:       %[[SUBI_3:.*]] = arith.subi %{{.*}}, %{{.*}}
+//       NOEPILOGUE:       %[[CMPI_4:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_3]]
+//       NOEPILOGUE:       %[[ADDF_5:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
+//       NOEPILOGUE:       %[[MULF_6:.*]] = arith.mulf %[[ADDF_5]], %{{.*}}
+//       NOEPILOGUE:       %[[ADDI_7:.*]] = arith.addi %[[ARG5]], %{{.*}}
+//       NOEPILOGUE:       %[[IF_8:.*]] = scf.if %[[CMPI_4]]
+//       NOEPILOGUE:         %[[LOAD_9:.*]] = memref.load %{{.*}}[%[[ADDI_7]]]
+//       NOEPILOGUE:         scf.yield %[[LOAD_9]]
+//       NOEPILOGUE:       } else {
+//       NOEPILOGUE:         scf.yield %{{.*}}
+//       NOEPILOGUE:       }
+//       NOEPILOGUE:       scf.yield %[[MULF_6]], %[[IF_8]]
+//       NOEPILOGUE:     }
+//       NOEPILOGUE:     memref.store %{{.*}}#0, %{{.*}}[%{{.*}}]
+
+// Check for predicated epilogue for dynamic loop.
+// CHECK-LABEL:   func.func @dynamic_loop_result
+//       CHECK:     %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+//       CHECK:       %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
+//       CHECK:       %[[MULF_14:.*]] = arith.mulf %[[ADDF_13]], %{{.*}}
+//       CHECK:       %[[ADDI_15:.*]] = arith.addi %[[ARG5]], %{{.*}}
+//       CHECK:       %[[LOAD_16:.*]] = memref.load %{{.*}}[%[[ADDI_15]]]
+//       CHECK:       scf.yield %[[MULF_14]], %[[LOAD_16]]
+//       CHECK:     }
+//       CHECK:     %[[SUBI_4:.*]] = arith.subi %{{.*}}, %{{.*}}
+//       CHECK:     %[[ADDI_5:.*]] = arith.addi %[[SUBI_4]], %{{.*}}
+//       CHECK:     %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
+//       CHECK:     %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
+//       CHECK:     %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
+//       CHECK:     %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
+//       CHECK:     %[[IF_10:.*]] = scf.if %[[CMPI_9]]
+//       CHECK:       %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
+//       CHECK:       scf.yield %[[ADDF_13]]
+//       CHECK:     } else {
+//       CHECK:       scf.yield %{{.*}}
+//       CHECK:     }
+//       CHECK:     %[[IF_11:.*]] = scf.if %[[CMPI_9]]
+//       CHECK:       %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}}
+//       CHECK:       scf.yield %[[MULF_13]]
+//       CHECK:     } else {
+//       CHECK:       scf.yield %{{.*}}
+//       CHECK:     }
+//       CHECK:     %[[SELECT_12:.*]] = arith.select %[[CMPI_9]], %[[IF_11]], %{{.*}}#0
+//       CHECK:     memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}]
+func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
+  %cf0 = arith.constant 1.0 : f32
+  %cf1 = arith.constant 33.0 : f32
+  %cst = arith.constant 0 : index
+  %res:1 = scf.for %i0 = %lb to %ub step %step iter_args (%arg0 = %cf0) -> (f32) {
+    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+    %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+    %A2_elem = arith.mulf %A1_elem, %cf1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
+    scf.yield %A2_elem : f32
+  } { __test_pipelining_loop__ }
+  memref.store %res#0, %result[%cst] : memref<?xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: yield_constant_loop(
 //  CHECK-SAME:   %[[A:.*]]: memref<?xf32>) -> f32 {
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index

diff  --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 8a92d840ad1302..3ff7f9966e93da 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -214,12 +214,12 @@ struct TestSCFPipeliningPass
     RewritePatternSet patterns(&getContext());
     mlir::scf::PipeliningOption options;
     options.getScheduleFn = getSchedule;
+    options.supportDynamicLoops = true;
+    options.predicateFn = predicateOp;
     if (annotatePipeline)
       options.annotateFn = annotate;
     if (noEpiloguePeeling) {
-      options.supportDynamicLoops = true;
       options.peelEpilogue = false;
-      options.predicateFn = predicateOp;
     }
     scf::populateSCFLoopPipeliningPatterns(patterns, options);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));


        


More information about the Mlir-commits mailing list