[Mlir-commits] [mlir] 7c90081 - [SCF][PIPELINE] Handle the case when values from the peeled prologue may escape out of the loop (#105755)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 23 08:23:15 PDT 2024
Author: pawelszczerbuk
Date: 2024-08-23T08:23:11-07:00
New Revision: 7c9008115a2a24788f07bb476fb28dcf5e661ae4
URL: https://github.com/llvm/llvm-project/commit/7c9008115a2a24788f07bb476fb28dcf5e661ae4
DIFF: https://github.com/llvm/llvm-project/commit/7c9008115a2a24788f07bb476fb28dcf5e661ae4.diff
LOG: [SCF][PIPELINE] Handle the case when values from the peeled prologue may escape out of the loop (#105755)
Previously the values in the peeled prologue that weren't treated with
the `predicateFn` were passed to the loop body without any other
predication. If those values are later used outside of the loop body,
they may be incorrect if the num iterations is smaller than num stages -
1. We need similar masking for those, as is done in the main loop body,
using already existing predicates.
Added:
Modified:
mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
mlir/test/Dialect/SCF/loop-pipelining.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index cc1a22d0d48a18..d8e1cc0ecef88e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -268,7 +268,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
}
void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
- // Initialize the iteration argument to the loop initiale values.
+ // Initialize the iteration argument to the loop initial values.
for (auto [arg, operand] :
llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
setValueMapping(arg, operand.get(), 0);
@@ -320,16 +320,26 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
- setValueMapping(op->getResult(destId), newOp->getResult(destId),
- i - stages[op]);
+ Value source = newOp->getResult(destId);
// If the value is a loop carried dependency update the loop argument
- // mapping.
for (OpOperand &operand : yield->getOpOperands()) {
if (operand.get() != op->getResult(destId))
continue;
+ if (predicates[predicateIdx] &&
+ !forOp.getResult(operand.getOperandNumber()).use_empty()) {
+ // If the value is used outside the loop, we need to make sure we
+ // return the correct version of it.
+ Value prevValue = valueMapping
+ [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
+ [i - stages[op]];
+ source = rewriter.create<arith::SelectOp>(
+ loc, predicates[predicateIdx], source, prevValue);
+ }
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
- newOp->getResult(destId), i - stages[op] + 1);
+ source, i - stages[op] + 1);
}
+ setValueMapping(op->getResult(destId), newOp->getResult(destId),
+ i - stages[op]);
}
}
}
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 46e7feca4329ee..9687f80f5ddfc8 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -703,18 +703,26 @@ func.func @distance_1_use(%A: memref<?xf32>, %result: memref<?xf32>) {
// -----
// NOEPILOGUE-LABEL: stage_0_value_escape(
-func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>) {
+func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- %c4 = arith.constant 4 : index
%cf = arith.constant 1.0 : f32
-// NOEPILOGUE: %[[C3:.+]] = arith.constant 3 : index
-// NOEPILOGUE: %[[A:.+]] = arith.addf
-// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[A]],
-// NOEPILOGUE: %[[C:.+]] = arith.cmpi slt, %[[IV]], %[[C3]] : index
-// NOEPILOGUE: %[[S:.+]] = arith.select %[[C]], %{{.+}}, %[[ARG]] : f32
-// NOEPILOGUE: scf.yield %[[S]]
- %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+// NOEPILOGUE: %[[UB:[^,]+]]: index)
+// NOEPILOGUE-DAG: %[[C0:.+]] = arith.constant 0 : index
+// NOEPILOGUE-DAG: %[[C1:.+]] = arith.constant 1 : index
+// NOEPILOGUE-DAG: %[[CF:.+]] = arith.constant 1.000000e+00
+// NOEPILOGUE: %[[CND0:.+]] = arith.cmpi sgt, %[[UB]], %[[C0]]
+// NOEPILOGUE: scf.if
+// NOEPILOGUE: %[[IF:.+]] = scf.if %[[CND0]]
+// NOEPILOGUE: %[[A:.+]] = arith.addf
+// NOEPILOGUE: scf.yield %[[A]]
+// NOEPILOGUE: %[[S0:.+]] = arith.select %[[CND0]], %[[IF]], %[[CF]]
+// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[S0]],
+// NOEPILOGUE: %[[UB_1:.+]] = arith.subi %[[UB]], %[[C1]] : index
+// NOEPILOGUE: %[[CND1:.+]] = arith.cmpi slt, %[[IV]], %[[UB_1]] : index
+// NOEPILOGUE: %[[S1:.+]] = arith.select %[[CND1]], %{{.+}}, %[[ARG]] : f32
+// NOEPILOGUE: scf.yield %[[S1]]
+ %r = scf.for %i0 = %c0 to %ub step %c1 iter_args(%arg0 = %cf) -> (f32) {
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref<?xf32>
%A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
memref.store %A1_elem, %result[%c0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
More information about the Mlir-commits
mailing list