[Mlir-commits] [mlir] [mlir] Fix loop pipelining when the operand of `yield` is not defined in the loop body (PR #75423)

Keren Zhou llvmlistbot at llvm.org
Wed Dec 13 18:19:13 PST 2023


https://github.com/Jokeren updated https://github.com/llvm/llvm-project/pull/75423

>From dcc8d5cc450ca2874de1f6f715b81df7d402ce6d Mon Sep 17 00:00:00 2001
From: Jokeren <robinho364 at gmail.com>
Date: Wed, 13 Dec 2023 14:37:16 -0500
Subject: [PATCH 1/4] Update

---
 .../Dialect/SCF/Transforms/LoopPipelining.cpp | 55 +++++++++++--------
 1 file changed, 33 insertions(+), 22 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 6c36600975a597..80d8255df1dd98 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -90,7 +90,8 @@ struct LoopPipelinerInternal {
       RewriterBase &rewriter);
   /// Emits the epilogue, this creates `maxStage - 1` part which will contain
   /// operations from stages [i; maxStage], where i is the part index.
-  llvm::SmallVector<Value> emitEpilogue(RewriterBase &rewriter);
+  void emitEpilogue(RewriterBase &rewriter,
+                    llvm::SmallVector<Value> &returnValues);
 };
 
 bool LoopPipelinerInternal::initializeLoopInfo(
@@ -175,15 +176,18 @@ bool LoopPipelinerInternal::initializeLoopInfo(
     }
   }
 
-  // Only support loop carried dependency with a distance of 1. This means the
-  // source of all the scf.yield operands needs to be defined by operations in
-  // the loop.
+  // Support only loop-carried dependencies with a distance of one iteration or
+  // those defined outside of the loop. This means that any dependency within a
+  // loop should either be on the immediately preceding iteration, the current
+  // iteration, or on variables whose values are set before entering the loop.
   if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
                    [this](Value operand) {
                      Operation *def = operand.getDefiningOp();
-                     return !def || !stages.contains(def);
+                     return !def ||
+                            (!stages.contains(def) && forOp->isAncestor(def));
                    })) {
-    LDBG("--only support loop carried dependency with a distance of 1 -> BAIL");
+    LDBG("--only support loop carried dependency with a distance of 1 or "
+         "defined outside of the loop -> BAIL");
     return false;
   }
   annotateFn = options.annotateFn;
@@ -341,12 +345,17 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
   for (const auto &retVal :
        llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
     Operation *def = retVal.value().getDefiningOp();
-    assert(def && "Only support loop carried dependencies of distance 1");
-    unsigned defStage = stages[def];
-    Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
-                                     [maxStage - defStage];
-    assert(valueVersion);
-    newLoopArg.push_back(valueVersion);
+    assert(def && "Only support loop carried dependencies of distance of 1 or "
+                  "outside the loop");
+    auto defStage = stages.find(def);
+    if (defStage != stages.end()) {
+      Value valueVersion =
+          valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
+                      [maxStage - defStage->second];
+      assert(valueVersion);
+      newLoopArg.push_back(valueVersion);
+    } else
+      newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
   }
   for (auto escape : crossStageValues) {
     LiverangeInfo &info = escape.second;
@@ -551,21 +560,24 @@ LogicalResult LoopPipelinerInternal::createKernel(
   for (const auto &retVal :
        llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
     Operation *def = retVal.value().getDefiningOp();
-    assert(def && "Only support loop carried dependencies of distance 1");
-    unsigned defStage = stages[def];
-    if (defStage > 0) {
+    assert(def && "Only support loop carried dependencies of distance of 1 or "
+                  "defined outside the loop");
+    auto defStage = stages.find(def);
+    if (defStage != stages.end() && defStage->second > 0)
       setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
                       newForOp->getResult(retVal.index()),
-                      maxStage - defStage + 1);
-    }
+                      maxStage - defStage->second + 1);
+    else
+      for (unsigned int stage = 1; stage <= maxStage; stage++)
+        setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
+                        retVal.value(), stage);
   }
   rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
   return success();
 }
 
-llvm::SmallVector<Value>
-LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
-  llvm::SmallVector<Value> returnValues(forOp->getNumResults());
+void LoopPipelinerInternal::emitEpilogue(
+    RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
   // Emit different versions of the induction variable. They will be
   // removed by dead code if not used.
   for (int64_t i = 0; i < maxStage; i++) {
@@ -628,7 +640,6 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
       }
     }
   }
-  return returnValues;
 }
 
 void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
@@ -685,7 +696,7 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
   if (options.peelEpilogue) {
     // 4. Emit the epilogue after the new forOp.
     rewriter.setInsertionPointAfter(newForOp);
-    returnValues = pipeliner.emitEpilogue(rewriter);
+    pipeliner.emitEpilogue(rewriter, returnValues);
   }
   // 5. Erase the original loop and replace the uses with the epilogue output.
   if (forOp->getNumResults() > 0)

>From 8a04013651696d403b675b57868c018da24775c6 Mon Sep 17 00:00:00 2001
From: Jokeren <robinho364 at gmail.com>
Date: Wed, 13 Dec 2023 20:59:08 -0500
Subject: [PATCH 2/4] Update

---
 .../Dialect/SCF/Transforms/LoopPipelining.cpp |  1 +
 mlir/test/Dialect/SCF/loop-pipelining.mlir    | 44 +++++++++++++++++++
 2 files changed, 45 insertions(+)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 80d8255df1dd98..be5d397ef22265 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -704,6 +704,7 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
   else
     rewriter.eraseOp(forOp);
 
+  llvm::errs() << *newForOp->getParentOp() << "\n";
   return newForOp;
 }
 
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 8a57ddccfee665..a18c850c3f05f1 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -770,3 +770,47 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
   } { __test_pipelining_loop__ }
   return
 }
+
+// -----
+
+// CHECK-LABEL: yield_constant_loop(
+//  CHECK-SAME:   %[[A:.*]]: memref<?xf32>) -> f32 {
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+//   CHECK-DAG:   %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
+// Prologue:
+//       CHECK:   %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+// Kernel:
+//  CHECK-NEXT:   %[[L1:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
+//  CHECK-SAME:     step %[[C1]] iter_args(%[[ARG0:.*]] = %[[CST2]], %[[ARG1:.*]] = %[[L0]]) -> (f32, f32) {
+//  CHECK-NEXT:     %[[ADD0:.*]] = arith.addf %[[ARG1]], %[[ARG0]] : f32
+//  CHECK-NEXT:     %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CST0]] : f32
+//  CHECK-NEXT:     memref.store %[[MUL0]], %[[A]][%[[IV]]] : memref<?xf32>
+//  CHECK-NEXT:     %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index
+//  CHECK-NEXT:     %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
+//  CHECK-NEXT:     scf.yield %[[CST0]], %[[L2]] : f32
+//  CHECK-NEXT:   }
+// Epilogue:
+//  CHECK-NEXT:   %[[ADD1:.*]] = arith.addf %[[L1]]#1, %[[CST0]] : f32
+//  CHECK-NEXT:   %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CST0]] : f32
+//  CHECK-NEXT:   memref.store %[[MUL1]], %[[A]][%[[C3]]] : memref<?xf32>
+//  CHECK-NEXT:   return %[[L1]]#0 : f32
+
+func.func @yield_constant_loop(%A: memref<?xf32>) -> f32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %cf0 = arith.constant 0.0 : f32
+  %cf2 = arith.constant 2.0 : f32
+  %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf2) -> f32 {
+    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref<?xf32>
+    %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+    %A2_elem = arith.mulf %cf0, %A1_elem { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
+    memref.store %A2_elem, %A[%i0] { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+    scf.yield %cf0: f32
+  }  { __test_pipelining_loop__ }
+  return %r : f32
+}
+

>From f0921232dd06b9f8549e5427a65b641d3f18018b Mon Sep 17 00:00:00 2001
From: Jokeren <robinho364 at gmail.com>
Date: Wed, 13 Dec 2023 21:11:36 -0500
Subject: [PATCH 3/4] Remove output

---
 mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index be5d397ef22265..f14f2de7af1236 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -563,14 +563,14 @@ LogicalResult LoopPipelinerInternal::createKernel(
     assert(def && "Only support loop carried dependencies of distance of 1 or "
                   "defined outside the loop");
     auto defStage = stages.find(def);
-    if (defStage != stages.end() && defStage->second > 0)
-      setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
-                      newForOp->getResult(retVal.index()),
-                      maxStage - defStage->second + 1);
-    else
+    if (defStage == stages.end())
       for (unsigned int stage = 1; stage <= maxStage; stage++)
         setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
                         retVal.value(), stage);
+    else if (defStage->second > 0)
+      setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
+                      newForOp->getResult(retVal.index()),
+                      maxStage - defStage->second + 1);
   }
   rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
   return success();
@@ -704,7 +704,6 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
   else
     rewriter.eraseOp(forOp);
 
-  llvm::errs() << *newForOp->getParentOp() << "\n";
   return newForOp;
 }
 

>From 1f3f1054d81d759c085157bde5bf6af675edf378 Mon Sep 17 00:00:00 2001
From: Jokeren <robinho364 at gmail.com>
Date: Wed, 13 Dec 2023 21:19:00 -0500
Subject: [PATCH 4/4] Multiline if/else

---
 mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index f14f2de7af1236..7d45b484f76575 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -563,14 +563,15 @@ LogicalResult LoopPipelinerInternal::createKernel(
     assert(def && "Only support loop carried dependencies of distance of 1 or "
                   "defined outside the loop");
     auto defStage = stages.find(def);
-    if (defStage == stages.end())
+    if (defStage == stages.end()) {
       for (unsigned int stage = 1; stage <= maxStage; stage++)
         setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
                         retVal.value(), stage);
-    else if (defStage->second > 0)
+    } else if (defStage->second > 0) {
       setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
                       newForOp->getResult(retVal.index()),
                       maxStage - defStage->second + 1);
+    }
   }
   rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
   return success();



More information about the Mlir-commits mailing list