[Mlir-commits] [mlir] [SCF][PIPELINE] Handle the case when values from the peeled prologue may escape out of the loop (PR #105755)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 22 16:50:54 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (pawelszczerbuk)
<details>
<summary>Changes</summary>
Previously the values in the peeled prologue that weren't treated with the `predicateFn` were passed to the loop body without any other predication. If those values are later used outside of the loop body, they may be incorrect if the num iterations is smaller than num stages - 1. We need similar masking for those, as is done in the main loop body, using already existing predicates.
---
Full diff: https://github.com/llvm/llvm-project/pull/105755.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp (+15-5)
- (modified) mlir/test/Dialect/SCF/loop-pipelining.mlir (+17-9)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index cc1a22d0d48a18..d8e1cc0ecef88e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -268,7 +268,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
}
void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
- // Initialize the iteration argument to the loop initiale values.
+ // Initialize the iteration argument to the loop initial values.
for (auto [arg, operand] :
llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
setValueMapping(arg, operand.get(), 0);
@@ -320,16 +320,26 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
if (annotateFn)
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
- setValueMapping(op->getResult(destId), newOp->getResult(destId),
- i - stages[op]);
+ Value source = newOp->getResult(destId);
// If the value is a loop carried dependency update the loop argument
- // mapping.
for (OpOperand &operand : yield->getOpOperands()) {
if (operand.get() != op->getResult(destId))
continue;
+ if (predicates[predicateIdx] &&
+ !forOp.getResult(operand.getOperandNumber()).use_empty()) {
+ // If the value is used outside the loop, we need to make sure we
+ // return the correct version of it.
+ Value prevValue = valueMapping
+ [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
+ [i - stages[op]];
+ source = rewriter.create<arith::SelectOp>(
+ loc, predicates[predicateIdx], source, prevValue);
+ }
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
- newOp->getResult(destId), i - stages[op] + 1);
+ source, i - stages[op] + 1);
}
+ setValueMapping(op->getResult(destId), newOp->getResult(destId),
+ i - stages[op]);
}
}
}
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 46e7feca4329ee..9687f80f5ddfc8 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -703,18 +703,26 @@ func.func @distance_1_use(%A: memref<?xf32>, %result: memref<?xf32>) {
// -----
// NOEPILOGUE-LABEL: stage_0_value_escape(
-func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>) {
+func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub: index) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- %c4 = arith.constant 4 : index
%cf = arith.constant 1.0 : f32
-// NOEPILOGUE: %[[C3:.+]] = arith.constant 3 : index
-// NOEPILOGUE: %[[A:.+]] = arith.addf
-// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[A]],
-// NOEPILOGUE: %[[C:.+]] = arith.cmpi slt, %[[IV]], %[[C3]] : index
-// NOEPILOGUE: %[[S:.+]] = arith.select %[[C]], %{{.+}}, %[[ARG]] : f32
-// NOEPILOGUE: scf.yield %[[S]]
- %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+// NOEPILOGUE: %[[UB:[^,]+]]: index)
+// NOEPILOGUE-DAG: %[[C0:.+]] = arith.constant 0 : index
+// NOEPILOGUE-DAG: %[[C1:.+]] = arith.constant 1 : index
+// NOEPILOGUE-DAG: %[[CF:.+]] = arith.constant 1.000000e+00
+// NOEPILOGUE: %[[CND0:.+]] = arith.cmpi sgt, %[[UB]], %[[C0]]
+// NOEPILOGUE: scf.if
+// NOEPILOGUE: %[[IF:.+]] = scf.if %[[CND0]]
+// NOEPILOGUE: %[[A:.+]] = arith.addf
+// NOEPILOGUE: scf.yield %[[A]]
+// NOEPILOGUE: %[[S0:.+]] = arith.select %[[CND0]], %[[IF]], %[[CF]]
+// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[S0]],
+// NOEPILOGUE: %[[UB_1:.+]] = arith.subi %[[UB]], %[[C1]] : index
+// NOEPILOGUE: %[[CND1:.+]] = arith.cmpi slt, %[[IV]], %[[UB_1]] : index
+// NOEPILOGUE: %[[S1:.+]] = arith.select %[[CND1]], %{{.+}}, %[[ARG]] : f32
+// NOEPILOGUE: scf.yield %[[S1]]
+ %r = scf.for %i0 = %c0 to %ub step %c1 iter_args(%arg0 = %cf) -> (f32) {
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref<?xf32>
%A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
memref.store %A1_elem, %result[%c0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/105755
More information about the Mlir-commits
mailing list