[llvm-branch-commits] [mlir] 19e068b - [MLIR][SCF] Handle more cases in pipelining transform (#74007)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Dec 4 11:02:15 PST 2023
Author: Thomas Raoux
Date: 2023-12-01T21:28:21-08:00
New Revision: 19e068b048591feb8fa66b164669c753090dfd3a
URL: https://github.com/llvm/llvm-project/commit/19e068b048591feb8fa66b164669c753090dfd3a
DIFF: https://github.com/llvm/llvm-project/commit/19e068b048591feb8fa66b164669c753090dfd3a.diff
LOG: [MLIR][SCF] Handle more cases in pipelining transform (#74007)
-Fix case where an op is scheduled in stage 0 and used with a distance
of 1
-Fix case where we don't peel the epilogue and a value not part of the
last stage is used outside the loop.
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
mlir/test/Dialect/SCF/loop-pipelining.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 5537a8b212c51..20fa8089201aa 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -61,6 +61,11 @@ struct LoopPipelinerInternal {
/// `idx` of `key` in the epilogue.
void setValueMapping(Value key, Value el, int64_t idx);
+ /// Return the defining op of the given value, if the Value is an argument of
+ /// the loop return the associated defining op in the loop and its distance to
+ /// the Value.
+ std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
+
public:
/// Initalize the information for the given `op`, return true if it
/// satisfies the pre-condition to apply pipelining.
@@ -240,11 +245,12 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
unsigned stage = stages[op];
auto analyzeOperand = [&](OpOperand &operand) {
- Operation *def = operand.get().getDefiningOp();
+ auto [def, distance] = getDefiningOpAndDistance(operand.get());
if (!def)
return;
auto defStage = stages.find(def);
- if (defStage == stages.end() || defStage->second == stage)
+ if (defStage == stages.end() || defStage->second == stage ||
+ defStage->second == stage + distance)
return;
assert(stage > defStage->second);
LiverangeInfo &info = crossStageValues[operand.get()];
@@ -261,6 +267,25 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
return crossStageValues;
}
+std::pair<Operation *, int64_t>
+LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
+ int64_t distance = 0;
+ if (auto arg = dyn_cast<BlockArgument>(value)) {
+ if (arg.getOwner() != forOp.getBody())
+ return {nullptr, 0};
+ // Ignore induction variable.
+ if (arg.getArgNumber() == 0)
+ return {nullptr, 0};
+ distance++;
+ value =
+ forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
+ }
+ Operation *def = value.getDefiningOp();
+ if (!def)
+ return {nullptr, 0};
+ return {def, distance};
+}
+
scf::ForOp LoopPipelinerInternal::createKernelLoop(
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
&crossStageValues,
@@ -366,10 +391,9 @@ LogicalResult LoopPipelinerInternal::createKernel(
rewriter.setInsertionPointAfter(newOp);
continue;
}
- auto arg = dyn_cast<BlockArgument>(operand->get());
+ Value source = operand->get();
+ auto arg = dyn_cast<BlockArgument>(source);
if (arg && arg.getOwner() == forOp.getBody()) {
- // If the value is a loop carried value coming from stage N + 1 remap,
- // it will become a direct use.
Value ret = forOp.getBody()->getTerminator()->getOperand(
arg.getArgNumber() - 1);
Operation *dep = ret.getDefiningOp();
@@ -378,15 +402,19 @@ LogicalResult LoopPipelinerInternal::createKernel(
auto stageDep = stages.find(dep);
if (stageDep == stages.end() || stageDep->second == useStage)
continue;
- assert(stageDep->second == useStage + 1);
- nestedNewOp->setOperand(operand->getOperandNumber(),
- mapping.lookupOrDefault(ret));
- continue;
+ // If the value is a loop carried value coming from stage N + 1 remap,
+ // it will become a direct use.
+ if (stageDep->second == useStage + 1) {
+ nestedNewOp->setOperand(operand->getOperandNumber(),
+ mapping.lookupOrDefault(ret));
+ continue;
+ }
+ source = ret;
}
// For operands defined in a previous stage we need to remap it to use
// the correct region argument. We look for the right version of the
// Value based on the stage where it is used.
- Operation *def = operand->get().getDefiningOp();
+ Operation *def = source.getDefiningOp();
if (!def)
continue;
auto stageDef = stages.find(def);
@@ -418,9 +446,29 @@ LogicalResult LoopPipelinerInternal::createKernel(
// We create a mapping between original values and the associated loop
// returned values that will be needed by the epilogue.
llvm::SmallVector<Value> yieldOperands;
- for (Value retVal : forOp.getBody()->getTerminator()->getOperands()) {
- yieldOperands.push_back(mapping.lookupOrDefault(retVal));
+ for (OpOperand &yieldOperand :
+ forOp.getBody()->getTerminator()->getOpOperands()) {
+ Value source = mapping.lookupOrDefault(yieldOperand.get());
+ // When we don't peel the epilogue and the yield value is used outside the
+ // loop we need to make sure we return the version from numStages -
+ // defStage.
+ if (!peelEpilogue &&
+ !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
+ Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
+ if (def) {
+ auto defStage = stages.find(def);
+ if (defStage != stages.end() && defStage->second < maxStage) {
+ Value pred = predicates[defStage->second];
+ source = rewriter.create<arith::SelectOp>(
+ pred.getLoc(), pred, source,
+ newForOp.getBody()
+ ->getArguments()[yieldOperand.getOperandNumber() + 1]);
+ }
+ }
+ }
+ yieldOperands.push_back(source);
}
+
for (auto &it : crossStageValues) {
int64_t version = maxStage - it.second.lastUseStage + 1;
unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
@@ -444,9 +492,11 @@ LogicalResult LoopPipelinerInternal::createKernel(
Operation *def = retVal.value().getDefiningOp();
assert(def && "Only support loop carried dependencies of distance 1");
unsigned defStage = stages[def];
- setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
- newForOp->getResult(retVal.index()),
- maxStage - defStage + 1);
+ if (defStage > 0) {
+ setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
+ newForOp->getResult(retVal.index()),
+ maxStage - defStage + 1);
+ }
}
rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
return success();
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 0309287e409c1..4cd686d2cdb86 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -670,4 +670,56 @@ func.func @backedge_mix_order(%A: memref<?xf32>) -> f32 {
scf.yield %A3_elem : f32
} { __test_pipelining_loop__ }
return %r : f32
-}
\ No newline at end of file
+}
+
+// -----
+
+// CHECK-LABEL: @distance_1_use
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// Prologue:
+// CHECK: %[[L0:.+]] = memref.load %{{.*}}[%[[C0]]] : memref<?xf32>
+// CHECK: %[[L1:.+]] = memref.load %{{.*}}[%[[C1]]] : memref<?xf32>
+// CHECK: %[[R:.+]]:5 = scf.for {{.*}} iter_args(%[[IDX0:.+]] = %[[C2]], %[[L2:.+]] = %[[L0]], %[[L3:.+]] = %[[L1]]
+// CHECK: %[[L4:.+]] = memref.load %{{.*}}[%[[IDX0]]] : memref<?xf32>
+// CHECK: %[[IDX1:.+]] = arith.addi %[[IDX0]], %[[C1]] : index
+// CHECK: memref.store %[[L2]]
+// CHECK: scf.yield %[[IDX1]], %[[L3]], %[[L4]]
+func.func @distance_1_use(%A: memref<?xf32>, %result: memref<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %cf = arith.constant 1.0 : f32
+ %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%idx = %c0) -> (index) {
+ %A_elem = memref.load %A[%idx] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : memref<?xf32>
+ %idx1 = arith.addi %idx, %c1 { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : index
+ memref.store %A_elem, %result[%idx] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ scf.yield %idx1 : index
+ } { __test_pipelining_loop__ }
+ return
+}
+
+// -----
+
+// NOEPILOGUE-LABEL: stage_0_value_escape(
+func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %cf = arith.constant 1.0 : f32
+// NOEPILOGUE: %[[C3:.+]] = arith.constant 3 : index
+// NOEPILOGUE: %[[A:.+]] = arith.addf
+// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[A]],
+// NOEPILOGUE: %[[C:.+]] = arith.cmpi slt, %[[IV]], %[[C3]] : index
+// NOEPILOGUE: %[[S:.+]] = arith.select %[[C]], %{{.+}}, %[[ARG]] : f32
+// NOEPILOGUE: scf.yield %[[S]]
+ %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+ %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref<?xf32>
+ %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+ memref.store %A1_elem, %result[%c0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+ scf.yield %A1_elem : f32
+ } { __test_pipelining_loop__ }
+ memref.store %r, %result[%c1] : memref<?xf32>
+ return
+}
More information about the llvm-branch-commits
mailing list