[Mlir-commits] [mlir] e66f97e - [mlir] Fix loop pipelining when the operand of `yield` is not defined in the loop body (#75423)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Dec 13 19:19:17 PST 2023
Author: Keren Zhou
Date: 2023-12-13T19:19:13-08:00
New Revision: e66f97e8a80bdd1acebfe6833380467a0454d2e1
URL: https://github.com/llvm/llvm-project/commit/e66f97e8a80bdd1acebfe6833380467a0454d2e1
DIFF: https://github.com/llvm/llvm-project/commit/e66f97e8a80bdd1acebfe6833380467a0454d2e1.diff
LOG: [mlir] Fix loop pipelining when the operand of `yield` is not defined in the loop body (#75423)
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
mlir/test/Dialect/SCF/loop-pipelining.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 6c36600975a597..7d45b484f76575 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -90,7 +90,8 @@ struct LoopPipelinerInternal {
RewriterBase &rewriter);
/// Emits the epilogue, this creates `maxStage - 1` part which will contain
/// operations from stages [i; maxStage], where i is the part index.
- llvm::SmallVector<Value> emitEpilogue(RewriterBase &rewriter);
+ void emitEpilogue(RewriterBase &rewriter,
+ llvm::SmallVector<Value> &returnValues);
};
bool LoopPipelinerInternal::initializeLoopInfo(
@@ -175,15 +176,18 @@ bool LoopPipelinerInternal::initializeLoopInfo(
}
}
- // Only support loop carried dependency with a distance of 1. This means the
- // source of all the scf.yield operands needs to be defined by operations in
- // the loop.
+ // Support only loop-carried dependencies with a distance of one iteration or
+ // those defined outside of the loop. This means that any dependency within a
+ // loop should either be on the immediately preceding iteration, the current
+ // iteration, or on variables whose values are set before entering the loop.
if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
[this](Value operand) {
Operation *def = operand.getDefiningOp();
- return !def || !stages.contains(def);
+ return !def ||
+ (!stages.contains(def) && forOp->isAncestor(def));
})) {
- LDBG("--only support loop carried dependency with a distance of 1 -> BAIL");
+ LDBG("--only support loop carried dependency with a distance of 1 or "
+ "defined outside of the loop -> BAIL");
return false;
}
annotateFn = options.annotateFn;
@@ -341,12 +345,17 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
for (const auto &retVal :
llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
Operation *def = retVal.value().getDefiningOp();
- assert(def && "Only support loop carried dependencies of distance 1");
- unsigned defStage = stages[def];
- Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
- [maxStage - defStage];
- assert(valueVersion);
- newLoopArg.push_back(valueVersion);
+ assert(def && "Only support loop carried dependencies of distance of 1 or "
+ "outside the loop");
+ auto defStage = stages.find(def);
+ if (defStage != stages.end()) {
+ Value valueVersion =
+ valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
+ [maxStage - defStage->second];
+ assert(valueVersion);
+ newLoopArg.push_back(valueVersion);
+ } else
+ newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
}
for (auto escape : crossStageValues) {
LiverangeInfo &info = escape.second;
@@ -551,21 +560,25 @@ LogicalResult LoopPipelinerInternal::createKernel(
for (const auto &retVal :
llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
Operation *def = retVal.value().getDefiningOp();
- assert(def && "Only support loop carried dependencies of distance 1");
- unsigned defStage = stages[def];
- if (defStage > 0) {
+ assert(def && "Only support loop carried dependencies of distance of 1 or "
+ "defined outside the loop");
+ auto defStage = stages.find(def);
+ if (defStage == stages.end()) {
+ for (unsigned int stage = 1; stage <= maxStage; stage++)
+ setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
+ retVal.value(), stage);
+ } else if (defStage->second > 0) {
setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
newForOp->getResult(retVal.index()),
- maxStage - defStage + 1);
+ maxStage - defStage->second + 1);
}
}
rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
return success();
}
-llvm::SmallVector<Value>
-LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
- llvm::SmallVector<Value> returnValues(forOp->getNumResults());
+void LoopPipelinerInternal::emitEpilogue(
+ RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
// Emit
diff erent versions of the induction variable. They will be
// removed by dead code if not used.
for (int64_t i = 0; i < maxStage; i++) {
@@ -628,7 +641,6 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
}
}
}
- return returnValues;
}
void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
@@ -685,7 +697,7 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
if (options.peelEpilogue) {
// 4. Emit the epilogue after the new forOp.
rewriter.setInsertionPointAfter(newForOp);
- returnValues = pipeliner.emitEpilogue(rewriter);
+ pipeliner.emitEpilogue(rewriter, returnValues);
}
// 5. Erase the original loop and replace the uses with the epilogue output.
if (forOp->getNumResults() > 0)
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 8a57ddccfee665..a18c850c3f05f1 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -770,3 +770,47 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
} { __test_pipelining_loop__ }
return
}
+
+// -----
+
+// CHECK-LABEL: yield_constant_loop(
+// CHECK-SAME: %[[A:.*]]: memref<?xf32>) -> f32 {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
+// Prologue:
+// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+// Kernel:
+// CHECK-NEXT: %[[L1:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
+// CHECK-SAME: step %[[C1]] iter_args(%[[ARG0:.*]] = %[[CST2]], %[[ARG1:.*]] = %[[L0]]) -> (f32, f32) {
+// CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG1]], %[[ARG0]] : f32
+// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CST0]] : f32
+// CHECK-NEXT: memref.store %[[MUL0]], %[[A]][%[[IV]]] : memref<?xf32>
+// CHECK-NEXT: %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index
+// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
+// CHECK-NEXT: scf.yield %[[CST0]], %[[L2]] : f32
+// CHECK-NEXT: }
+// Epilogue:
+// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[L1]]#1, %[[CST0]] : f32
+// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CST0]] : f32
+// CHECK-NEXT: memref.store %[[MUL1]], %[[A]][%[[C3]]] : memref<?xf32>
+// CHECK-NEXT: return %[[L1]]#0 : f32
+
+func.func @yield_constant_loop(%A: memref<?xf32>) -> f32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %cf0 = arith.constant 0.0 : f32
+ %cf2 = arith.constant 2.0 : f32
+ %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf2) -> f32 {
+ %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref<?xf32>
+ %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+ %A2_elem = arith.mulf %cf0, %A1_elem { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
+ memref.store %A2_elem, %A[%i0] { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ scf.yield %cf0: f32
+ } { __test_pipelining_loop__ }
+ return %r : f32
+}
+
More information about the Mlir-commits
mailing list