[Mlir-commits] [mlir] 45cb414 - [mlir] Extend scf pipeling to support loop carried dependencies
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 21 18:32:53 PDT 2021
Author: thomasraoux
Date: 2021-07-21T18:32:38-07:00
New Revision: 45cb4140eb13804f9e2f62fc1c91f9a64eb81351
URL: https://github.com/llvm/llvm-project/commit/45cb4140eb13804f9e2f62fc1c91f9a64eb81351
DIFF: https://github.com/llvm/llvm-project/commit/45cb4140eb13804f9e2f62fc1c91f9a64eb81351.diff
LOG: [mlir] Extend scf pipeling to support loop carried dependencies
Differential Revision: https://reviews.llvm.org/D106325
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
mlir/test/Dialect/SCF/loop-pipelining.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 7cb36c958f4de..70297a93b1b50 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -74,7 +74,7 @@ struct LoopPipelinerInternal {
PatternRewriter &rewriter);
/// Emits the epilogue, this creates `maxStage - 1` part which will contain
/// operations from stages [i; maxStage], where i is the part index.
- void emitEpilogue(PatternRewriter &rewriter);
+ llvm::SmallVector<Value> emitEpilogue(PatternRewriter &rewriter);
};
bool LoopPipelinerInternal::initializeLoopInfo(
@@ -114,14 +114,25 @@ bool LoopPipelinerInternal::initializeLoopInfo(
.wasInterrupted())
return false;
- // TODO: Add support for loop with operands.
- if (forOp.getNumIterOperands() > 0)
+ // Only support loop carried dependency with a distance of 1. This means the
+ // source of all the scf.yield operands needs to be defined by operations in
+ // the loop.
+ if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
+ [this](Value operand) {
+ Operation *def = operand.getDefiningOp();
+ return !def || stages.find(def) == stages.end();
+ }))
return false;
-
return true;
}
void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) {
+ // Initialize the iteration argument to the loop initiale values.
+ for (BlockArgument &arg : forOp.getRegionIterArgs()) {
+ OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
+ setValueMapping(arg, operand.get(), 0);
+ }
+ auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (int64_t i = 0; i < maxStage; i++) {
// special handling for induction variable as the increment is implicit.
Value iv = rewriter.create<ConstantIndexOp>(forOp.getLoc(), lb + i);
@@ -138,6 +149,14 @@ void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) {
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
setValueMapping(op->getResult(destId), newOp->getResult(destId),
i - stages[op]);
+ // If the value is a loop carried dependency update the loop argument
+ // mapping.
+ for (OpOperand &operand : yield->getOpOperands()) {
+ if (operand.get() != op->getResult(destId))
+ continue;
+ setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
+ newOp->getResult(destId), i - stages[op] + 1);
+ }
}
}
}
@@ -173,7 +192,19 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
// stages. The initial values come from the prologue created above.
// Keep track of the kernel argument associated to each version of the
// values passed to the kernel.
- auto newLoopArg = llvm::to_vector<8>(forOp.getIterOperands());
+ llvm::SmallVector<Value> newLoopArg;
+ // For existing loop argument initialize them with the right version from the
+ // prologue.
+ for (auto retVal :
+ llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
+ Operation *def = retVal.value().getDefiningOp();
+ assert(def && "Only support loop carried dependencies of distance 1");
+ unsigned defStage = stages[def];
+ Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
+ [maxStage - defStage];
+ assert(valueVersion);
+ newLoopArg.push_back(valueVersion);
+ }
for (auto escape : crossStageValues) {
LiverangeInfo &info = escape.second;
Value value = escape.first;
@@ -210,6 +241,9 @@ void LoopPipelinerInternal::createKernel(
rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
BlockAndValueMapping mapping;
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
+ for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) {
+ mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
+ }
for (Operation *op : opOrder) {
int64_t useStage = stages[op];
auto *newOp = rewriter.clone(*op, mapping);
@@ -226,6 +260,23 @@ void LoopPipelinerInternal::createKernel(
rewriter.setInsertionPointAfter(newOp);
continue;
}
+ auto arg = operand.get().dyn_cast<BlockArgument>();
+ if (arg && arg.getOwner() == forOp.getBody()) {
+ // If the value is a loop carried value coming from stage N + 1 remap,
+ // it will become a direct use.
+ Value ret = forOp.getBody()->getTerminator()->getOperand(
+ arg.getArgNumber() - 1);
+ Operation *dep = ret.getDefiningOp();
+ if (!dep)
+ continue;
+ auto stageDep = stages.find(dep);
+ if (stageDep == stages.end() || stageDep->second == useStage)
+ continue;
+ assert(stageDep->second == useStage + 1);
+ newOp->setOperand(operand.getOperandNumber(),
+ mapping.lookupOrDefault(ret));
+ continue;
+ }
// For operands defined in a previous stage we need to remap it to use
// the correct region argument. We look for the right version of the
// Value based on the stage where it is used.
@@ -249,6 +300,9 @@ void LoopPipelinerInternal::createKernel(
// We create a mapping between original values and the associated loop
// returned values that will be needed by the epilogue.
llvm::SmallVector<Value> yieldOperands;
+ for (Value retVal : forOp.getBody()->getTerminator()->getOperands()) {
+ yieldOperands.push_back(mapping.lookupOrDefault(retVal));
+ }
for (auto &it : crossStageValues) {
int64_t version = maxStage - it.second.lastUseStage + 1;
unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
@@ -266,10 +320,22 @@ void LoopPipelinerInternal::createKernel(
version++);
yieldOperands.push_back(mapping.lookupOrDefault(it.first));
}
+ // Map the yield operand to the forOp returned value.
+ for (auto retVal :
+ llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
+ Operation *def = retVal.value().getDefiningOp();
+ assert(def && "Only support loop carried dependencies of distance 1");
+ unsigned defStage = stages[def];
+ setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
+ newForOp->getResult(retVal.index()),
+ maxStage - defStage + 1);
+ }
rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
}
-void LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) {
+llvm::SmallVector<Value>
+LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) {
+ llvm::SmallVector<Value> returnValues(forOp->getNumResults());
// Emit
diff erent versions of the induction variable. They will be
// removed by dead code if not used.
for (int64_t i = 0; i < maxStage; i++) {
@@ -295,9 +361,27 @@ void LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) {
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
setValueMapping(op->getResult(destId), newOp->getResult(destId),
maxStage - stages[op] + i);
+ // If the value is a loop carried dependency update the loop argument
+ // mapping and keep track of the last version to replace the original
+ // forOp uses.
+ for (OpOperand &operand :
+ forOp.getBody()->getTerminator()->getOpOperands()) {
+ if (operand.get() != op->getResult(destId))
+ continue;
+ unsigned version = maxStage - stages[op] + i + 1;
+ // If the version is greater than maxStage it means it maps to the
+ // original forOp returned value.
+ if (version > maxStage) {
+ returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
+ continue;
+ }
+ setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
+ newOp->getResult(destId), version);
+ }
}
}
}
+ return returnValues;
}
void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
@@ -361,12 +445,11 @@ struct ForLoopPipelining : public OpRewritePattern<ForOp> {
// 4. Emit the epilogue after the new forOp.
rewriter.setInsertionPointAfter(newForOp);
- pipeliner.emitEpilogue(rewriter);
+ llvm::SmallVector<Value> returnValues = pipeliner.emitEpilogue(rewriter);
// 5. Erase the original loop and replace the uses with the epilogue output.
if (forOp->getNumResults() > 0)
- rewriter.replaceOp(
- forOp, newForOp.getResults().take_front(forOp->getNumResults()));
+ rewriter.replaceOp(forOp, returnValues);
else
rewriter.eraseOp(forOp);
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index fb3cce1ed7869..1b7e62571bb95 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -171,3 +171,118 @@ func @multiple_uses(%A: memref<?xf32>, %result: memref<?xf32>) {
} { __test_pipelining_loop__ }
return
}
+
+// -----
+
+// CHECK-LABEL: loop_carried(
+// CHECK-SAME: %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>) {
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[C3:.*]] = constant 3 : index
+// CHECK-DAG: %[[CSTF:.*]] = constant 1.000000e+00 : f32
+// Prologue:
+// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+// Kernel:
+// CHECK-NEXT: %[[LR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
+// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
+// CHECK-SAME: %[[LARG:.*]] = %[[L0]]) -> (f32, f32) {
+// CHECK-NEXT: %[[ADD0:.*]] = addf %[[LARG]], %[[C]] : f32
+// CHECK-NEXT: %[[IV1:.*]] = addi %[[IV]], %[[C1]] : index
+// CHECK-NEXT: %[[L1:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
+// CHECK-NEXT: scf.yield %[[ADD0]], %[[L1]] : f32, f32
+// CHECK-NEXT: }
+// Epilogue:
+// CHECK-NEXT: %[[ADD1:.*]] = addf %[[LR]]#1, %[[LR]]#0 : f32
+// CHECK-NEXT: memref.store %[[ADD1]], %[[R]][%[[C0]]] : memref<?xf32>
+func @loop_carried(%A: memref<?xf32>, %result: memref<?xf32>) {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %cf = constant 1.0 : f32
+ %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+ %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref<?xf32>
+ %A1_elem = addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+ scf.yield %A1_elem : f32
+ } { __test_pipelining_loop__ }
+ memref.store %r, %result[%c0] : memref<?xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: backedge_
diff erent_stage
+// CHECK-SAME: (%[[A:.*]]: memref<?xf32>) -> f32 {
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = constant 2 : index
+// CHECK-DAG: %[[CSTF:.*]] = constant 1.000000e+00 : f32
+// Prologue:
+// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+// CHECK-NEXT: %[[ADD0:.*]] = addf %[[L0]], %[[CSTF]] : f32
+// CHECK-NEXT: %[[L1:.*]] = memref.load %[[A]][%[[C1]]] : memref<?xf32>
+// Kernel:
+// CHECK-NEXT: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]]
+// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
+// CHECK-SAME: %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) {
+// CHECK-NEXT: %[[MUL0:.*]] = mulf %[[CSTF]], %[[ADDARG]] : f32
+// CHECK-NEXT: %[[ADD1:.*]] = addf %[[LARG]], %[[MUL0]] : f32
+// CHECK-NEXT: %[[IV2:.*]] = addi %[[IV]], %[[C2]] : index
+// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref<?xf32>
+// CHECK-NEXT: scf.yield %[[MUL0]], %[[ADD1]], %[[L2]] : f32, f32, f32
+// CHECK-NEXT: }
+// Epilogue:
+// CHECK-NEXT: %[[MUL1:.*]] = mulf %[[CSTF]], %[[R]]#1 : f32
+// CHECK-NEXT: %[[ADD2:.*]] = addf %[[R]]#2, %[[MUL1]] : f32
+// CHECK-NEXT: %[[MUL2:.*]] = mulf %[[CSTF]], %[[ADD2]] : f32
+// CHECK-NEXT: return %[[MUL2]] : f32
+func @backedge_
diff erent_stage(%A: memref<?xf32>) -> f32 {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %cf = constant 1.0 : f32
+ %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+ %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ %A1_elem = addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
+ %A2_elem = mulf %cf, %A1_elem { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : f32
+ scf.yield %A2_elem : f32
+ } { __test_pipelining_loop__ }
+ return %r : f32
+}
+
+// -----
+
+// CHECK-LABEL: backedge_same_stage
+// CHECK-SAME: (%[[A:.*]]: memref<?xf32>) -> f32 {
+// CHECK-DAG: %[[C0:.*]] = constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = constant 1 : index
+// CHECK-DAG: %[[C3:.*]] = constant 3 : index
+// CHECK-DAG: %[[CSTF:.*]] = constant 1.000000e+00 : f32
+// Prologue:
+// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+// Kernel:
+// CHECK-NEXT: %[[R:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
+// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
+// CHECK-SAME: %[[LARG:.*]] = %[[L0]]) -> (f32, f32) {
+// CHECK-NEXT: %[[ADD0:.*]] = addf %[[LARG]], %[[C]] : f32
+// CHECK-NEXT: %[[MUL0:.*]] = mulf %[[CSTF]], %[[ADD0]] : f32
+// CHECK-NEXT: %[[IV1:.*]] = addi %[[IV]], %[[C1]] : index
+// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
+// CHECK-NEXT: scf.yield %[[MUL0]], %[[L2]] : f32, f32
+// CHECK-NEXT: }
+// Epilogue:
+// CHECK-NEXT: %[[ADD1:.*]] = addf %[[R]]#1, %[[R]]#0 : f32
+// CHECK-NEXT: %[[MUL1:.*]] = mulf %[[CSTF]], %[[ADD1]] : f32
+// CHECK-NEXT: return %[[MUL1]] : f32
+func @backedge_same_stage(%A: memref<?xf32>) -> f32 {
+ %c0 = constant 0 : index
+ %c1 = constant 1 : index
+ %c4 = constant 4 : index
+ %cf = constant 1.0 : f32
+ %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+ %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ %A1_elem = addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+ %A2_elem = mulf %cf, %A1_elem { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
+ scf.yield %A2_elem : f32
+ } { __test_pipelining_loop__ }
+ return %r : f32
+}
More information about the Mlir-commits
mailing list