[Mlir-commits] [mlir] 8da5aa1 - [mlir][SCF] Fix dynamic loop pipeline peeling for num_stages > total_iters (#112418)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 15 13:13:52 PDT 2024


Author: SJW
Date: 2024-10-15T13:13:49-07:00
New Revision: 8da5aa16f65bc297663573bacd3030f975b9fcde

URL: https://github.com/llvm/llvm-project/commit/8da5aa16f65bc297663573bacd3030f975b9fcde
DIFF: https://github.com/llvm/llvm-project/commit/8da5aa16f65bc297663573bacd3030f975b9fcde.diff

LOG: [mlir][SCF] Fix dynamic loop pipeline peeling for num_stages > total_iters (#112418)

When pipelining an `scf.for` with dynamic loop bounds, the epilogue
ramp-down must align with the prologue when num_stages >
total_iterations.

For example:
```
scf.for (0..ub) {
  load(i)
  add(i)
  store(i)
}
```
When num_stages=3 the pipeline follows:
```
load(0)  -  add(0)      -  scf.for (0..ub-2)    -  store(ub-2)
            load(1)     -                       -  add(ub-1)     -  store(ub-1)

```
The trailing `store(ub-2)`, `i=ub-2`, must align with the ramp-up for
`i=0` when `ub < num_stages-1`, so the index `i` should be `max(0,
ub-2)` and each subsequent index is an increment. The predicate must
also handle this scenario, so it becomes `predicate[0] =
total_iterations > epilogue_stage`.

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
    mlir/test/Dialect/SCF/loop-pipelining.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 83c9cf69ba0364..1b458f410af601 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -642,22 +642,25 @@ LogicalResult
 LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
                                     llvm::SmallVector<Value> &returnValues) {
   Location loc = forOp.getLoc();
+  Type t = lb.getType();
+
   // Emit 
diff erent versions of the induction variable. They will be
   // removed by dead code if not used.
 
-  // bounds_range = ub - lb
-  // total_iterations = (bounds_range + step - 1) / step
-  Type t = lb.getType();
-  Value zero =
-      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
-  Value one =
-      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
-  Value minusOne =
-      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+  auto createConst = [&](int v) {
+    return rewriter.create<arith::ConstantOp>(loc,
+                                              rewriter.getIntegerAttr(t, v));
+  };
+
+  // total_iterations = cdiv(range_
diff , step);
+  // - range_
diff  = ub - lb
+  // - total_iterations = (range_
diff  + step + (step < 0 ? 1 : -1)) / step
+  Value zero = createConst(0);
+  Value one = createConst(1);
   Value stepLessZero = rewriter.create<arith::CmpIOp>(
       loc, arith::CmpIPredicate::slt, step, zero);
   Value stepDecr =
-      rewriter.create<arith::SelectOp>(loc, stepLessZero, one, minusOne);
+      rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1));
 
   Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
   Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
@@ -665,25 +668,31 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
       rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
   Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
 
+  // If total_iters < max_stage, start the epilogue at zero to match the
+  // ramp-up in the prologue.
+  // start_iter = max(0, total_iters - max_stage)
+  Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations,
+                                               createConst(maxStage));
+  iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);
+
+  // Capture predicates for dynamic loops.
   SmallVector<Value> predicates(maxStage + 1);
-  for (int64_t i = 0; i < maxStage; i++) {
-    // iterI = total_iters - 1 - i
-    // May go negative...
-    Value minusI =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
-    Value iterI = rewriter.create<arith::AddIOp>(
-        loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minusOne),
-        minusI);
+
+  for (int64_t i = 1; i <= maxStage; i++) {
     // newLastIter = lb + step * iterI
     Value newlastIter = rewriter.create<arith::AddIOp>(
         loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
 
-    setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+    setValueMapping(forOp.getInductionVar(), newlastIter, i);
+
+    // increment to next iterI
+    iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);
 
     if (dynamicLoop) {
-      // pred = iterI >= 0
-      predicates[i + 1] = rewriter.create<arith::CmpIOp>(
-          loc, arith::CmpIPredicate::sge, iterI, zero);
+      // Disable stages when `i` is greater than total_iters.
+      // pred = total_iters >= i
+      predicates[i] = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::sge, totalIterations, createConst(i));
     }
   }
 

diff  --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index af49d2afc049ba..c879c83275bf86 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -767,6 +767,7 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 // Check for predicated epilogue for dynamic loop.
 // CHECK-LABEL: dynamic_loop(
 //    CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//    CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
 //    CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //    CHECK-DAG:   %[[CM1:.*]] = arith.constant -1 : index
 //        CHECK:   %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
@@ -779,32 +780,32 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 //        CHECK:       scf.yield %[[ADDF_24]], %[[LOAD_27]]
 //        CHECK:   }
 //        CHECK:   %[[CMPI_10:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]]
-//        CHECK:   %[[SEL_10:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]]
-//        CHECK:   %[[SUBI_10:.*]] = arith.subi %[[UB]], %[[LB]]
-//        CHECK:   %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %[[STEP]]
-//        CHECK:   %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %[[SEL_10]]
-//        CHECK:   %[[DIVSI_13:.*]] = arith.divsi %[[ADDI_12]], %[[STEP]]
-//        CHECK:   %[[ADDI_14:.*]] = arith.addi %[[DIVSI_13]], %[[CM1]]
-//        CHECK:   %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
-//        CHECK:   %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
-//        CHECK:   %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %[[C0]]
-//        CHECK:   %[[ADDI_18:.*]] = arith.addi %[[DIVSI_13]], %{{.*}}-1
-//        CHECK:   %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
-//        CHECK:   %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
-//        CHECK:   %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
-//        CHECK:   %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %[[C0]]
-//        CHECK:   scf.if %[[CMPI_17]] {
-//        CHECK:     memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
+//        CHECK:   %[[SELECT_11:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]]
+//        CHECK:   %[[SUBI_12:.*]] = arith.subi %[[UB]], %[[LB]]
+//        CHECK:   %[[ADDI_13:.*]] = arith.addi %[[SUBI_12]], %[[STEP]]
+//        CHECK:   %[[ADDI_14:.*]] = arith.addi %[[ADDI_13]], %[[SELECT_11]]
+//        CHECK:   %[[DIVSI_15:.*]] = arith.divsi %[[ADDI_14]], %[[STEP]]
+//        CHECK:   %[[SUBI_17:.*]] = arith.subi %[[DIVSI_15]], %[[C2]]
+//        CHECK:   %[[MAXSI_18:.*]] = arith.maxsi %[[SUBI_17]], %[[C0]]
+//        CHECK:   %[[MULI_19:.*]] = arith.muli %[[STEP]], %[[MAXSI_18]]
+//        CHECK:   %[[ADDI_20:.*]] = arith.addi %[[LB]], %[[MULI_19]]
+//        CHECK:   %[[ADDI_21:.*]] = arith.addi %[[MAXSI_18]], %[[C1]]
+//        CHECK:   %[[CMPI_22:.*]] = arith.cmpi sge, %[[DIVSI_15]], %[[C1]]
+//        CHECK:   %[[MULI_23:.*]] = arith.muli %[[STEP]], %[[ADDI_21]]
+//        CHECK:   %[[ADDI_24:.*]] = arith.addi %[[LB]], %[[MULI_23]]
+//        CHECK:   %[[CMPI_25:.*]] = arith.cmpi sge, %[[DIVSI_15]], %[[C2]]
+//        CHECK:   scf.if %[[CMPI_22]] {
+//        CHECK:     memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_20]]]
 //        CHECK:   } else {
 //        CHECK:   }
-//        CHECK:   %[[IF_23:.*]] = scf.if %[[CMPI_22]] -> (f32) {
-//        CHECK:     %[[ADDF_24:.*]] = arith.addf %{{.*}}#1, %{{.*}}
-//        CHECK:     scf.yield %[[ADDF_24]]
+//        CHECK:   %[[IF_26:.*]] = scf.if %[[CMPI_25]]
+//        CHECK:     %[[ADDF_27:.*]] = arith.addf %{{.*}}#1, %{{.*}}
+//        CHECK:     scf.yield %[[ADDF_27]]
 //        CHECK:   } else {
 //        CHECK:     scf.yield %{{.*}}
 //        CHECK:   }
-//        CHECK:   scf.if %[[CMPI_22]] {
-//        CHECK:     memref.store %[[IF_23]], %{{.*}}[%[[ADDI_16]]]
+//        CHECK:   scf.if %[[CMPI_25]] {
+//        CHECK:     memref.store %[[IF_26]], %{{.*}}[%[[ADDI_24]]]
 //        CHECK:   } else {
 //        CHECK:   }
 //        CHECK:   return
@@ -842,6 +843,7 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
 //   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG:   %[[CM1:.*]] = arith.constant -1 : index
+//   CHECK-DAG:   %[[CF0:.*]] = arith.constant 0.000000e+00
 //       CHECK:   %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
 //       CHECK:   %{{.*}}:2 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UBM]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
 //       CHECK:       %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
@@ -856,22 +858,21 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
 //       CHECK:     %[[ADDI_7:.*]] = arith.addi %[[SUBI_6]], %[[STEP]]
 //       CHECK:     %[[ADDI_8:.*]] = arith.addi %[[ADDI_7]], %[[SELECT_5]]
 //       CHECK:     %[[DIVSI_9:.*]] = arith.divsi %[[ADDI_8]], %[[STEP]]
-//       CHECK:     %[[ADDI_10:.*]] = arith.addi %[[DIVSI_9]], %[[CM1]]
-//       CHECK:     %[[CMPI_11:.*]] = arith.cmpi sge, %[[ADDI_10]], %[[C0]]
-//       CHECK:     %[[IF_10:.*]] = scf.if %[[CMPI_11]]
-//       CHECK:       %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
-//       CHECK:       scf.yield %[[ADDF_13]]
+//       CHECK:     %[[CMPI_10:.*]] = arith.cmpi sge, %[[DIVSI_9]], %[[C1]]
+//       CHECK:     %[[IF_11:.*]] = scf.if %[[CMPI_10]]
+//       CHECK:       %[[ADDF_14:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
+//       CHECK:       scf.yield %[[ADDF_14]]
 //       CHECK:     } else {
-//       CHECK:       scf.yield %{{.*}}
+//       CHECK:       scf.yield %[[CF0]]
 //       CHECK:     }
-//       CHECK:     %[[IF_11:.*]] = scf.if %[[CMPI_11]]
-//       CHECK:       %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}}
-//       CHECK:       scf.yield %[[MULF_13]]
+//       CHECK:     %[[IF_12:.*]] = scf.if %[[CMPI_10]]
+//       CHECK:       %[[MULF_14:.*]] = arith.mulf %[[IF_11]], %{{.*}}
+//       CHECK:       scf.yield %[[MULF_14]]
 //       CHECK:     } else {
-//       CHECK:       scf.yield %{{.*}}
+//       CHECK:       scf.yield %[[CF0]]
 //       CHECK:     }
-//       CHECK:     %[[SELECT_12:.*]] = arith.select %[[CMPI_11]], %[[IF_11]], %{{.*}}#0
-//       CHECK:     memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}]
+//       CHECK:     %[[SELECT_13:.*]] = arith.select %[[CMPI_10]], %[[IF_12]], %{{.*}}#0
+//       CHECK:     memref.store %[[SELECT_13]], %{{.*}}[%[[C0]]]
 func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
   %cf0 = arith.constant 1.0 : f32
   %cf1 = arith.constant 33.0 : f32


        


More information about the Mlir-commits mailing list