[Mlir-commits] [mlir] [mlir] Fix loop pipelining when the operand of `yield` is not defined in the loop body (PR #75423)
    llvmlistbot at llvm.org 
    llvmlistbot at llvm.org
       
    Wed Dec 13 18:06:45 PST 2023
    
    
  
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-scf
Author: Keren Zhou (Jokeren)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/75423.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp (+34-22) 
- (modified) mlir/test/Dialect/SCF/loop-pipelining.mlir (+44) 
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 6c36600975a597..be5d397ef22265 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -90,7 +90,8 @@ struct LoopPipelinerInternal {
       RewriterBase &rewriter);
   /// Emits the epilogue, this creates `maxStage - 1` part which will contain
   /// operations from stages [i; maxStage], where i is the part index.
-  llvm::SmallVector<Value> emitEpilogue(RewriterBase &rewriter);
+  void emitEpilogue(RewriterBase &rewriter,
+                    llvm::SmallVector<Value> &returnValues);
 };
 
 bool LoopPipelinerInternal::initializeLoopInfo(
@@ -175,15 +176,18 @@ bool LoopPipelinerInternal::initializeLoopInfo(
     }
   }
 
-  // Only support loop carried dependency with a distance of 1. This means the
-  // source of all the scf.yield operands needs to be defined by operations in
-  // the loop.
+  // Support only loop-carried dependencies with a distance of one iteration or
+  // those defined outside of the loop. This means that any dependency within a
+  // loop should either be on the immediately preceding iteration, the current
+  // iteration, or on variables whose values are set before entering the loop.
   if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
                    [this](Value operand) {
                      Operation *def = operand.getDefiningOp();
-                     return !def || !stages.contains(def);
+                     return !def ||
+                            (!stages.contains(def) && forOp->isAncestor(def));
                    })) {
-    LDBG("--only support loop carried dependency with a distance of 1 -> BAIL");
+    LDBG("--only support loop carried dependency with a distance of 1 or "
+         "defined outside of the loop -> BAIL");
     return false;
   }
   annotateFn = options.annotateFn;
@@ -341,12 +345,17 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
   for (const auto &retVal :
        llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
     Operation *def = retVal.value().getDefiningOp();
-    assert(def && "Only support loop carried dependencies of distance 1");
-    unsigned defStage = stages[def];
-    Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
-                                     [maxStage - defStage];
-    assert(valueVersion);
-    newLoopArg.push_back(valueVersion);
+    assert(def && "Only support loop carried dependencies of distance of 1 or "
+                  "outside the loop");
+    auto defStage = stages.find(def);
+    if (defStage != stages.end()) {
+      Value valueVersion =
+          valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
+                      [maxStage - defStage->second];
+      assert(valueVersion);
+      newLoopArg.push_back(valueVersion);
+    } else
+      newLoopArg.push_back(forOp.getInitArgs()[retVal.index()]);
   }
   for (auto escape : crossStageValues) {
     LiverangeInfo &info = escape.second;
@@ -551,21 +560,24 @@ LogicalResult LoopPipelinerInternal::createKernel(
   for (const auto &retVal :
        llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
     Operation *def = retVal.value().getDefiningOp();
-    assert(def && "Only support loop carried dependencies of distance 1");
-    unsigned defStage = stages[def];
-    if (defStage > 0) {
+    assert(def && "Only support loop carried dependencies of distance of 1 or "
+                  "defined outside the loop");
+    auto defStage = stages.find(def);
+    if (defStage != stages.end() && defStage->second > 0)
       setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
                       newForOp->getResult(retVal.index()),
-                      maxStage - defStage + 1);
-    }
+                      maxStage - defStage->second + 1);
+    else
+      for (unsigned int stage = 1; stage <= maxStage; stage++)
+        setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
+                        retVal.value(), stage);
   }
   rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
   return success();
 }
 
-llvm::SmallVector<Value>
-LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
-  llvm::SmallVector<Value> returnValues(forOp->getNumResults());
+void LoopPipelinerInternal::emitEpilogue(
+    RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
   // Emit different versions of the induction variable. They will be
   // removed by dead code if not used.
   for (int64_t i = 0; i < maxStage; i++) {
@@ -628,7 +640,6 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
       }
     }
   }
-  return returnValues;
 }
 
 void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
@@ -685,7 +696,7 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
   if (options.peelEpilogue) {
     // 4. Emit the epilogue after the new forOp.
     rewriter.setInsertionPointAfter(newForOp);
-    returnValues = pipeliner.emitEpilogue(rewriter);
+    pipeliner.emitEpilogue(rewriter, returnValues);
   }
   // 5. Erase the original loop and replace the uses with the epilogue output.
   if (forOp->getNumResults() > 0)
@@ -693,6 +704,7 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
   else
     rewriter.eraseOp(forOp);
 
+  llvm::errs() << *newForOp->getParentOp() << "\n";
   return newForOp;
 }
 
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 8a57ddccfee665..a18c850c3f05f1 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -770,3 +770,47 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
   } { __test_pipelining_loop__ }
   return
 }
+
+// -----
+
+// CHECK-LABEL: yield_constant_loop(
+//  CHECK-SAME:   %[[A:.*]]: memref<?xf32>) -> f32 {
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
+//   CHECK-DAG:   %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
+// Prologue:
+//       CHECK:   %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+// Kernel:
+//  CHECK-NEXT:   %[[L1:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
+//  CHECK-SAME:     step %[[C1]] iter_args(%[[ARG0:.*]] = %[[CST2]], %[[ARG1:.*]] = %[[L0]]) -> (f32, f32) {
+//  CHECK-NEXT:     %[[ADD0:.*]] = arith.addf %[[ARG1]], %[[ARG0]] : f32
+//  CHECK-NEXT:     %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CST0]] : f32
+//  CHECK-NEXT:     memref.store %[[MUL0]], %[[A]][%[[IV]]] : memref<?xf32>
+//  CHECK-NEXT:     %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index
+//  CHECK-NEXT:     %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
+//  CHECK-NEXT:     scf.yield %[[CST0]], %[[L2]] : f32
+//  CHECK-NEXT:   }
+// Epilogue:
+//  CHECK-NEXT:   %[[ADD1:.*]] = arith.addf %[[L1]]#1, %[[CST0]] : f32
+//  CHECK-NEXT:   %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CST0]] : f32
+//  CHECK-NEXT:   memref.store %[[MUL1]], %[[A]][%[[C3]]] : memref<?xf32>
+//  CHECK-NEXT:   return %[[L1]]#0 : f32
+
+func.func @yield_constant_loop(%A: memref<?xf32>) -> f32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %cf0 = arith.constant 0.0 : f32
+  %cf2 = arith.constant 2.0 : f32
+  %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf2) -> f32 {
+    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref<?xf32>
+    %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+    %A2_elem = arith.mulf %cf0, %A1_elem { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
+    memref.store %A2_elem, %A[%i0] { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+    scf.yield %cf0: f32
+  }  { __test_pipelining_loop__ }
+  return %r : f32
+}
+
``````````
</details>
https://github.com/llvm/llvm-project/pull/75423
    
    
More information about the Mlir-commits
mailing list