[Mlir-commits] [mlir] 8da5aa1 - [mlir][SCF] Fix dynamic loop pipeline peeling for num_stages > total_iters (#112418)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 15 13:13:52 PDT 2024
Author: SJW
Date: 2024-10-15T13:13:49-07:00
New Revision: 8da5aa16f65bc297663573bacd3030f975b9fcde
URL: https://github.com/llvm/llvm-project/commit/8da5aa16f65bc297663573bacd3030f975b9fcde
DIFF: https://github.com/llvm/llvm-project/commit/8da5aa16f65bc297663573bacd3030f975b9fcde.diff
LOG: [mlir][SCF] Fix dynamic loop pipeline peeling for num_stages > total_iters (#112418)
When pipelining an `scf.for` with dynamic loop bounds, the epilogue
ramp-down must align with the prologue when num_stages >
total_iterations.
For example:
```
scf.for (0..ub) {
load(i)
add(i)
store(i)
}
```
When num_stages=3 the pipeline follows:
```
load(0) - add(0) - scf.for (0..ub-2) - store(ub-2)
load(1) - - add(ub-1) - store(ub-1)
```
The trailing `store(ub-2)`, `i=ub-2`, must align with the ramp-up for
`i=0` when `ub < num_stages-1`, so the index `i` should be `max(0,
ub-2)` and each subsequent index is an increment. The predicate must
also handle this scenario, so it becomes `predicate[0] =
total_iterations > epilogue_stage`.
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
mlir/test/Dialect/SCF/loop-pipelining.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 83c9cf69ba0364..1b458f410af601 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -642,22 +642,25 @@ LogicalResult
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
llvm::SmallVector<Value> &returnValues) {
Location loc = forOp.getLoc();
+ Type t = lb.getType();
+
// Emit
diff erent versions of the induction variable. They will be
// removed by dead code if not used.
- // bounds_range = ub - lb
- // total_iterations = (bounds_range + step - 1) / step
- Type t = lb.getType();
- Value zero =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
- Value one =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
- Value minusOne =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+ auto createConst = [&](int v) {
+ return rewriter.create<arith::ConstantOp>(loc,
+ rewriter.getIntegerAttr(t, v));
+ };
+
+ // total_iterations = cdiv(range_
diff , step);
+ // - range_
diff = ub - lb
+ // - total_iterations = (range_
diff + step + (step < 0 ? 1 : -1)) / step
+ Value zero = createConst(0);
+ Value one = createConst(1);
Value stepLessZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, step, zero);
Value stepDecr =
- rewriter.create<arith::SelectOp>(loc, stepLessZero, one, minusOne);
+ rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1));
Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
@@ -665,25 +668,31 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
+ // If total_iters < max_stage, start the epilogue at zero to match the
+ // ramp-up in the prologue.
+ // start_iter = max(0, total_iters - max_stage)
+ Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations,
+ createConst(maxStage));
+ iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);
+
+ // Capture predicates for dynamic loops.
SmallVector<Value> predicates(maxStage + 1);
- for (int64_t i = 0; i < maxStage; i++) {
- // iterI = total_iters - 1 - i
- // May go negative...
- Value minusI =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
- Value iterI = rewriter.create<arith::AddIOp>(
- loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minusOne),
- minusI);
+
+ for (int64_t i = 1; i <= maxStage; i++) {
// newLastIter = lb + step * iterI
Value newlastIter = rewriter.create<arith::AddIOp>(
loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
- setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+ setValueMapping(forOp.getInductionVar(), newlastIter, i);
+
+ // increment to next iterI
+ iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);
if (dynamicLoop) {
- // pred = iterI >= 0
- predicates[i + 1] = rewriter.create<arith::CmpIOp>(
- loc, arith::CmpIPredicate::sge, iterI, zero);
+ // Disable stages when `i` is greater than total_iters.
+ // pred = total_iters >= i
+ predicates[i] = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::sge, totalIterations, createConst(i));
}
}
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index af49d2afc049ba..c879c83275bf86 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -767,6 +767,7 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// Check for predicated epilogue for dynamic loop.
// CHECK-LABEL: dynamic_loop(
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CM1:.*]] = arith.constant -1 : index
// CHECK: %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
@@ -779,32 +780,32 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// CHECK: scf.yield %[[ADDF_24]], %[[LOAD_27]]
// CHECK: }
// CHECK: %[[CMPI_10:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]]
-// CHECK: %[[SEL_10:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]]
-// CHECK: %[[SUBI_10:.*]] = arith.subi %[[UB]], %[[LB]]
-// CHECK: %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %[[STEP]]
-// CHECK: %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %[[SEL_10]]
-// CHECK: %[[DIVSI_13:.*]] = arith.divsi %[[ADDI_12]], %[[STEP]]
-// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVSI_13]], %[[CM1]]
-// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
-// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
-// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %[[C0]]
-// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVSI_13]], %{{.*}}-1
-// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
-// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
-// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
-// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %[[C0]]
-// CHECK: scf.if %[[CMPI_17]] {
-// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
+// CHECK: %[[SELECT_11:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]]
+// CHECK: %[[SUBI_12:.*]] = arith.subi %[[UB]], %[[LB]]
+// CHECK: %[[ADDI_13:.*]] = arith.addi %[[SUBI_12]], %[[STEP]]
+// CHECK: %[[ADDI_14:.*]] = arith.addi %[[ADDI_13]], %[[SELECT_11]]
+// CHECK: %[[DIVSI_15:.*]] = arith.divsi %[[ADDI_14]], %[[STEP]]
+// CHECK: %[[SUBI_17:.*]] = arith.subi %[[DIVSI_15]], %[[C2]]
+// CHECK: %[[MAXSI_18:.*]] = arith.maxsi %[[SUBI_17]], %[[C0]]
+// CHECK: %[[MULI_19:.*]] = arith.muli %[[STEP]], %[[MAXSI_18]]
+// CHECK: %[[ADDI_20:.*]] = arith.addi %[[LB]], %[[MULI_19]]
+// CHECK: %[[ADDI_21:.*]] = arith.addi %[[MAXSI_18]], %[[C1]]
+// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[DIVSI_15]], %[[C1]]
+// CHECK: %[[MULI_23:.*]] = arith.muli %[[STEP]], %[[ADDI_21]]
+// CHECK: %[[ADDI_24:.*]] = arith.addi %[[LB]], %[[MULI_23]]
+// CHECK: %[[CMPI_25:.*]] = arith.cmpi sge, %[[DIVSI_15]], %[[C2]]
+// CHECK: scf.if %[[CMPI_22]] {
+// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_20]]]
// CHECK: } else {
// CHECK: }
-// CHECK: %[[IF_23:.*]] = scf.if %[[CMPI_22]] -> (f32) {
-// CHECK: %[[ADDF_24:.*]] = arith.addf %{{.*}}#1, %{{.*}}
-// CHECK: scf.yield %[[ADDF_24]]
+// CHECK: %[[IF_26:.*]] = scf.if %[[CMPI_25]]
+// CHECK: %[[ADDF_27:.*]] = arith.addf %{{.*}}#1, %{{.*}}
+// CHECK: scf.yield %[[ADDF_27]]
// CHECK: } else {
// CHECK: scf.yield %{{.*}}
// CHECK: }
-// CHECK: scf.if %[[CMPI_22]] {
-// CHECK: memref.store %[[IF_23]], %{{.*}}[%[[ADDI_16]]]
+// CHECK: scf.if %[[CMPI_25]] {
+// CHECK: memref.store %[[IF_26]], %{{.*}}[%[[ADDI_24]]]
// CHECK: } else {
// CHECK: }
// CHECK: return
@@ -842,6 +843,7 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CM1:.*]] = arith.constant -1 : index
+// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.000000e+00
// CHECK: %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UBM]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
// CHECK: %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
@@ -856,22 +858,21 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
// CHECK: %[[ADDI_7:.*]] = arith.addi %[[SUBI_6]], %[[STEP]]
// CHECK: %[[ADDI_8:.*]] = arith.addi %[[ADDI_7]], %[[SELECT_5]]
// CHECK: %[[DIVSI_9:.*]] = arith.divsi %[[ADDI_8]], %[[STEP]]
-// CHECK: %[[ADDI_10:.*]] = arith.addi %[[DIVSI_9]], %[[CM1]]
-// CHECK: %[[CMPI_11:.*]] = arith.cmpi sge, %[[ADDI_10]], %[[C0]]
-// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_11]]
-// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
-// CHECK: scf.yield %[[ADDF_13]]
+// CHECK: %[[CMPI_10:.*]] = arith.cmpi sge, %[[DIVSI_9]], %[[C1]]
+// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_10]]
+// CHECK: %[[ADDF_14:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
+// CHECK: scf.yield %[[ADDF_14]]
// CHECK: } else {
-// CHECK: scf.yield %{{.*}}
+// CHECK: scf.yield %[[CF0]]
// CHECK: }
-// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_11]]
-// CHECK: %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}}
-// CHECK: scf.yield %[[MULF_13]]
+// CHECK: %[[IF_12:.*]] = scf.if %[[CMPI_10]]
+// CHECK: %[[MULF_14:.*]] = arith.mulf %[[IF_11]], %{{.*}}
+// CHECK: scf.yield %[[MULF_14]]
// CHECK: } else {
-// CHECK: scf.yield %{{.*}}
+// CHECK: scf.yield %[[CF0]]
// CHECK: }
-// CHECK: %[[SELECT_12:.*]] = arith.select %[[CMPI_11]], %[[IF_11]], %{{.*}}#0
-// CHECK: memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}]
+// CHECK: %[[SELECT_13:.*]] = arith.select %[[CMPI_10]], %[[IF_12]], %{{.*}}#0
+// CHECK: memref.store %[[SELECT_13]], %{{.*}}[%[[C0]]]
func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
%cf0 = arith.constant 1.0 : f32
%cf1 = arith.constant 33.0 : f32
More information about the Mlir-commits
mailing list