[Mlir-commits] [mlir] 117db47 - [mlir][scf] Fix bug in software pipeliner and simplify the logic

Thomas Raoux llvmlistbot at llvm.org
Wed Mar 8 12:06:32 PST 2023


Author: Thomas Raoux
Date: 2023-03-08T20:06:07Z
New Revision: 117db47d02c174e2ec039fa8b6a97381106e6238

URL: https://github.com/llvm/llvm-project/commit/117db47d02c174e2ec039fa8b6a97381106e6238
DIFF: https://github.com/llvm/llvm-project/commit/117db47d02c174e2ec039fa8b6a97381106e6238.diff

LOG: [mlir][scf] Fix bug in software pipeliner and simplify the logic

Fix bug when pipelining while interleaving stages. Re-do the logic to
only consider cloned operands when updating the use-def chain.

Differential Revision: https://reviews.llvm.org/D145598

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
    mlir/test/Dialect/SCF/loop-pipelining.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index b9182f5a073ed..4a7175b109614 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -294,81 +294,6 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
   return newForOp;
 }
 
-/// Replace any use of `target` with `replacement` in `op`'s operands or within
-/// `op`'s nested regions.
-static void replaceInOp(Operation *op, Value target, Value replacement) {
-  for (auto &use : llvm::make_early_inc_range(target.getUses())) {
-    if (op->isAncestor(use.getOwner()))
-      use.set(replacement);
-  }
-}
-
-/// Given a cloned op in the new kernel body, updates induction variable uses.
-/// We replace it with a version incremented based on the stage where it is
-/// used.
-static void updateInductionVariableUses(RewriterBase &rewriter, Location loc,
-                                        Operation *newOp, Value newForIv,
-                                        unsigned maxStage, unsigned useStage,
-                                        unsigned step) {
-  rewriter.setInsertionPoint(newOp);
-  Value offset = rewriter.create<arith::ConstantIndexOp>(
-      loc, (maxStage - useStage) * step);
-  Value iv = rewriter.create<arith::AddIOp>(loc, newForIv, offset);
-  replaceInOp(newOp, newForIv, iv);
-  rewriter.setInsertionPointAfter(newOp);
-}
-
-/// If the value is a loop carried value coming from stage N + 1 remap, it will
-/// become a direct use.
-static void updateIterArgUses(RewriterBase &rewriter, IRMapping &bvm,
-                              Operation *newOp, ForOp oldForOp, ForOp newForOp,
-                              unsigned useStage,
-                              const DenseMap<Operation *, unsigned> &stages) {
-
-  for (unsigned i = 0; i < oldForOp.getNumRegionIterArgs(); i++) {
-    Value yieldedVal = oldForOp.getBody()->getTerminator()->getOperand(i);
-    Operation *dep = yieldedVal.getDefiningOp();
-    if (!dep)
-      continue;
-    auto stageDep = stages.find(dep);
-    if (stageDep == stages.end() || stageDep->second == useStage)
-      continue;
-    if (stageDep->second != useStage + 1)
-      continue;
-    Value replacement = bvm.lookup(yieldedVal);
-    replaceInOp(newOp, newForOp.getRegionIterArg(i), replacement);
-  }
-}
-
-/// For operands defined in a previous stage we need to remap it to use the
-/// correct region argument. We look for the right version of the Value based
-/// on the stage where it is used.
-static void updateCrossStageUses(
-    RewriterBase &rewriter, Operation *newOp, IRMapping &bvm, ForOp newForOp,
-    unsigned useStage, const DenseMap<Operation *, unsigned> &stages,
-    const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
-  // Because we automatically cloned the sub-regions, there's no simple way
-  // to walk the nested regions in pairs of (oldOps, newOps), so we just
-  // traverse the set of remapped loop arguments, filter which ones are
-  // relevant, and replace any uses.
-  for (auto [remapPair, newIterIdx] : loopArgMap) {
-    auto [crossArgValue, stageIdx] = remapPair;
-    Operation *def = crossArgValue.getDefiningOp();
-    assert(def);
-    unsigned stageDef = stages.lookup(def);
-    if (useStage <= stageDef || useStage - stageDef != stageIdx)
-      continue;
-
-    // Use "lookupOrDefault" for the target value because some operations
-    // are remapped, while in other cases the original will be present.
-    Value target = bvm.lookupOrDefault(crossArgValue);
-    Value replacement = newForOp.getRegionIterArg(newIterIdx);
-
-    // Replace uses in the new op's operands and any nested uses.
-    replaceInOp(newOp, target, replacement);
-  }
-}
-
 void LoopPipelinerInternal::createKernel(
     scf::ForOp newForOp,
     const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
@@ -400,16 +325,59 @@ void LoopPipelinerInternal::createKernel(
   for (Operation *op : opOrder) {
     int64_t useStage = stages[op];
     auto *newOp = rewriter.clone(*op, mapping);
-
-    // Within the kernel body, update uses of the induction variable, uses of
-    // the original iter args, and uses of cross stage values.
-    updateInductionVariableUses(rewriter, forOp.getLoc(), newOp,
-                                newForOp.getInductionVar(), maxStage,
-                                stages[op], step);
-    updateIterArgUses(rewriter, mapping, newOp, forOp, newForOp, useStage,
-                      stages);
-    updateCrossStageUses(rewriter, newOp, mapping, newForOp, useStage, stages,
-                         loopArgMap);
+    SmallVector<OpOperand *> operands;
+    // Collect all the operands for the cloned op and its nested ops.
+    op->walk([&operands](Operation *nestedOp) {
+      for (OpOperand &operand : nestedOp->getOpOperands()) {
+        operands.push_back(&operand);
+      }
+    });
+    for (OpOperand *operand : operands) {
+      Operation *nestedNewOp = mapping.lookup(operand->getOwner());
+      // Special case for the induction variable uses. We replace it with a
+      // version incremented based on the stage where it is used.
+      if (operand->get() == forOp.getInductionVar()) {
+        rewriter.setInsertionPoint(newOp);
+        Value offset = rewriter.create<arith::ConstantIndexOp>(
+            forOp.getLoc(), (maxStage - stages[op]) * step);
+        Value iv = rewriter.create<arith::AddIOp>(
+            forOp.getLoc(), newForOp.getInductionVar(), offset);
+        nestedNewOp->setOperand(operand->getOperandNumber(), iv);
+        rewriter.setInsertionPointAfter(newOp);
+        continue;
+      }
+      auto arg = operand->get().dyn_cast<BlockArgument>();
+      if (arg && arg.getOwner() == forOp.getBody()) {
+        // If the value is a loop carried value coming from stage N + 1 remap,
+        // it will become a direct use.
+        Value ret = forOp.getBody()->getTerminator()->getOperand(
+            arg.getArgNumber() - 1);
+        Operation *dep = ret.getDefiningOp();
+        if (!dep)
+          continue;
+        auto stageDep = stages.find(dep);
+        if (stageDep == stages.end() || stageDep->second == useStage)
+          continue;
+        assert(stageDep->second == useStage + 1);
+        nestedNewOp->setOperand(operand->getOperandNumber(),
+                                mapping.lookupOrDefault(ret));
+        continue;
+      }
+      // For operands defined in a previous stage we need to remap it to use
+      // the correct region argument. We look for the right version of the
+      // Value based on the stage where it is used.
+      Operation *def = operand->get().getDefiningOp();
+      if (!def)
+        continue;
+      auto stageDef = stages.find(def);
+      if (stageDef == stages.end() || stageDef->second == useStage)
+        continue;
+      auto remap = loopArgMap.find(
+          std::make_pair(operand->get(), useStage - stageDef->second));
+      assert(remap != loopArgMap.end());
+      nestedNewOp->setOperand(operand->getOperandNumber(),
+                              newForOp.getRegionIterArgs()[remap->second]);
+    }
 
     if (predicates[useStage]) {
       newOp = predicateFn(newOp, predicates[useStage], rewriter);

diff  --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 68b513362a250..0309287e409c1 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -627,3 +627,47 @@ func.func @pipeline_op_with_region(%A: memref<?xf32>, %B: memref<?xf32>, %result
   }  { __test_pipelining_loop__ }
   return
 }
+
+// -----
+
+// CHECK-LABEL: @backedge_mix_order
+//  CHECK-SAME:   (%[[A:.*]]: memref<?xf32>) -> f32 {
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[CSTF:.*]] = arith.constant 2.000000e+00 : f32
+// Prologue:
+//       CHECK:   %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+//  CHECK-NEXT:   %[[L1:.*]] = memref.load %[[A]][%[[C1]]] : memref<?xf32>
+// Kernel:
+//  CHECK-NEXT:   %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
+//  CHECK-SAME:     step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
+//  CHECK-SAME:     %[[ARG1:.*]] = %[[L0]], %[[ARG2:.*]] = %[[L1]]) -> (f32, f32, f32) {
+//  CHECK-NEXT:     %[[IV2:.*]] = arith.addi %[[IV]], %[[C1]] : index
+//  CHECK-NEXT:     %[[L2:.*]] = memref.load %[[A]][%[[IV2]]] : memref<?xf32>
+//  CHECK-NEXT:     %[[MUL0:.*]] = arith.mulf %[[C]], %[[ARG1]] : f32
+//  CHECK-NEXT:     %[[IV3:.*]] = arith.addi %[[IV]], %[[C1]] : index
+//  CHECK-NEXT:     %[[IV4:.*]] = arith.addi %[[IV3]], %[[C1]] : index
+//  CHECK-NEXT:     %[[L3:.*]] = memref.load %[[A]][%[[IV4]]] : memref<?xf32>
+//  CHECK-NEXT:     %[[MUL1:.*]] = arith.mulf %[[ARG2]], %[[MUL0]] : f32
+//  CHECK-NEXT:     scf.yield %[[MUL1]], %[[L2]], %[[L3]] : f32, f32, f32
+//  CHECK-NEXT:   }
+// Epilogue:
+//  CHECK-NEXT:   %[[MUL1:.*]] = arith.mulf %[[R]]#0, %[[R]]#1 : f32
+//  CHECK-NEXT:   %[[MUL2:.*]] = arith.mulf %[[R]]#2, %[[MUL1]] : f32
+//  CHECK-NEXT:   return %[[MUL2]] : f32
+func.func @backedge_mix_order(%A: memref<?xf32>) -> f32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %cf = arith.constant 2.0 : f32
+  %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 0 } : memref<?xf32>
+    %A2_elem = arith.mulf %arg0, %A_elem { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
+    %i1 = arith.addi %i0, %c1 { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : index
+    %A1_elem = memref.load %A[%i1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref<?xf32>
+    %A3_elem = arith.mulf %A1_elem, %A2_elem { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 4 } : f32
+    scf.yield %A3_elem : f32
+  }  { __test_pipelining_loop__ }
+  return %r : f32
+}
\ No newline at end of file


        


More information about the Mlir-commits mailing list