[Mlir-commits] [mlir] 45cb414 - [mlir] Extend scf pipeling to support loop carried dependencies

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 21 18:32:53 PDT 2021


Author: thomasraoux
Date: 2021-07-21T18:32:38-07:00
New Revision: 45cb4140eb13804f9e2f62fc1c91f9a64eb81351

URL: https://github.com/llvm/llvm-project/commit/45cb4140eb13804f9e2f62fc1c91f9a64eb81351
DIFF: https://github.com/llvm/llvm-project/commit/45cb4140eb13804f9e2f62fc1c91f9a64eb81351.diff

LOG: [mlir] Extend scf pipeling to support loop carried dependencies

Differential Revision: https://reviews.llvm.org/D106325

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
    mlir/test/Dialect/SCF/loop-pipelining.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 7cb36c958f4de..70297a93b1b50 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -74,7 +74,7 @@ struct LoopPipelinerInternal {
       PatternRewriter &rewriter);
   /// Emits the epilogue, this creates `maxStage - 1` part which will contain
   /// operations from stages [i; maxStage], where i is the part index.
-  void emitEpilogue(PatternRewriter &rewriter);
+  llvm::SmallVector<Value> emitEpilogue(PatternRewriter &rewriter);
 };
 
 bool LoopPipelinerInternal::initializeLoopInfo(
@@ -114,14 +114,25 @@ bool LoopPipelinerInternal::initializeLoopInfo(
           .wasInterrupted())
     return false;
 
-  // TODO: Add support for loop with operands.
-  if (forOp.getNumIterOperands() > 0)
+  // Only support loop carried dependency with a distance of 1. This means the
+  // source of all the scf.yield operands needs to be defined by operations in
+  // the loop.
+  if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
+                   [this](Value operand) {
+                     Operation *def = operand.getDefiningOp();
+                     return !def || stages.find(def) == stages.end();
+                   }))
     return false;
-
   return true;
 }
 
 void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) {
+  // Initialize the iteration argument to the loop initiale values.
+  for (BlockArgument &arg : forOp.getRegionIterArgs()) {
+    OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
+    setValueMapping(arg, operand.get(), 0);
+  }
+  auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
   for (int64_t i = 0; i < maxStage; i++) {
     // special handling for induction variable as the increment is implicit.
     Value iv = rewriter.create<ConstantIndexOp>(forOp.getLoc(), lb + i);
@@ -138,6 +149,14 @@ void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) {
       for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
         setValueMapping(op->getResult(destId), newOp->getResult(destId),
                         i - stages[op]);
+        // If the value is a loop carried dependency update the loop argument
+        // mapping.
+        for (OpOperand &operand : yield->getOpOperands()) {
+          if (operand.get() != op->getResult(destId))
+            continue;
+          setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
+                          newOp->getResult(destId), i - stages[op] + 1);
+        }
       }
     }
   }
@@ -173,7 +192,19 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
   // stages. The initial values come from the prologue created above.
   // Keep track of the kernel argument associated to each version of the
   // values passed to the kernel.
-  auto newLoopArg = llvm::to_vector<8>(forOp.getIterOperands());
+  llvm::SmallVector<Value> newLoopArg;
+  // For existing loop argument initialize them with the right version from the
+  // prologue.
+  for (auto retVal :
+       llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
+    Operation *def = retVal.value().getDefiningOp();
+    assert(def && "Only support loop carried dependencies of distance 1");
+    unsigned defStage = stages[def];
+    Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
+                                     [maxStage - defStage];
+    assert(valueVersion);
+    newLoopArg.push_back(valueVersion);
+  }
   for (auto escape : crossStageValues) {
     LiverangeInfo &info = escape.second;
     Value value = escape.first;
@@ -210,6 +241,9 @@ void LoopPipelinerInternal::createKernel(
   rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
   BlockAndValueMapping mapping;
   mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
+  for (auto arg : llvm::enumerate(forOp.getRegionIterArgs())) {
+    mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
+  }
   for (Operation *op : opOrder) {
     int64_t useStage = stages[op];
     auto *newOp = rewriter.clone(*op, mapping);
@@ -226,6 +260,23 @@ void LoopPipelinerInternal::createKernel(
         rewriter.setInsertionPointAfter(newOp);
         continue;
       }
+      auto arg = operand.get().dyn_cast<BlockArgument>();
+      if (arg && arg.getOwner() == forOp.getBody()) {
+        // If the value is a loop carried value coming from stage N + 1 remap,
+        // it will become a direct use.
+        Value ret = forOp.getBody()->getTerminator()->getOperand(
+            arg.getArgNumber() - 1);
+        Operation *dep = ret.getDefiningOp();
+        if (!dep)
+          continue;
+        auto stageDep = stages.find(dep);
+        if (stageDep == stages.end() || stageDep->second == useStage)
+          continue;
+        assert(stageDep->second == useStage + 1);
+        newOp->setOperand(operand.getOperandNumber(),
+                          mapping.lookupOrDefault(ret));
+        continue;
+      }
       // For operands defined in a previous stage we need to remap it to use
       // the correct region argument. We look for the right version of the
       // Value based on the stage where it is used.
@@ -249,6 +300,9 @@ void LoopPipelinerInternal::createKernel(
   // We create a mapping between original values and the associated loop
   // returned values that will be needed by the epilogue.
   llvm::SmallVector<Value> yieldOperands;
+  for (Value retVal : forOp.getBody()->getTerminator()->getOperands()) {
+    yieldOperands.push_back(mapping.lookupOrDefault(retVal));
+  }
   for (auto &it : crossStageValues) {
     int64_t version = maxStage - it.second.lastUseStage + 1;
     unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
@@ -266,10 +320,22 @@ void LoopPipelinerInternal::createKernel(
                     version++);
     yieldOperands.push_back(mapping.lookupOrDefault(it.first));
   }
+  // Map the yield operand to the forOp returned value.
+  for (auto retVal :
+       llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
+    Operation *def = retVal.value().getDefiningOp();
+    assert(def && "Only support loop carried dependencies of distance 1");
+    unsigned defStage = stages[def];
+    setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
+                    newForOp->getResult(retVal.index()),
+                    maxStage - defStage + 1);
+  }
   rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
 }
 
-void LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) {
+llvm::SmallVector<Value>
+LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) {
+  llvm::SmallVector<Value> returnValues(forOp->getNumResults());
   // Emit 
diff erent versions of the induction variable. They will be
   // removed by dead code if not used.
   for (int64_t i = 0; i < maxStage; i++) {
@@ -295,9 +361,27 @@ void LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) {
       for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
         setValueMapping(op->getResult(destId), newOp->getResult(destId),
                         maxStage - stages[op] + i);
+        // If the value is a loop carried dependency update the loop argument
+        // mapping and keep track of the last version to replace the original
+        // forOp uses.
+        for (OpOperand &operand :
+             forOp.getBody()->getTerminator()->getOpOperands()) {
+          if (operand.get() != op->getResult(destId))
+            continue;
+          unsigned version = maxStage - stages[op] + i + 1;
+          // If the version is greater than maxStage it means it maps to the
+          // original forOp returned value.
+          if (version > maxStage) {
+            returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
+            continue;
+          }
+          setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
+                          newOp->getResult(destId), version);
+        }
       }
     }
   }
+  return returnValues;
 }
 
 void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
@@ -361,12 +445,11 @@ struct ForLoopPipelining : public OpRewritePattern<ForOp> {
 
     // 4. Emit the epilogue after the new forOp.
     rewriter.setInsertionPointAfter(newForOp);
-    pipeliner.emitEpilogue(rewriter);
+    llvm::SmallVector<Value> returnValues = pipeliner.emitEpilogue(rewriter);
 
     // 5. Erase the original loop and replace the uses with the epilogue output.
     if (forOp->getNumResults() > 0)
-      rewriter.replaceOp(
-          forOp, newForOp.getResults().take_front(forOp->getNumResults()));
+      rewriter.replaceOp(forOp, returnValues);
     else
       rewriter.eraseOp(forOp);
 

diff  --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index fb3cce1ed7869..1b7e62571bb95 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -171,3 +171,118 @@ func @multiple_uses(%A: memref<?xf32>, %result: memref<?xf32>) {
   } { __test_pipelining_loop__ }
   return
 }
+
+// -----
+
+// CHECK-LABEL: loop_carried(
+//  CHECK-SAME:   %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>) {
+//   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = constant 1 : index
+//   CHECK-DAG:   %[[C3:.*]] = constant 3 : index
+//   CHECK-DAG:   %[[CSTF:.*]] = constant 1.000000e+00 : f32
+// Prologue:
+//       CHECK:   %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+// Kernel:
+//  CHECK-NEXT:   %[[LR:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
+//  CHECK-SAME:     step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
+//  CHECK-SAME:     %[[LARG:.*]] = %[[L0]]) -> (f32, f32) {
+//  CHECK-NEXT:     %[[ADD0:.*]] = addf %[[LARG]], %[[C]] : f32
+//  CHECK-NEXT:     %[[IV1:.*]] = addi %[[IV]], %[[C1]] : index
+//  CHECK-NEXT:     %[[L1:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
+//  CHECK-NEXT:     scf.yield %[[ADD0]], %[[L1]] : f32, f32
+//  CHECK-NEXT:   }
+// Epilogue:
+//  CHECK-NEXT:   %[[ADD1:.*]] = addf %[[LR]]#1, %[[LR]]#0 : f32
+//  CHECK-NEXT:   memref.store %[[ADD1]], %[[R]][%[[C0]]] : memref<?xf32>
+func @loop_carried(%A: memref<?xf32>, %result: memref<?xf32>) {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %cf = constant 1.0 : f32
+  %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 1 } : memref<?xf32>
+    %A1_elem = addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+    scf.yield %A1_elem : f32
+  }  { __test_pipelining_loop__ }
+  memref.store %r, %result[%c0] : memref<?xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: backedge_
diff erent_stage
+//  CHECK-SAME:   (%[[A:.*]]: memref<?xf32>) -> f32 {
+//   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = constant 1 : index
+//   CHECK-DAG:   %[[C2:.*]] = constant 2 : index
+//   CHECK-DAG:   %[[CSTF:.*]] = constant 1.000000e+00 : f32
+// Prologue:
+//       CHECK:   %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+//  CHECK-NEXT:   %[[ADD0:.*]] = addf %[[L0]], %[[CSTF]] : f32
+//  CHECK-NEXT:   %[[L1:.*]] = memref.load %[[A]][%[[C1]]] : memref<?xf32>
+// Kernel:
+//  CHECK-NEXT:   %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]]
+//  CHECK-SAME:     step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
+//  CHECK-SAME:     %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) {
+//  CHECK-NEXT:     %[[MUL0:.*]] = mulf %[[CSTF]], %[[ADDARG]] : f32
+//  CHECK-NEXT:     %[[ADD1:.*]] = addf %[[LARG]], %[[MUL0]] : f32
+//  CHECK-NEXT:     %[[IV2:.*]] = addi %[[IV]], %[[C2]] : index
+//  CHECK-NEXT:     %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref<?xf32>
+//  CHECK-NEXT:     scf.yield %[[MUL0]], %[[ADD1]], %[[L2]] : f32, f32, f32
+//  CHECK-NEXT:   }
+// Epilogue:
+//  CHECK-NEXT:   %[[MUL1:.*]] = mulf %[[CSTF]], %[[R]]#1 : f32
+//  CHECK-NEXT:   %[[ADD2:.*]] = addf %[[R]]#2, %[[MUL1]] : f32
+//  CHECK-NEXT:   %[[MUL2:.*]] = mulf %[[CSTF]], %[[ADD2]] : f32
+//  CHECK-NEXT:   return %[[MUL2]] : f32
+func @backedge_
diff erent_stage(%A: memref<?xf32>) -> f32 {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %cf = constant 1.0 : f32
+  %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+    %A1_elem = addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
+    %A2_elem = mulf %cf, %A1_elem { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : f32
+    scf.yield %A2_elem : f32
+  }  { __test_pipelining_loop__ }
+  return %r : f32
+}
+
+// -----
+
+// CHECK-LABEL: backedge_same_stage
+//  CHECK-SAME:   (%[[A:.*]]: memref<?xf32>) -> f32 {
+//   CHECK-DAG:   %[[C0:.*]] = constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = constant 1 : index
+//   CHECK-DAG:   %[[C3:.*]] = constant 3 : index
+//   CHECK-DAG:   %[[CSTF:.*]] = constant 1.000000e+00 : f32
+// Prologue:
+//       CHECK:   %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+// Kernel:
+//  CHECK-NEXT:   %[[R:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
+//  CHECK-SAME:     step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
+//  CHECK-SAME:     %[[LARG:.*]] = %[[L0]]) -> (f32, f32) {
+//  CHECK-NEXT:     %[[ADD0:.*]] = addf %[[LARG]], %[[C]] : f32
+//  CHECK-NEXT:     %[[MUL0:.*]] = mulf %[[CSTF]], %[[ADD0]] : f32
+//  CHECK-NEXT:     %[[IV1:.*]] = addi %[[IV]], %[[C1]] : index
+//  CHECK-NEXT:     %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32>
+//  CHECK-NEXT:     scf.yield %[[MUL0]], %[[L2]] : f32, f32
+//  CHECK-NEXT:   }
+// Epilogue:
+//  CHECK-NEXT:   %[[ADD1:.*]] = addf %[[R]]#1, %[[R]]#0 : f32
+//  CHECK-NEXT:   %[[MUL1:.*]] = mulf %[[CSTF]], %[[ADD1]] : f32
+//  CHECK-NEXT:   return %[[MUL1]] : f32
+func @backedge_same_stage(%A: memref<?xf32>) -> f32 {
+  %c0 = constant 0 : index
+  %c1 = constant 1 : index
+  %c4 = constant 4 : index
+  %cf = constant 1.0 : f32
+  %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+    %A1_elem = addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+    %A2_elem = mulf %cf, %A1_elem { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
+    scf.yield %A2_elem : f32
+  }  { __test_pipelining_loop__ }
+  return %r : f32
+}


        


More information about the Mlir-commits mailing list