[Mlir-commits] [mlir] [SCF][PIPELINE] Handle the case when values from the peeled prologue may escape out of the loop (PR #105755)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Aug 22 16:50:26 PDT 2024


https://github.com/pawelszczerbuk created https://github.com/llvm/llvm-project/pull/105755

Previously the values in the peeled prologue that weren't treated with the `predicateFn` were passed to the loop body without any other predication. If those values are later used outside of the loop body, they may be incorrect if the num iterations is smaller than num stages - 1. We need similar masking for those, as is done in the main loop body, using already existing predicates.

>From 0e65cb8339684c1e7d5fe576bd3df51fc1ddef6f Mon Sep 17 00:00:00 2001
From: Pawel Szczerbuk <pawel.szczerbuk at openai.com>
Date: Thu, 22 Aug 2024 16:44:04 -0700
Subject: [PATCH] Handle the case when values from the peeled prologue may
 escape out of the loop.

---
 .../Dialect/SCF/Transforms/LoopPipelining.cpp | 20 ++++++++++----
 mlir/test/Dialect/SCF/loop-pipelining.mlir    | 26 ++++++++++++-------
 2 files changed, 32 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index cc1a22d0d48a18..d8e1cc0ecef88e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -268,7 +268,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
 }
 
 void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
-  // Initialize the iteration argument to the loop initiale values.
+  // Initialize the iteration argument to the loop initial values.
   for (auto [arg, operand] :
        llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
     setValueMapping(arg, operand.get(), 0);
@@ -320,16 +320,26 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
       if (annotateFn)
         annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
       for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
-        setValueMapping(op->getResult(destId), newOp->getResult(destId),
-                        i - stages[op]);
+        Value source = newOp->getResult(destId);
         // If the value is a loop carried dependency update the loop argument
-        // mapping.
         for (OpOperand &operand : yield->getOpOperands()) {
           if (operand.get() != op->getResult(destId))
             continue;
+          if (predicates[predicateIdx] &&
+              !forOp.getResult(operand.getOperandNumber()).use_empty()) {
+            // If the value is used outside the loop, we need to make sure we
+            // return the correct version of it.
+            Value prevValue = valueMapping
+                [forOp.getRegionIterArgs()[operand.getOperandNumber()]]
+                [i - stages[op]];
+            source = rewriter.create<arith::SelectOp>(
+                loc, predicates[predicateIdx], source, prevValue);
+          }
           setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
-                          newOp->getResult(destId), i - stages[op] + 1);
+                          source, i - stages[op] + 1);
         }
+        setValueMapping(op->getResult(destId), newOp->getResult(destId),
+                        i - stages[op]);
       }
     }
   }
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 46e7feca4329ee..9687f80f5ddfc8 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -703,18 +703,26 @@ func.func @distance_1_use(%A: memref<?xf32>, %result: memref<?xf32>) {
 // -----
 
 // NOEPILOGUE-LABEL: stage_0_value_escape(
-func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>) {
+func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub: index) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  %c4 = arith.constant 4 : index
   %cf = arith.constant 1.0 : f32
-// NOEPILOGUE: %[[C3:.+]] = arith.constant 3 : index
-// NOEPILOGUE: %[[A:.+]] = arith.addf
-// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[A]],
-// NOEPILOGUE:   %[[C:.+]] = arith.cmpi slt, %[[IV]], %[[C3]] : index
-// NOEPILOGUE:   %[[S:.+]] = arith.select %[[C]], %{{.+}}, %[[ARG]] : f32
-// NOEPILOGUE:   scf.yield %[[S]]
-  %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+// NOEPILOGUE: %[[UB:[^,]+]]: index)
+// NOEPILOGUE-DAG: %[[C0:.+]] = arith.constant 0 : index
+// NOEPILOGUE-DAG: %[[C1:.+]] = arith.constant 1 : index
+// NOEPILOGUE-DAG: %[[CF:.+]] = arith.constant 1.000000e+00
+// NOEPILOGUE: %[[CND0:.+]] = arith.cmpi sgt, %[[UB]], %[[C0]]
+// NOEPILOGUE: scf.if
+// NOEPILOGUE: %[[IF:.+]] = scf.if %[[CND0]]
+// NOEPILOGUE:   %[[A:.+]] = arith.addf
+// NOEPILOGUE:   scf.yield %[[A]]
+// NOEPILOGUE: %[[S0:.+]] = arith.select %[[CND0]], %[[IF]], %[[CF]]
+// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[S0]],
+// NOEPILOGUE:   %[[UB_1:.+]] = arith.subi %[[UB]], %[[C1]] : index
+// NOEPILOGUE:   %[[CND1:.+]] = arith.cmpi slt, %[[IV]], %[[UB_1]] : index
+// NOEPILOGUE:   %[[S1:.+]] = arith.select %[[CND1]], %{{.+}}, %[[ARG]] : f32
+// NOEPILOGUE:   scf.yield %[[S1]]
+  %r = scf.for %i0 = %c0 to %ub step %c1 iter_args(%arg0 = %cf) -> (f32) {
     %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref<?xf32>
     %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
     memref.store %A1_elem, %result[%c0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>



More information about the Mlir-commits mailing list