[Mlir-commits] [mlir] [MLIR][SCF] Add support for pipelining dynamic loops (PR #74350)

Thomas Raoux llvmlistbot at llvm.org
Mon Dec 4 09:47:22 PST 2023


https://github.com/ThomasRaoux created https://github.com/llvm/llvm-project/pull/74350

Support loops without static boundaries. Since the number of iteration is not known we need to predicate prologue and epilogue in case the number of iterations is smaller than the number of stages.

This patch includes work from @chengjunlu

>From 0159da9342a3eaeb0d33d18d658a0b5bd18473b1 Mon Sep 17 00:00:00 2001
From: Thomas Raoux <thomas.raoux at openai.com>
Date: Fri, 1 Dec 2023 19:51:11 -0800
Subject: [PATCH] [MLIR][SCF] Add support for pipelining dynamic loops

Support loop that don't have static bound. Since the number of iteration
is not known we need to predicate prologue and epilogue in case the
number of iterations is smaller than the number of stages.
---
 .../mlir/Dialect/SCF/Transforms/Transforms.h  |   7 +
 .../Dialect/SCF/Transforms/LoopPipelining.cpp | 137 ++++++++++++++----
 .../NVGPU/transform-pipeline-shared.mlir      |  26 ++--
 mlir/test/Dialect/SCF/loop-pipelining.mlir    |  44 ++++++
 mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp    |   1 +
 5 files changed, 170 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
index 347beb9e4c64f..cad5173599453 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
@@ -128,6 +128,13 @@ struct PipeliningOption {
   /// lambda to generate the predicated version of operations.
   bool peelEpilogue = true;
 
+  /// Control whether the transformation checks that the number of iterations is
+  /// greater or equal to the number of stages and skip the transformation if
+  /// this is not the case. If the loop is dynamic and this is set to true and
+  /// the loop bounds are not static the pipeliner will have to predicate
+  /// operations in the the prologue/epilogue.
+  bool supportDynamicLoops = false;
+
   // Callback to predicate operations when the prologue or epilogue are not
   // peeled. This takes the original operation, an i1 predicate value and the
   // pattern rewriter. It is expected to replace the given operation with
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 20fa8089201aa..81ed1826a5508 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -44,9 +44,10 @@ struct LoopPipelinerInternal {
   unsigned maxStage = 0;
   DenseMap<Operation *, unsigned> stages;
   std::vector<Operation *> opOrder;
-  int64_t ub;
-  int64_t lb;
-  int64_t step;
+  Value ub;
+  Value lb;
+  Value step;
+  bool dynamicLoop;
   PipeliningOption::AnnotationlFnType annotateFn = nullptr;
   bool peelEpilogue;
   PipeliningOption::PredicateOpFn predicateFn = nullptr;
@@ -96,25 +97,41 @@ bool LoopPipelinerInternal::initializeLoopInfo(
     ForOp op, const PipeliningOption &options) {
   LDBG("Start initializeLoopInfo");
   forOp = op;
-  auto upperBoundCst =
-      forOp.getUpperBound().getDefiningOp<arith::ConstantIndexOp>();
-  auto lowerBoundCst =
-      forOp.getLowerBound().getDefiningOp<arith::ConstantIndexOp>();
-  auto stepCst = forOp.getStep().getDefiningOp<arith::ConstantIndexOp>();
+  ub = forOp.getUpperBound();
+  lb = forOp.getLowerBound();
+  step = forOp.getStep();
+
+  dynamicLoop = true;
+  auto upperBoundCst = getConstantIntValue(ub);
+  auto lowerBoundCst = getConstantIntValue(lb);
+  auto stepCst = getConstantIntValue(step);
   if (!upperBoundCst || !lowerBoundCst || !stepCst) {
-    LDBG("--no constant bounds or step -> BAIL");
-    return false;
+    if (!options.supportDynamicLoops) {
+      LDBG("--dynamic loop not supported -> BAIL");
+      return false;
+    }
+  } else {
+    int64_t ubImm = upperBoundCst.value();
+    int64_t lbImm = lowerBoundCst.value();
+    int64_t stepImm = stepCst.value();
+    int64_t numIteration = ceilDiv(ubImm - lbImm, stepImm);
+    if (numIteration > maxStage) {
+      dynamicLoop = false;
+    } else if (!options.supportDynamicLoops) {
+      LDBG("--fewer loop iterations than pipeline stages -> BAIL");
+      return false;
+    }
   }
-  ub = upperBoundCst.value();
-  lb = lowerBoundCst.value();
-  step = stepCst.value();
   peelEpilogue = options.peelEpilogue;
   predicateFn = options.predicateFn;
-  if (!peelEpilogue && predicateFn == nullptr) {
+  if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
     LDBG("--no epilogue or predicate set -> BAIL");
     return false;
   }
-  int64_t numIteration = ceilDiv(ub - lb, step);
+  if (dynamicLoop && peelEpilogue) {
+    LDBG("--dynamic loop doesn't support epilogue yet -> BAIL");
+    return false;
+  }
   std::vector<std::pair<Operation *, unsigned>> schedule;
   options.getScheduleFn(forOp, schedule);
   if (schedule.empty()) {
@@ -128,10 +145,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
     stages[opSchedule.first] = opSchedule.second;
     opOrder.push_back(opSchedule.first);
   }
-  if (numIteration <= maxStage) {
-    LDBG("--fewer loop iterations than pipeline stages -> BAIL");
-    return false;
-  }
 
   // All operations need to have a stage.
   for (Operation &op : forOp.getBody()->without_terminator()) {
@@ -204,10 +217,31 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
     setValueMapping(arg, operand.get(), 0);
   }
   auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+  Location loc = forOp.getLoc();
+  SmallVector<Value> predicates(maxStage);
   for (int64_t i = 0; i < maxStage; i++) {
+    if (dynamicLoop) {
+      Type t = ub.getType();
+      // pred = ub > lb + (i * step)
+      Value iv = rewriter.create<arith::AddIOp>(
+          loc, lb,
+          rewriter.create<arith::MulIOp>(
+              loc, step,
+              rewriter.create<arith::ConstantOp>(
+                  loc, rewriter.getIntegerAttr(t, i))));
+      predicates[i] = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::slt, iv, ub);
+    }
+
     // special handling for induction variable as the increment is implicit.
-    Value iv =
-        rewriter.create<arith::ConstantIndexOp>(forOp.getLoc(), lb + i * step);
+    // iv = lb + i * step
+    Type t = lb.getType();
+    Value iv = rewriter.create<arith::AddIOp>(
+        loc, lb,
+        rewriter.create<arith::MulIOp>(
+            loc, step,
+            rewriter.create<arith::ConstantOp>(loc,
+                                               rewriter.getIntegerAttr(t, i))));
     setValueMapping(forOp.getInductionVar(), iv, i);
     for (Operation *op : opOrder) {
       if (stages[op] > i)
@@ -220,6 +254,12 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
               newOperand->set(replacement);
             }
           });
+      int predicateIdx = i - stages[op];
+      if (predicates[predicateIdx]) {
+        newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
+        assert(newOp && "failed to predicate op.");
+      }
+      rewriter.setInsertionPointAfter(newOp);
       if (annotateFn)
         annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
       for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
@@ -326,9 +366,16 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
   // `numStages - 1` iterations. Then we adjust the upper bound to remove those
   // iterations.
   Value newUb = forOp.getUpperBound();
-  if (peelEpilogue)
-    newUb = rewriter.create<arith::ConstantIndexOp>(forOp.getLoc(),
-                                                    ub - maxStage * step);
+  if (peelEpilogue) {
+    Type t = ub.getType();
+    Location loc = forOp.getLoc();
+    // newUb = ub - maxStage * step
+    Value maxStageByStep = rewriter.create<arith::MulIOp>(
+        loc, step,
+        rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getIntegerAttr(t, maxStage)));
+    newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep);
+  }
   auto newForOp =
       rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
                                   forOp.getStep(), newLoopArg);
@@ -358,9 +405,17 @@ LogicalResult LoopPipelinerInternal::createKernel(
   SmallVector<Value> predicates(maxStage + 1, nullptr);
   if (!peelEpilogue) {
     // Create a predicate for each stage except the last stage.
+    Location loc = newForOp.getLoc();
+    Type t = ub.getType();
     for (unsigned i = 0; i < maxStage; i++) {
-      Value c = rewriter.create<arith::ConstantIndexOp>(
-          newForOp.getLoc(), ub - (maxStage - i) * step);
+      // c = ub - (maxStage - i) * step
+      Value c = rewriter.create<arith::AddIOp>(
+          loc, ub,
+          rewriter.create<arith::MulIOp>(
+              loc, step,
+              rewriter.create<arith::ConstantOp>(
+                  loc, rewriter.getIntegerAttr(t, -int64_t(maxStage - i)))));
+
       Value pred = rewriter.create<arith::CmpIOp>(
           newForOp.getLoc(), arith::CmpIPredicate::slt,
           newForOp.getInductionVar(), c);
@@ -383,8 +438,14 @@ LogicalResult LoopPipelinerInternal::createKernel(
       // version incremented based on the stage where it is used.
       if (operand->get() == forOp.getInductionVar()) {
         rewriter.setInsertionPoint(newOp);
-        Value offset = rewriter.create<arith::ConstantIndexOp>(
-            forOp.getLoc(), (maxStage - stages[op]) * step);
+
+        // offset = (maxStage - stages[op]) * step
+        Type t = step.getType();
+        Value offset = rewriter.create<arith::MulIOp>(
+            forOp.getLoc(), step,
+            rewriter.create<arith::ConstantOp>(
+                forOp.getLoc(),
+                rewriter.getIntegerAttr(t, maxStage - stages[op])));
         Value iv = rewriter.create<arith::AddIOp>(
             forOp.getLoc(), newForOp.getInductionVar(), offset);
         nestedNewOp->setOperand(operand->getOperandNumber(), iv);
@@ -508,8 +569,24 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
   // Emit different versions of the induction variable. They will be
   // removed by dead code if not used.
   for (int64_t i = 0; i < maxStage; i++) {
-    Value newlastIter = rewriter.create<arith::ConstantIndexOp>(
-        forOp.getLoc(), lb + step * ((((ub - 1) - lb) / step) - i));
+    Location loc = forOp.getLoc();
+    Type t = lb.getType();
+    Value minusOne =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+    // number of iterations = ((ub - 1) - lb) / step
+    Value totlaNumIteration = rewriter.create<arith::DivUIOp>(
+        loc,
+        rewriter.create<arith::SubIOp>(
+            loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
+        step);
+    // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
+    Value minusI =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
+    Value newlastIter = rewriter.create<arith::AddIOp>(
+        loc, lb,
+        rewriter.create<arith::MulIOp>(
+            loc, step,
+            rewriter.create<arith::AddIOp>(loc, totlaNumIteration, minusI)));
     setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
   }
   // Emit `maxStage - 1` epilogue part that includes operations from stages
diff --git a/mlir/test/Dialect/NVGPU/transform-pipeline-shared.mlir b/mlir/test/Dialect/NVGPU/transform-pipeline-shared.mlir
index 42b072374261e..e959949babd9e 100644
--- a/mlir/test/Dialect/NVGPU/transform-pipeline-shared.mlir
+++ b/mlir/test/Dialect/NVGPU/transform-pipeline-shared.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter -canonicalize --split-input-file --verify-diagnostics | FileCheck %s
 
 func.func @simple_depth_2_unpeeled(%global: memref<?xf32>, %result: memref<?xf32> ) {
   %c0 = arith.constant 0 : index
@@ -78,15 +78,19 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-LABEL: @async_depth_2_predicated
 // CHECK-SAME: %[[GLOBAL:.+]]: memref
-func.func @async_depth_2_predicated(%global: memref<?xf32>) {
+func.func @async_depth_2_predicated(%global: memref<?xf32>, %alloc_size: index) {
   %c0 = arith.constant 0 : index
   %c98 = arith.constant 98 : index
   %c100 = arith.constant 100 : index
-  %c200 = arith.constant 200 : index
-  // CHECK: %[[C4:.+]] = arith.constant 4
+  // CHECK-DAG: %[[C4:.+]] = arith.constant 4
+  // CHECK-DAG:   %[[C90:.+]] = arith.constant 90
+  // CHECK-DAG:   %[[C96:.+]] = arith.constant 96
+  // CHECK-DAG:   %[[C8:.+]] = arith.constant 8
+  // CHECK-DAG:   %[[C2:.+]] = arith.constant 2
+  // CHECK-DAG:   %[[C0:.+]] = arith.constant 0
   %c4 = arith.constant 4 : index
   // CHECK: %[[SHARED:.+]] = memref.alloc{{.*}} #gpu.address_space<workgroup>
-  %shared = memref.alloc(%c200) : memref<?xf32, #gpu.address_space<workgroup>>
+  %shared = memref.alloc(%alloc_size) : memref<?xf32, #gpu.address_space<workgroup>>
   %c0f = arith.constant 0.0 : f32
   // CHECK: %[[TOKEN0:.+]] = nvgpu.device_async_copy
   // CHECK: %[[TOKEN1:.+]] = nvgpu.device_async_copy
@@ -95,16 +99,11 @@ func.func @async_depth_2_predicated(%global: memref<?xf32>) {
   // CHECK-SAME: %[[ITER_ARG1:.+]] = %[[TOKEN1]]
   scf.for %i = %c0 to %c98 step %c4 {
     // Condition for the predication "select" below.
-    // CHECK:   %[[C90:.+]] = arith.constant 90
     // CHECK:   %[[CMP0:.+]] = arith.cmpi slt, %[[I]], %[[C90]]
     // CHECK:   nvgpu.device_async_wait %[[ITER_ARG0]] {numGroups = 1
-
     // Original "select" with updated induction variable.
-    // CHECK:   %[[C96:.+]] = arith.constant 96
-    // CHECK:   %[[C8:.+]] = arith.constant 8
     // CHECK:   %[[I_PLUS_8:.+]] = arith.addi %[[I]], %[[C8]]
     // CHECK:   %[[CMP1:.+]] = arith.cmpi slt, %[[I_PLUS_8]], %[[C96]]
-    // CHECK:   %[[C2:.+]] = arith.constant 2
     // CHECK:   %[[SELECTED0:.+]] = arith.select %[[CMP1]], %[[C4]], %[[C2]]
     %c96 = arith.constant 96 : index
     %cond = arith.cmpi slt, %i, %c96 : index
@@ -113,14 +112,11 @@ func.func @async_depth_2_predicated(%global: memref<?xf32>) {
 
     // Updated induction variables (two more) for the device_async_copy below.
     // These are generated repeatedly by the pipeliner.
-    // CHECK:   %[[C8_2:.+]] = arith.constant 8
-    // CHECK:   %[[I_PLUS_8_2:.+]] = arith.addi %[[I]], %[[C8_2]]
-    // CHECK:   %[[C8_3:.+]] = arith.constant 8
-    // CHECK:   %[[I_PLUS_8_3:.+]] = arith.addi %[[I]], %[[C8_3]]
+    // CHECK:   %[[I_PLUS_8_2:.+]] = arith.addi %[[I]], %[[C8]]
+    // CHECK:   %[[I_PLUS_8_3:.+]] = arith.addi %[[I]], %[[C8]]
 
     // The second "select" is generated by predication and selects 0 for
     // the two last iterations.
-    // CHECK:   %[[C0:.+]] = arith.constant 0
     // CHECK:   %[[SELECTED1:.+]] = arith.select %[[CMP0]], %[[SELECTED0]], %[[C0]]
     // CHECK:   %[[ASYNC_TOKEN:.+]] = nvgpu.device_async_copy %[[GLOBAL]][%[[I_PLUS_8_3]]], %[[SHARED]][%[[I_PLUS_8_2]]], 4, %[[SELECTED1]]
     %token = nvgpu.device_async_copy %global[%i], %shared[%i], 4, %read_size
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 4cd686d2cdb86..9cc8ce69dbba4 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -723,3 +723,47 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>) {
   memref.store %r, %result[%c1] : memref<?xf32>
   return
 }
+
+// -----
+
+// NOEPILOGUE-LABEL: dynamic_loop(
+//  NOEPILOGUE-SAME:   %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>, %[[LB:.+]]: index, %[[UB:.+]]: index, %[[STEP:.+]]: index) {
+//  NOEPILOGUE-DAG: %[[CM2:.+]] = arith.constant -2 : index
+//  NOEPILOGUE-DAG: %[[CM1:.+]] = arith.constant -1 : index
+//  NOEPILOGUE-DAG: %[[C2:.+]] = arith.constant 2 : index
+//  NOEPILOGUE-DAG: %[[CSTF:.+]] = arith.constant 1.000000e+00 : f32
+// Prologue:
+//      NOEPILOGUE: %[[P_I0:.+]] = arith.cmpi slt, %[[LB]], %[[UB]] : index
+//      NOEPILOGUE: %[[L0:.+]] = scf.if %[[P_I0]] -> (f32) {
+// NOEPILOGUE-NEXT:   memref.load %[[A]][%[[LB]]] : memref<?xf32>
+//      NOEPILOGUE: %[[IV1:.+]] = arith.addi %[[LB]], %[[STEP]] : index
+//      NOEPILOGUE: %[[P_I1:.+]] = arith.cmpi slt, %[[IV1]], %[[UB]] : index
+//      NOEPILOGUE: %[[IV1_2:.+]] = arith.addi %[[LB]], %[[STEP]] : index
+//      NOEPILOGUE: %[[V0:.+]] = scf.if %[[P_I0]] -> (f32) {
+// NOEPILOGUE-NEXT:   arith.addf %[[L0]], %[[CSTF]] : f32
+//      NOEPILOGUE: %[[L1:.+]] = scf.if %[[P_I1]] -> (f32) {
+// NOEPILOGUE-NEXT:   memref.load %[[A]][%[[IV1_2]]] : memref<?xf32>
+//  NOEPILOGUE: scf.for %[[IV2:.+]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[V1:.+]] = %[[V0]], %[[L2:.+]] = %[[L1]]) -> (f32, f32) {
+//  NOEPILOGUE-DAG:   %[[S2:.+]] = arith.muli %[[STEP]], %[[CM2]] : index
+//  NOEPILOGUE-DAG:   %[[IT2:.+]] = arith.addi %[[UB]], %[[S2]] : index
+//  NOEPILOGUE-DAG:   %[[P_I2:.+]] = arith.cmpi slt, %[[IV2]], %[[IT2]] : index
+//  NOEPILOGUE-DAG:   %[[S3:.+]] = arith.muli %[[STEP]], %[[CM1]] : index
+//  NOEPILOGUE-DAG:   %[[IT3:.+]] = arith.addi %[[UB]], %[[S3]] : index
+//  NOEPILOGUE-DAG:   %[[P_I3:.+]] = arith.cmpi slt, %[[IV2]], %[[IT3]] : index
+//      NOEPILOGUE:   memref.store %[[V1]], %[[R]][%[[IV2]]] : memref<?xf32>
+//      NOEPILOGUE:   %[[V2:.+]] = scf.if %[[P_I3]] -> (f32) {
+//      NOEPILOGUE:     arith.addf %[[L2]], %[[CSTF]] : f32
+//      NOEPILOGUE:   %[[IT4:.+]] = arith.muli %[[STEP]], %[[C2]] : index
+//      NOEPILOGUE:   %[[IV3:.+]] = arith.addi %[[IV2]], %[[IT4]] : index
+//      NOEPILOGUE:   %[[L3:.+]] = scf.if %[[P_I2]] -> (f32) {
+//      NOEPILOGUE:     memref.load %[[A]][%[[IV3]]] : memref<?xf32>
+//      NOEPILOGUE:   scf.yield %[[V2]], %[[L3]] : f32, f32
+func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
+  %cf = arith.constant 1.0 : f32
+  scf.for %i0 = %lb to %ub step %step {
+    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+    %A1_elem = arith.addf %A_elem, %cf { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
+    memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : memref<?xf32>
+  } { __test_pipelining_loop__ }
+  return
+}
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 565d07669792f..9fed60e567881 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -214,6 +214,7 @@ struct TestSCFPipeliningPass
     RewritePatternSet patterns(&getContext());
     mlir::scf::PipeliningOption options;
     options.getScheduleFn = getSchedule;
+    options.supportDynamicLoops = true;
     if (annotatePipeline)
       options.annotateFn = annotate;
     if (noEpiloguePeeling) {



More information about the Mlir-commits mailing list