[Mlir-commits] [mlir] 117db47 - [mlir][scf] Fix bug in software pipeliner and simplify the logic
Thomas Raoux
llvmlistbot at llvm.org
Wed Mar 8 12:06:32 PST 2023
Author: Thomas Raoux
Date: 2023-03-08T20:06:07Z
New Revision: 117db47d02c174e2ec039fa8b6a97381106e6238
URL: https://github.com/llvm/llvm-project/commit/117db47d02c174e2ec039fa8b6a97381106e6238
DIFF: https://github.com/llvm/llvm-project/commit/117db47d02c174e2ec039fa8b6a97381106e6238.diff
LOG: [mlir][scf] Fix bug in software pipeliner and simplify the logic
Fix bug when pipelining while interleaving stages. Re-do the logic to
only consider cloned operands when updating the use-def chain.
Differential Revision: https://reviews.llvm.org/D145598
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
mlir/test/Dialect/SCF/loop-pipelining.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index b9182f5a073ed..4a7175b109614 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -294,81 +294,6 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
return newForOp;
}
-/// Replace any use of `target` with `replacement` in `op`'s operands or within
-/// `op`'s nested regions.
-static void replaceInOp(Operation *op, Value target, Value replacement) {
- for (auto &use : llvm::make_early_inc_range(target.getUses())) {
- if (op->isAncestor(use.getOwner()))
- use.set(replacement);
- }
-}
-
-/// Given a cloned op in the new kernel body, updates induction variable uses.
-/// We replace it with a version incremented based on the stage where it is
-/// used.
-static void updateInductionVariableUses(RewriterBase &rewriter, Location loc,
- Operation *newOp, Value newForIv,
- unsigned maxStage, unsigned useStage,
- unsigned step) {
- rewriter.setInsertionPoint(newOp);
- Value offset = rewriter.create<arith::ConstantIndexOp>(
- loc, (maxStage - useStage) * step);
- Value iv = rewriter.create<arith::AddIOp>(loc, newForIv, offset);
- replaceInOp(newOp, newForIv, iv);
- rewriter.setInsertionPointAfter(newOp);
-}
-
-/// If the value is a loop carried value coming from stage N + 1 remap, it will
-/// become a direct use.
-static void updateIterArgUses(RewriterBase &rewriter, IRMapping &bvm,
- Operation *newOp, ForOp oldForOp, ForOp newForOp,
- unsigned useStage,
- const DenseMap<Operation *, unsigned> &stages) {
-
- for (unsigned i = 0; i < oldForOp.getNumRegionIterArgs(); i++) {
- Value yieldedVal = oldForOp.getBody()->getTerminator()->getOperand(i);
- Operation *dep = yieldedVal.getDefiningOp();
- if (!dep)
- continue;
- auto stageDep = stages.find(dep);
- if (stageDep == stages.end() || stageDep->second == useStage)
- continue;
- if (stageDep->second != useStage + 1)
- continue;
- Value replacement = bvm.lookup(yieldedVal);
- replaceInOp(newOp, newForOp.getRegionIterArg(i), replacement);
- }
-}
-
-/// For operands defined in a previous stage we need to remap it to use the
-/// correct region argument. We look for the right version of the Value based
-/// on the stage where it is used.
-static void updateCrossStageUses(
- RewriterBase &rewriter, Operation *newOp, IRMapping &bvm, ForOp newForOp,
- unsigned useStage, const DenseMap<Operation *, unsigned> &stages,
- const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
- // Because we automatically cloned the sub-regions, there's no simple way
- // to walk the nested regions in pairs of (oldOps, newOps), so we just
- // traverse the set of remapped loop arguments, filter which ones are
- // relevant, and replace any uses.
- for (auto [remapPair, newIterIdx] : loopArgMap) {
- auto [crossArgValue, stageIdx] = remapPair;
- Operation *def = crossArgValue.getDefiningOp();
- assert(def);
- unsigned stageDef = stages.lookup(def);
- if (useStage <= stageDef || useStage - stageDef != stageIdx)
- continue;
-
- // Use "lookupOrDefault" for the target value because some operations
- // are remapped, while in other cases the original will be present.
- Value target = bvm.lookupOrDefault(crossArgValue);
- Value replacement = newForOp.getRegionIterArg(newIterIdx);
-
- // Replace uses in the new op's operands and any nested uses.
- replaceInOp(newOp, target, replacement);
- }
-}
-
void LoopPipelinerInternal::createKernel(
scf::ForOp newForOp,
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
@@ -400,16 +325,59 @@ void LoopPipelinerInternal::createKernel(
for (Operation *op : opOrder) {
int64_t useStage = stages[op];
auto *newOp = rewriter.clone(*op, mapping);
-
- // Within the kernel body, update uses of the induction variable, uses of
- // the original iter args, and uses of cross stage values.
- updateInductionVariableUses(rewriter, forOp.getLoc(), newOp,
- newForOp.getInductionVar(), maxStage,
- stages[op], step);
- updateIterArgUses(rewriter, mapping, newOp, forOp, newForOp, useStage,
- stages);
- updateCrossStageUses(rewriter, newOp, mapping, newForOp, useStage, stages,
- loopArgMap);
+ SmallVector<OpOperand *> operands;
+ // Collect all the operands for the cloned op and its nested ops.
+ op->walk([&operands](Operation *nestedOp) {
+ for (OpOperand &operand : nestedOp->getOpOperands()) {
+ operands.push_back(&operand);
+ }
+ });
+ for (OpOperand *operand : operands) {
+ Operation *nestedNewOp = mapping.lookup(operand->getOwner());
+ // Special case for the induction variable uses. We replace it with a
+ // version incremented based on the stage where it is used.
+ if (operand->get() == forOp.getInductionVar()) {
+ rewriter.setInsertionPoint(newOp);
+ Value offset = rewriter.create<arith::ConstantIndexOp>(
+ forOp.getLoc(), (maxStage - stages[op]) * step);
+ Value iv = rewriter.create<arith::AddIOp>(
+ forOp.getLoc(), newForOp.getInductionVar(), offset);
+ nestedNewOp->setOperand(operand->getOperandNumber(), iv);
+ rewriter.setInsertionPointAfter(newOp);
+ continue;
+ }
+ auto arg = operand->get().dyn_cast<BlockArgument>();
+ if (arg && arg.getOwner() == forOp.getBody()) {
+ // If the value is a loop carried value coming from stage N + 1 remap,
+ // it will become a direct use.
+ Value ret = forOp.getBody()->getTerminator()->getOperand(
+ arg.getArgNumber() - 1);
+ Operation *dep = ret.getDefiningOp();
+ if (!dep)
+ continue;
+ auto stageDep = stages.find(dep);
+ if (stageDep == stages.end() || stageDep->second == useStage)
+ continue;
+ assert(stageDep->second == useStage + 1);
+ nestedNewOp->setOperand(operand->getOperandNumber(),
+ mapping.lookupOrDefault(ret));
+ continue;
+ }
+ // For operands defined in a previous stage we need to remap it to use
+ // the correct region argument. We look for the right version of the
+ // Value based on the stage where it is used.
+ Operation *def = operand->get().getDefiningOp();
+ if (!def)
+ continue;
+ auto stageDef = stages.find(def);
+ if (stageDef == stages.end() || stageDef->second == useStage)
+ continue;
+ auto remap = loopArgMap.find(
+ std::make_pair(operand->get(), useStage - stageDef->second));
+ assert(remap != loopArgMap.end());
+ nestedNewOp->setOperand(operand->getOperandNumber(),
+ newForOp.getRegionIterArgs()[remap->second]);
+ }
if (predicates[useStage]) {
newOp = predicateFn(newOp, predicates[useStage], rewriter);
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 68b513362a250..0309287e409c1 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -627,3 +627,47 @@ func.func @pipeline_op_with_region(%A: memref<?xf32>, %B: memref<?xf32>, %result
} { __test_pipelining_loop__ }
return
}
+
+// -----
+
+// CHECK-LABEL: @backedge_mix_order
+// CHECK-SAME: (%[[A:.*]]: memref<?xf32>) -> f32 {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[CSTF:.*]] = arith.constant 2.000000e+00 : f32
+// Prologue:
+// CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+// CHECK-NEXT: %[[L1:.*]] = memref.load %[[A]][%[[C1]]] : memref<?xf32>
+// Kernel:
+// CHECK-NEXT: %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
+// CHECK-SAME: step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
+// CHECK-SAME: %[[ARG1:.*]] = %[[L0]], %[[ARG2:.*]] = %[[L1]]) -> (f32, f32, f32) {
+// CHECK-NEXT: %[[IV2:.*]] = arith.addi %[[IV]], %[[C1]] : index
+// CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref<?xf32>
+// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[C]], %[[ARG1]] : f32
+// CHECK-NEXT: %[[IV3:.*]] = arith.addi %[[IV]], %[[C1]] : index
+// CHECK-NEXT: %[[IV4:.*]] = arith.addi %[[IV3]], %[[C1]] : index
+// CHECK-NEXT: %[[L3:.*]] = memref.load %[[A]][%[[IV4]]] : memref<?xf32>
+// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[ARG2]], %[[MUL0]] : f32
+// CHECK-NEXT: scf.yield %[[MUL1]], %[[L2]], %[[L3]] : f32, f32, f32
+// CHECK-NEXT: }
+// Epilogue:
+// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[R]]#0, %[[R]]#1 : f32
+// CHECK-NEXT: %[[MUL2:.*]] = arith.mulf %[[R]]#2, %[[MUL1]] : f32
+// CHECK-NEXT: return %[[MUL2]] : f32
+func.func @backedge_mix_order(%A: memref<?xf32>) -> f32 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ %cf = arith.constant 2.0 : f32
+ %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+ %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : memref<?xf32>
+ %A2_elem = arith.mulf %arg0, %A_elem { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
+ %i1 = arith.addi %i0, %c1 { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : index
+ %A1_elem = memref.load %A[%i1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref<?xf32>
+ %A3_elem = arith.mulf %A1_elem, %A2_elem { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 4 } : f32
+ scf.yield %A3_elem : f32
+ } { __test_pipelining_loop__ }
+ return %r : f32
+}
\ No newline at end of file
More information about the Mlir-commits
mailing list