[Mlir-commits] [mlir] [MLIR][SCF] Handle more cases in pipelining transform (PR #74007)

Thomas Raoux llvmlistbot at llvm.org
Fri Dec 1 20:16:11 PST 2023


https://github.com/ThomasRaoux updated https://github.com/llvm/llvm-project/pull/74007

>From a4e8d3d150632b0d431d21c66ebda7e5fcaf2098 Mon Sep 17 00:00:00 2001
From: Thomas Raoux <thomas.raoux at openai.com>
Date: Thu, 30 Nov 2023 15:58:49 -0800
Subject: [PATCH 1/2] [MLIR][SCF] Handle more cases in pipelining transform

-Fix case where an op is scheduled in stage 0 and used with a distance of 1
-Fix case where we don't peel the epilogue and a value not part of the last stage is used outside the loop.
---
 .../Dialect/SCF/Transforms/LoopPipelining.cpp | 81 +++++++++++++++----
 mlir/test/Dialect/SCF/loop-pipelining.mlir    | 54 ++++++++++++-
 2 files changed, 119 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 5537a8b212c51f7..f25318fe52093ec 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -61,6 +61,11 @@ struct LoopPipelinerInternal {
   /// `idx` of `key` in the epilogue.
   void setValueMapping(Value key, Value el, int64_t idx);
 
+  /// Return the defining op of the given value, if the Value is an argument of
+  /// the loop return the associated defining op in the loop and its distance to
+  /// the Value.
+  std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
+
 public:
   /// Initalize the information for the given `op`, return true if it
   /// satisfies the pre-condition to apply pipelining.
@@ -240,11 +245,12 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
     unsigned stage = stages[op];
 
     auto analyzeOperand = [&](OpOperand &operand) {
-      Operation *def = operand.get().getDefiningOp();
+      auto [def, distance] = getDefiningOpAndDistance(operand.get());
       if (!def)
         return;
       auto defStage = stages.find(def);
-      if (defStage == stages.end() || defStage->second == stage)
+      if (defStage == stages.end() || defStage->second == stage ||
+          defStage->second == stage + distance)
         return;
       assert(stage > defStage->second);
       LiverangeInfo &info = crossStageValues[operand.get()];
@@ -261,6 +267,25 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
   return crossStageValues;
 }
 
+std::pair<Operation *, int64_t>
+LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
+  int64_t distance = 0;
+  if (auto arg = dyn_cast<BlockArgument>(value)) {
+    if (arg.getOwner() != forOp.getBody())
+      return {nullptr, 0};
+    // Ignore induction variable.
+    if (arg.getArgNumber() == 0)
+      return {nullptr, 0};
+    distance++;
+    value =
+        forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
+  }
+  Operation *def = value.getDefiningOp();
+  if (!def)
+    return {nullptr, 0};
+  return {def, distance};
+}
+
 scf::ForOp LoopPipelinerInternal::createKernelLoop(
     const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
         &crossStageValues,
@@ -366,10 +391,9 @@ LogicalResult LoopPipelinerInternal::createKernel(
         rewriter.setInsertionPointAfter(newOp);
         continue;
       }
-      auto arg = dyn_cast<BlockArgument>(operand->get());
+      Value source = operand->get();
+      auto arg = dyn_cast<BlockArgument>(source);
       if (arg && arg.getOwner() == forOp.getBody()) {
-        // If the value is a loop carried value coming from stage N + 1 remap,
-        // it will become a direct use.
         Value ret = forOp.getBody()->getTerminator()->getOperand(
             arg.getArgNumber() - 1);
         Operation *dep = ret.getDefiningOp();
@@ -378,15 +402,19 @@ LogicalResult LoopPipelinerInternal::createKernel(
         auto stageDep = stages.find(dep);
         if (stageDep == stages.end() || stageDep->second == useStage)
           continue;
-        assert(stageDep->second == useStage + 1);
-        nestedNewOp->setOperand(operand->getOperandNumber(),
-                                mapping.lookupOrDefault(ret));
-        continue;
+        // If the value is a loop carried value coming from stage N + 1 remap,
+        // it will become a direct use.
+        if (stageDep->second == useStage + 1) {
+          nestedNewOp->setOperand(operand->getOperandNumber(),
+                                  mapping.lookupOrDefault(ret));
+          continue;
+        }
+        source = ret;
       }
       // For operands defined in a previous stage we need to remap it to use
       // the correct region argument. We look for the right version of the
       // Value based on the stage where it is used.
-      Operation *def = operand->get().getDefiningOp();
+      Operation *def = source.getDefiningOp();
       if (!def)
         continue;
       auto stageDef = stages.find(def);
@@ -418,9 +446,30 @@ LogicalResult LoopPipelinerInternal::createKernel(
   // We create a mapping between original values and the associated loop
   // returned values that will be needed by the epilogue.
   llvm::SmallVector<Value> yieldOperands;
-  for (Value retVal : forOp.getBody()->getTerminator()->getOperands()) {
-    yieldOperands.push_back(mapping.lookupOrDefault(retVal));
+  for (OpOperand &yielOperand :
+       forOp.getBody()->getTerminator()->getOpOperands()) {
+    Value source = mapping.lookupOrDefault(yielOperand.get());
+    // When we don't peel the epilogue the yield value is used outside the loop
+    // we need to make sure we return the version from numStages - defStage.
+    if (!peelEpilogue &&
+        !forOp.getResult(yielOperand.getOperandNumber()).use_empty()) {
+      auto [def, distance] = getDefiningOpAndDistance(yielOperand.get());
+      if (def) {
+        auto defStage = stages.find(def);
+        if (defStage != stages.end()) {
+          Value pred = predicates[defStage->second];
+          if (pred) {
+            source = rewriter.create<arith::SelectOp>(
+                pred.getLoc(), pred, source,
+                newForOp.getBody()
+                    ->getArguments()[yielOperand.getOperandNumber() + 1]);
+          }
+        }
+      }
+    }
+    yieldOperands.push_back(source);
   }
+
   for (auto &it : crossStageValues) {
     int64_t version = maxStage - it.second.lastUseStage + 1;
     unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
@@ -444,9 +493,11 @@ LogicalResult LoopPipelinerInternal::createKernel(
     Operation *def = retVal.value().getDefiningOp();
     assert(def && "Only support loop carried dependencies of distance 1");
     unsigned defStage = stages[def];
-    setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
-                    newForOp->getResult(retVal.index()),
-                    maxStage - defStage + 1);
+    if (defStage > 0) {
+      setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
+                      newForOp->getResult(retVal.index()),
+                      maxStage - defStage + 1);
+    }
   }
   rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
   return success();
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 0309287e409c184..4cd686d2cdb86b6 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -670,4 +670,56 @@ func.func @backedge_mix_order(%A: memref<?xf32>) -> f32 {
     scf.yield %A3_elem : f32
   }  { __test_pipelining_loop__ }
   return %r : f32
-}
\ No newline at end of file
+}
+
+// -----
+
+// CHECK-LABEL: @distance_1_use
+//  CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+// Prologue:
+//  CHECK: %[[L0:.+]] = memref.load %{{.*}}[%[[C0]]] : memref<?xf32>
+//  CHECK: %[[L1:.+]] = memref.load %{{.*}}[%[[C1]]] : memref<?xf32>
+//  CHECK: %[[R:.+]]:5 = scf.for {{.*}} iter_args(%[[IDX0:.+]] = %[[C2]], %[[L2:.+]] = %[[L0]], %[[L3:.+]] = %[[L1]]
+//  CHECK:   %[[L4:.+]] = memref.load %{{.*}}[%[[IDX0]]] : memref<?xf32>
+//  CHECK:   %[[IDX1:.+]] = arith.addi %[[IDX0]], %[[C1]] : index
+//  CHECK:   memref.store %[[L2]]
+//  CHECK:   scf.yield %[[IDX1]], %[[L3]], %[[L4]]
+func.func @distance_1_use(%A: memref<?xf32>, %result: memref<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %cf = arith.constant 1.0 : f32
+  %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%idx = %c0) -> (index) {
+    %A_elem = memref.load %A[%idx] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : memref<?xf32>
+    %idx1 = arith.addi %idx, %c1 { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : index
+    memref.store %A_elem, %result[%idx] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+    scf.yield %idx1 : index
+  }  { __test_pipelining_loop__ }
+  return
+}
+
+// -----
+
+// NOEPILOGUE-LABEL: stage_0_value_escape(
+func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %cf = arith.constant 1.0 : f32
+// NOEPILOGUE: %[[C3:.+]] = arith.constant 3 : index
+// NOEPILOGUE: %[[A:.+]] = arith.addf
+// NOEPILOGUE: scf.for %[[IV:.+]] = {{.*}} iter_args(%[[ARG:.+]] = %[[A]],
+// NOEPILOGUE:   %[[C:.+]] = arith.cmpi slt, %[[IV]], %[[C3]] : index
+// NOEPILOGUE:   %[[S:.+]] = arith.select %[[C]], %{{.+}}, %[[ARG]] : f32
+// NOEPILOGUE:   scf.yield %[[S]]
+  %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref<?xf32>
+    %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+    memref.store %A1_elem, %result[%c0] { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+    scf.yield %A1_elem : f32
+  }  { __test_pipelining_loop__ }
+  memref.store %r, %result[%c1] : memref<?xf32>
+  return
+}

>From 6f2cd1a6d27746c68c5366c14dbb29502719c00f Mon Sep 17 00:00:00 2001
From: Thomas Raoux <thomas.raoux at openai.com>
Date: Fri, 1 Dec 2023 20:06:29 -0800
Subject: [PATCH 2/2] Address review comments

---
 .../Dialect/SCF/Transforms/LoopPipelining.cpp | 25 +++++++++----------
 1 file changed, 12 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index f25318fe52093ec..20fa8089201aa19 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -446,24 +446,23 @@ LogicalResult LoopPipelinerInternal::createKernel(
   // We create a mapping between original values and the associated loop
   // returned values that will be needed by the epilogue.
   llvm::SmallVector<Value> yieldOperands;
-  for (OpOperand &yielOperand :
+  for (OpOperand &yieldOperand :
        forOp.getBody()->getTerminator()->getOpOperands()) {
-    Value source = mapping.lookupOrDefault(yielOperand.get());
-    // When we don't peel the epilogue the yield value is used outside the loop
-    // we need to make sure we return the version from numStages - defStage.
+    Value source = mapping.lookupOrDefault(yieldOperand.get());
+    // When we don't peel the epilogue and the yield value is used outside the
+    // loop we need to make sure we return the version from numStages -
+    // defStage.
     if (!peelEpilogue &&
-        !forOp.getResult(yielOperand.getOperandNumber()).use_empty()) {
-      auto [def, distance] = getDefiningOpAndDistance(yielOperand.get());
+        !forOp.getResult(yieldOperand.getOperandNumber()).use_empty()) {
+      Operation *def = getDefiningOpAndDistance(yieldOperand.get()).first;
       if (def) {
         auto defStage = stages.find(def);
-        if (defStage != stages.end()) {
+        if (defStage != stages.end() && defStage->second < maxStage) {
           Value pred = predicates[defStage->second];
-          if (pred) {
-            source = rewriter.create<arith::SelectOp>(
-                pred.getLoc(), pred, source,
-                newForOp.getBody()
-                    ->getArguments()[yielOperand.getOperandNumber() + 1]);
-          }
+          source = rewriter.create<arith::SelectOp>(
+              pred.getLoc(), pred, source,
+              newForOp.getBody()
+                  ->getArguments()[yieldOperand.getOperandNumber() + 1]);
         }
       }
     }



More information about the Mlir-commits mailing list