[Mlir-commits] [mlir] f5fe92f - [mlir][SCF] Fix loop pipelining unable to handle ops with regions

Christopher Bate llvmlistbot at llvm.org
Tue Sep 20 20:59:03 PDT 2022


Author: Christopher Bate
Date: 2022-09-20T21:58:53-06:00
New Revision: f5fe92f6938511a8e8c6ec850b842a499c8b040f

URL: https://github.com/llvm/llvm-project/commit/f5fe92f6938511a8e8c6ec850b842a499c8b040f
DIFF: https://github.com/llvm/llvm-project/commit/f5fe92f6938511a8e8c6ec850b842a499c8b040f.diff

LOG: [mlir][SCF] Fix loop pipelining unable to handle ops with regions

This change allows the SCF LoopPipelining transform to handle ops with
nested regions within the pipelined `scf.for` body. The op and nested
regions are treated as a single unit from the transform's perspective.
This change also makes explicit the requirement that only ops whose
parent Block is the loop body Block are allowed to be scheduled by the
caller.

Reviewed By: ThomasRaoux, nicolasvasilache

Differential Revision: https://reviews.llvm.org/D133965

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
    mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
    mlir/test/Dialect/SCF/loop-pipelining.mlir
    mlir/test/lib/Dialect/SCF/CMakeLists.txt
    mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 9ed0ae8556b61..fe7f1b03d3f1e 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -233,6 +233,12 @@ def ForOp : SCF_Op<"for",
     Block::BlockArgListType getRegionIterArgs() {
       return getBody()->getArguments().drop_front(getNumInductionVars());
     }
+    /// Return the `index`-th region iteration argument.
+    BlockArgument getRegionIterArg(unsigned index) {
+      assert(index < getNumRegionIterArgs() && 
+        "expected an index less than the number of region iter args");
+      return getBody()->getArguments().drop_front(getNumInductionVars())[index];
+    }
     Operation::operand_range getIterOperands() {
       return getOperands().drop_front(getNumControlOperands());
     }

diff  --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 7a162dce81935..fc3725d21e475 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/BlockAndValueMapping.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Support/MathExtras.h"
+#include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/MapVector.h"
 
 using namespace mlir;
@@ -114,15 +115,28 @@ bool LoopPipelinerInternal::initializeLoopInfo(
     return false;
 
   // All operations need to have a stage.
-  if (forOp
-          .walk([this](Operation *op) {
-            if (op != forOp.getOperation() && !isa<scf::YieldOp>(op) &&
-                stages.find(op) == stages.end())
-              return WalkResult::interrupt();
-            return WalkResult::advance();
-          })
-          .wasInterrupted())
-    return false;
+  for (Operation &op : forOp.getBody()->without_terminator()) {
+    if (stages.find(&op) == stages.end()) {
+      op.emitOpError("not assigned a pipeline stage");
+      return false;
+    }
+  }
+
+  // Currently, we do not support assigning stages to ops in nested regions. The
+  // block of all operations assigned a stage should be the single `scf.for`
+  // body block.
+  for (const auto &[op, stageNum] : stages) {
+    (void)stageNum;
+    if (op == forOp.getBody()->getTerminator()) {
+      op->emitError("terminator should not be assigned a stage");
+      return false;
+    }
+    if (op->getBlock() != forOp.getBody()) {
+      op->emitOpError("the owning Block of all operations assigned a stage "
+                      "should be the loop body block");
+      return false;
+    }
+  }
 
   // Only support loop carried dependency with a distance of 1. This means the
   // source of all the scf.yield operands needs to be defined by operations in
@@ -137,6 +151,27 @@ bool LoopPipelinerInternal::initializeLoopInfo(
   return true;
 }
 
+/// Clone `op` and call `callback` on the cloned op's oeprands as well as any
+/// operands of nested ops that:
+/// 1) aren't defined within the new op or
+/// 2) are block arguments.
+static Operation *
+cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
+                       function_ref<void(OpOperand *newOperand)> callback) {
+  Operation *clone = rewriter.clone(*op);
+  for (OpOperand &operand : clone->getOpOperands())
+    callback(&operand);
+  clone->walk([&](Operation *nested) {
+    for (OpOperand &operand : nested->getOpOperands()) {
+      Operation *def = operand.get().getDefiningOp();
+      if ((def && !clone->isAncestor(def)) ||
+          operand.get().isa<BlockArgument>())
+        callback(&operand);
+    }
+  });
+  return clone;
+}
+
 void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) {
   // Initialize the iteration argument to the loop initiale values.
   for (BlockArgument &arg : forOp.getRegionIterArgs()) {
@@ -152,12 +187,14 @@ void LoopPipelinerInternal::emitPrologue(PatternRewriter &rewriter) {
     for (Operation *op : opOrder) {
       if (stages[op] > i)
         continue;
-      Operation *newOp = rewriter.clone(*op);
-      for (unsigned opIdx = 0; opIdx < op->getNumOperands(); opIdx++) {
-        auto it = valueMapping.find(op->getOperand(opIdx));
-        if (it != valueMapping.end())
-          newOp->setOperand(opIdx, it->second[i - stages[op]]);
-      }
+      Operation *newOp =
+          cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
+            auto it = valueMapping.find(newOperand->get());
+            if (it != valueMapping.end()) {
+              Value replacement = it->second[i - stages[op]];
+              newOperand->set(replacement);
+            }
+          });
       if (annotateFn)
         annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
       for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
@@ -181,18 +218,25 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
   llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
   for (Operation *op : opOrder) {
     unsigned stage = stages[op];
-    for (OpOperand &operand : op->getOpOperands()) {
+
+    auto analyzeOperand = [&](OpOperand &operand) {
       Operation *def = operand.get().getDefiningOp();
       if (!def)
-        continue;
+        return;
       auto defStage = stages.find(def);
       if (defStage == stages.end() || defStage->second == stage)
-        continue;
+        return;
       assert(stage > defStage->second);
       LiverangeInfo &info = crossStageValues[operand.get()];
       info.defStage = defStage->second;
       info.lastUseStage = std::max(info.lastUseStage, stage);
-    }
+    };
+
+    for (OpOperand &operand : op->getOpOperands())
+      analyzeOperand(operand);
+    visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) {
+      analyzeOperand(*operand);
+    });
   }
   return crossStageValues;
 }
@@ -243,9 +287,89 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
   auto newForOp =
       rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
                                   forOp.getStep(), newLoopArg);
+  // When there are no iter args, the loop body terminator will be created.
+  // Since we always create it below, remove the terminator if it was created.
+  if (!newForOp.getBody()->empty())
+    rewriter.eraseOp(newForOp.getBody()->getTerminator());
   return newForOp;
 }
 
+/// Replace any use of `target` with `replacement` in `op`'s operands or within
+/// `op`'s nested regions.
+static void replaceInOp(Operation *op, Value target, Value replacement) {
+  for (auto &use : llvm::make_early_inc_range(target.getUses())) {
+    if (op->isAncestor(use.getOwner()))
+      use.set(replacement);
+  }
+}
+
+/// Given a cloned op in the new kernel body, updates induction variable uses.
+/// We replace it with a version incremented based on the stage where it is
+/// used.
+static void updateInductionVariableUses(RewriterBase &rewriter, Location loc,
+                                        Operation *newOp, Value newForIv,
+                                        unsigned maxStage, unsigned useStage,
+                                        unsigned step) {
+  rewriter.setInsertionPoint(newOp);
+  Value offset = rewriter.create<arith::ConstantIndexOp>(
+      loc, (maxStage - useStage) * step);
+  Value iv = rewriter.create<arith::AddIOp>(loc, newForIv, offset);
+  replaceInOp(newOp, newForIv, iv);
+  rewriter.setInsertionPointAfter(newOp);
+}
+
+/// If the value is a loop carried value coming from stage N + 1 remap, it will
+/// become a direct use.
+static void updateIterArgUses(RewriterBase &rewriter, BlockAndValueMapping &bvm,
+                              Operation *newOp, ForOp oldForOp, ForOp newForOp,
+                              unsigned useStage,
+                              const DenseMap<Operation *, unsigned> &stages) {
+
+  for (unsigned i = 0; i < oldForOp.getNumRegionIterArgs(); i++) {
+    Value yieldedVal = oldForOp.getBody()->getTerminator()->getOperand(i);
+    Operation *dep = yieldedVal.getDefiningOp();
+    if (!dep)
+      continue;
+    auto stageDep = stages.find(dep);
+    if (stageDep == stages.end() || stageDep->second == useStage)
+      continue;
+    if (stageDep->second != useStage + 1)
+      continue;
+    Value replacement = bvm.lookup(yieldedVal);
+    replaceInOp(newOp, newForOp.getRegionIterArg(i), replacement);
+  }
+}
+
+/// For operands defined in a previous stage we need to remap it to use the
+/// correct region argument. We look for the right version of the Value based
+/// on the stage where it is used.
+static void updateCrossStageUses(
+    RewriterBase &rewriter, Operation *newOp, BlockAndValueMapping &bvm,
+    ForOp newForOp, unsigned useStage,
+    const DenseMap<Operation *, unsigned> &stages,
+    const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
+  // Because we automatically cloned the sub-regions, there's no simple way
+  // to walk the nested regions in pairs of (oldOps, newOps), so we just
+  // traverse the set of remapped loop arguments, filter which ones are
+  // relevant, and replace any uses.
+  for (auto [remapPair, newIterIdx] : loopArgMap) {
+    auto [crossArgValue, stageIdx] = remapPair;
+    Operation *def = crossArgValue.getDefiningOp();
+    assert(def);
+    unsigned stageDef = stages.lookup(def);
+    if (useStage <= stageDef || useStage - stageDef != stageIdx)
+      continue;
+
+    // Use "lookupOrDefault" for the target value because some operations
+    // are remapped, while in other cases the original will be present.
+    Value target = bvm.lookupOrDefault(crossArgValue);
+    Value replacement = newForOp.getRegionIterArg(newIterIdx);
+
+    // Replace uses in the new op's operands and any nested uses.
+    replaceInOp(newOp, target, replacement);
+  }
+}
+
 void LoopPipelinerInternal::createKernel(
     scf::ForOp newForOp,
     const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
@@ -277,51 +401,17 @@ void LoopPipelinerInternal::createKernel(
   for (Operation *op : opOrder) {
     int64_t useStage = stages[op];
     auto *newOp = rewriter.clone(*op, mapping);
-    for (OpOperand &operand : op->getOpOperands()) {
-      // Special case for the induction variable uses. We replace it with a
-      // version incremented based on the stage where it is used.
-      if (operand.get() == forOp.getInductionVar()) {
-        rewriter.setInsertionPoint(newOp);
-        Value offset = rewriter.create<arith::ConstantIndexOp>(
-            forOp.getLoc(), (maxStage - stages[op]) * step);
-        Value iv = rewriter.create<arith::AddIOp>(
-            forOp.getLoc(), newForOp.getInductionVar(), offset);
-        newOp->setOperand(operand.getOperandNumber(), iv);
-        rewriter.setInsertionPointAfter(newOp);
-        continue;
-      }
-      auto arg = operand.get().dyn_cast<BlockArgument>();
-      if (arg && arg.getOwner() == forOp.getBody()) {
-        // If the value is a loop carried value coming from stage N + 1 remap,
-        // it will become a direct use.
-        Value ret = forOp.getBody()->getTerminator()->getOperand(
-            arg.getArgNumber() - 1);
-        Operation *dep = ret.getDefiningOp();
-        if (!dep)
-          continue;
-        auto stageDep = stages.find(dep);
-        if (stageDep == stages.end() || stageDep->second == useStage)
-          continue;
-        assert(stageDep->second == useStage + 1);
-        newOp->setOperand(operand.getOperandNumber(),
-                          mapping.lookupOrDefault(ret));
-        continue;
-      }
-      // For operands defined in a previous stage we need to remap it to use
-      // the correct region argument. We look for the right version of the
-      // Value based on the stage where it is used.
-      Operation *def = operand.get().getDefiningOp();
-      if (!def)
-        continue;
-      auto stageDef = stages.find(def);
-      if (stageDef == stages.end() || stageDef->second == useStage)
-        continue;
-      auto remap = loopArgMap.find(
-          std::make_pair(operand.get(), useStage - stageDef->second));
-      assert(remap != loopArgMap.end());
-      newOp->setOperand(operand.getOperandNumber(),
-                        newForOp.getRegionIterArgs()[remap->second]);
-    }
+
+    // Within the kernel body, update uses of the induction variable, uses of
+    // the original iter args, and uses of cross stage values.
+    updateInductionVariableUses(rewriter, forOp.getLoc(), newOp,
+                                newForOp.getInductionVar(), maxStage,
+                                stages[op], step);
+    updateIterArgUses(rewriter, mapping, newOp, forOp, newForOp, useStage,
+                      stages);
+    updateCrossStageUses(rewriter, newOp, mapping, newForOp, useStage, stages,
+                         loopArgMap);
+
     if (predicates[useStage]) {
       newOp = predicateFn(newOp, predicates[useStage], rewriter);
       // Remap the results to the new predicated one.
@@ -382,21 +472,20 @@ LoopPipelinerInternal::emitEpilogue(PatternRewriter &rewriter) {
         forOp.getLoc(), lb + step * ((((ub - 1) - lb) / step) - i));
     setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
   }
-  // Emit `maxStage - 1` epilogue part that includes operations fro stages
+  // Emit `maxStage - 1` epilogue part that includes operations from stages
   // [i; maxStage].
   for (int64_t i = 1; i <= maxStage; i++) {
     for (Operation *op : opOrder) {
       if (stages[op] < i)
         continue;
-      Operation *newOp = rewriter.clone(*op);
-      for (unsigned opIdx = 0; opIdx < op->getNumOperands(); opIdx++) {
-        auto it = valueMapping.find(op->getOperand(opIdx));
-        if (it != valueMapping.end()) {
-          Value v = it->second[maxStage - stages[op] + i];
-          assert(v);
-          newOp->setOperand(opIdx, v);
-        }
-      }
+      Operation *newOp =
+          cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
+            auto it = valueMapping.find(newOperand->get());
+            if (it != valueMapping.end()) {
+              Value replacement = it->second[maxStage - stages[op] + i];
+              newOperand->set(replacement);
+            }
+          });
       if (annotateFn)
         annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
       for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {

diff  --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 0246231f5b743..56e01fed6c008 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -34,6 +34,54 @@ func.func @simple_pipeline(%A: memref<?xf32>, %result: memref<?xf32>) {
   return
 }
 
+
+// -----
+
+// CHECK-LABEL: simple_pipeline_region(
+//  CHECK-SAME:   %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>) {
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+// Prologue:
+//       CHECK:   %[[L0:.*]] = scf.execute_region 
+//  CHECK-NEXT:     memref.load %[[A]][%[[C0]]] : memref<?xf32>
+// Kernel:
+//       CHECK:   %[[L1:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]]
+//  CHECK-SAME:     step %[[C1]] iter_args(%[[LARG:.*]] = %[[L0]]) -> (f32) {
+//  CHECK-NEXT:     %[[ADD0:.*]] = scf.execute_region
+//  CHECK-NEXT:       arith.addf %[[LARG]], %{{.*}} : f32
+//       CHECK:     memref.store %[[ADD0]], %[[R]][%[[IV]]] : memref<?xf32>
+//  CHECK-NEXT:     %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index
+//  CHECK-NEXT:     %[[LR:.*]] = scf.execute_region
+//  CHECK-NEXT:       memref.load %[[A]][%[[IV1]]] : memref<?xf32>
+//       CHECK:     scf.yield %[[LR]] : f32
+//  CHECK-NEXT:   }
+// Epilogue:
+//  CHECK-NEXT:   %[[ADD1:.*]] = scf.execute_region
+//  CHECK-NEXT:     arith.addf %[[L1]], %{{.*}} : f32
+//       CHECK:   memref.store %[[ADD1]], %[[R]][%[[C3]]] : memref<?xf32>
+func.func @simple_pipeline_region(%A: memref<?xf32>, %result: memref<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %cf = arith.constant 1.0 : f32
+  scf.for %i0 = %c0 to %c4 step %c1 {
+
+    %A_elem = scf.execute_region -> f32 {
+      %A_elem1 = memref.load %A[%i0]  : memref<?xf32>
+      scf.yield %A_elem1 : f32
+    } { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 }
+
+    %A1_elem = scf.execute_region -> f32 {
+      %A1_elem1 = arith.addf %A_elem, %cf  : f32
+      scf.yield %A1_elem1 : f32
+    } { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 }
+
+    memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : memref<?xf32>
+  }  { __test_pipelining_loop__ }
+  return
+}
+
 // -----
 
 // CHECK-LABEL: simple_pipeline_step(
@@ -269,6 +317,65 @@ func.func @multiple_uses(%A: memref<?xf32>, %result: memref<?xf32>) {
 
 // -----
 
+// CHECK-LABEL: region_multiple_uses(
+//  CHECK-SAME:   %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>) {
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG:   %[[C7:.*]] = arith.constant 7 : index
+//   CHECK-DAG:   %[[C8:.*]] = arith.constant 8 : index
+//   CHECK-DAG:   %[[C9:.*]] = arith.constant 9 : index
+// Prologue:
+//       CHECK:   %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32>
+//  CHECK-NEXT:   %[[ADD0:.*]] = arith.addf %[[L0]], %{{.*}} : f32
+//  CHECK-NEXT:   %[[L1:.*]] = memref.load %[[A]][%[[C1]]] : memref<?xf32>
+//  CHECK-NEXT:   %[[ADD1:.*]] = arith.addf %[[L1]], %{{.*}} : f32
+//  CHECK-NEXT:   %[[MUL0:.*]] = scf.execute_region
+// arith.mulf %[[ADD0]], %[[L0]] : f32
+//  CHECK:   %[[L2:.*]] = memref.load %[[A]][%[[C2]]] : memref<?xf32>
+// Kernel:
+//  CHECK-NEXT:   %[[LR:.*]]:4 = scf.for %[[IV:.*]] = %[[C0]] to %[[C7]]
+//  CHECK-SAME:     step %[[C1]] iter_args(%[[LA1:.*]] = %[[L1]],
+//  CHECK-SAME:     %[[LA2:.*]] = %[[L2]], %[[ADDARG1:.*]] = %[[ADD1]],
+//  CHECK-SAME:     %[[MULARG0:.*]] = %[[MUL0]]) -> (f32, f32, f32, f32) {
+//  CHECK-NEXT:     %[[ADD2:.*]] = arith.addf %[[LA2]], %{{.*}} : f32
+//  CHECK-NEXT:     %[[MUL1:.*]] = scf.execute_region
+// arith.mulf %[[ADDARG1]], %[[LA1]] : f32
+//       CHECK:     memref.store %[[MULARG0]], %[[R]][%[[IV]]] : memref<?xf32>
+//  CHECK-NEXT:     %[[IV3:.*]] = arith.addi %[[IV]], %[[C3]] : index
+//  CHECK-NEXT:     %[[L3:.*]] = memref.load %[[A]][%[[IV3]]] : memref<?xf32>
+//  CHECK-NEXT:     scf.yield %[[LA2]], %[[L3]], %[[ADD2]], %[[MUL1]] : f32, f32, f32, f32
+//  CHECK-NEXT:   }
+// Epilogue:
+//  CHECK-NEXT:   %[[ADD3:.*]] = arith.addf %[[LR]]#1, %{{.*}} : f32
+//  CHECK-NEXT:   %[[MUL2:.*]] = scf.execute_region
+// arith.mulf %[[LR]]#2, %[[LR]]#0 : f32
+//       CHECK:   memref.store %[[LR]]#3, %[[R]][%[[C7]]] : memref<?xf32>
+//  CHECK-NEXT:   %[[MUL3:.*]] = scf.execute_region 
+/// %[[ADD3]], %[[LR]]#1 : f32
+//       CHECK:   memref.store %[[MUL2]], %[[R]][%[[C8]]] : memref<?xf32>
+//  CHECK-NEXT:   memref.store %[[MUL3]], %[[R]][%[[C9]]] : memref<?xf32>
+
+func.func @region_multiple_uses(%A: memref<?xf32>, %result: memref<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  %cf = arith.constant 1.0 : f32
+  scf.for %i0 = %c0 to %c10 step %c1 {
+    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref<?xf32>
+    %A1_elem = arith.addf %A_elem, %cf { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
+    %A2_elem = scf.execute_region -> f32 {
+      %A2_elem1 = arith.mulf %A1_elem, %A_elem : f32
+      scf.yield %A2_elem1 : f32
+    } { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 1 }
+    memref.store %A2_elem, %result[%i0] { __test_pipelining_stage__ = 3, __test_pipelining_op_order__ = 2 } : memref<?xf32>
+  } { __test_pipelining_loop__ }
+  return
+}
+
+// -----
+
 // CHECK-LABEL: loop_carried(
 //  CHECK-SAME:   %[[A:.*]]: memref<?xf32>, %[[R:.*]]: memref<?xf32>) {
 //   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
@@ -341,6 +448,58 @@ func.func @backedge_
diff erent_stage(%A: memref<?xf32>) -> f32 {
   return %r : f32
 }
 
+// -----
+
+// CHECK-LABEL: region_backedge_
diff erent_stage
+//  CHECK-SAME:   (%[[A:.*]]: memref<?xf32>) -> f32 {
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG:   %[[CSTF:.*]] = arith.constant 1.000000e+00 : f32
+// Prologue:
+//       CHECK:   %[[L0:.*]] = scf.execute_region
+//  CHECK-NEXT:     memref.load %[[A]][%[[C0]]] : memref<?xf32>
+//       CHECK:   %[[ADD0:.*]] = scf.execute_region
+//  CHECK-NEXT:   arith.addf %[[L0]], %[[CSTF]] : f32
+//       CHECK:   %[[L1:.*]] = scf.execute_region
+//  CHECK-NEXT:     memref.load %[[A]][%[[C1]]] : memref<?xf32>
+// Kernel:
+//       CHECK:   %[[R:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] to %[[C2]]
+//  CHECK-SAME:     step %[[C1]] iter_args(%[[C:.*]] = %[[CSTF]],
+//  CHECK-SAME:     %[[ADDARG:.*]] = %[[ADD0]], %[[LARG:.*]] = %[[L1]]) -> (f32, f32, f32) {
+//       CHECK:     %[[ADD1:.*]] = scf.execute_region
+//  CHECK-NEXT:       arith.addf %[[LARG]], %[[ADDARG]] : f32
+//       CHECK:     %[[IV2:.*]] = arith.addi %[[IV]], %[[C2]] : index
+//       CHECK:     %[[L2:.*]] = scf.execute_region
+//  CHECK-NEXT:       memref.load %[[A]][%[[IV2]]] : memref<?xf32>
+//       CHECK:     scf.yield %[[ADDARG]], %[[ADD1]], %[[L2]] : f32, f32, f32
+//  CHECK-NEXT:   }
+// Epilogue:
+//       CHECK:   %[[ADD2:.*]] = scf.execute_region
+//  CHECK-NEXT:    arith.addf %[[R]]#2, %[[R]]#1 : f32
+//       CHECK:   return %[[ADD2]] : f32
+
+func.func @region_backedge_
diff erent_stage(%A: memref<?xf32>) -> f32 {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %cf = arith.constant 1.0 : f32
+  %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf) -> (f32) {
+    %A_elem = scf.execute_region -> f32 {
+      %A_elem1 = memref.load %A[%i0] : memref<?xf32>
+      scf.yield %A_elem1 : f32
+    } { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 }    
+    %A1_elem = scf.execute_region -> f32 {
+      %inner = arith.addf %A_elem, %arg0 : f32
+      scf.yield %inner : f32 
+    }  { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 }
+    %A2_elem = arith.mulf %cf, %A1_elem { __test_pipelining_stage__ = 2, __test_pipelining_op_order__ = 0 } : f32
+    scf.yield %A2_elem : f32
+  }  { __test_pipelining_loop__ }
+  return %r : f32
+}
+
+
 // -----
 
 // CHECK-LABEL: backedge_same_stage
@@ -376,3 +535,88 @@ func.func @backedge_same_stage(%A: memref<?xf32>) -> f32 {
   }  { __test_pipelining_loop__ }
   return %r : f32
 }
+
+// -----
+
+// CHECK: @pipeline_op_with_region(%[[ARG0:.+]]: memref<?xf32>, %[[ARG1:.+]]: memref<?xf32>, %[[ARG2:.+]]: memref<?xf32>) {
+// CHECK:   %[[C0:.+]] = arith.constant 0 :
+// CHECK:   %[[C3:.+]] = arith.constant 3 :
+// CHECK:   %[[C1:.+]] = arith.constant 1 :
+// CHECK:   %[[APRO:.+]] = memref.alloc() :
+// CHECK:   %[[BPRO:.+]] = memref.alloc() :
+// CHECK:   %[[ASV0:.+]] = memref.subview %[[ARG0]][%[[C0]]] [8] [1] : 
+// CHECK:   %[[BSV0:.+]] = memref.subview %[[ARG1]][%[[C0]]] [8] [1] : 
+
+// Prologue:
+// CHECK:   %[[PAV0:.+]] = memref.subview %[[APRO]][%[[C0]], 0] [1, 8] [1, 1] :
+// CHECK:   %[[PBV0:.+]] = memref.subview %[[BPRO]][%[[C0]], 0] [1, 8] [1, 1] :
+// CHECK:   memref.copy %[[ASV0]], %[[PAV0]] : 
+// CHECK:   memref.copy %[[BSV0]], %[[PBV0]] : 
+
+// Kernel:
+// CHECK:   %[[R:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[C3]] step %[[C1]] 
+// CHECK-SAME: iter_args(%[[IA:.+]] = %[[PAV0]], %[[IB:.+]] = %[[PBV0:.+]])
+// CHECK:     %[[CV:.+]] = memref.subview %[[ARG2]]
+// CHECK:     linalg.generic
+// CHECK-SAME:  ins(%[[IA]], %[[IB]], %{{.*}} : {{.*}}) outs(%[[CV]] :
+// CHECK:     %[[NEXT:.+]] = arith.addi %[[IV]], %[[C1]] 
+// CHECK:     %[[ASV:.+]] = memref.subview %[[ARG0]][%[[NEXT]]] [8] [1] :
+// CHECK:     %[[NEXT:.+]] = arith.addi %[[IV]], %[[C1]] :
+// CHECK:     %[[BSV:.+]] = memref.subview %[[ARG1]][%[[NEXT]]] [8] [1] :
+// CHECK:     %[[NEXT:.+]] = arith.addi %[[IV]], %[[C1]] :
+// CHECK:     %[[BUFIDX:.+]] = affine.apply
+// CHECK:     %[[APROSV:.+]] = memref.subview %[[APRO]][%[[BUFIDX]], 0] [1, 8] [1, 1] : 
+// CHECK:     %[[BPROSV:.+]] = memref.subview %[[BPRO]][%[[BUFIDX]], 0] [1, 8] [1, 1] : 
+// CHECK:     memref.copy %[[ASV]], %[[APROSV]] :
+// CHECK:     memref.copy %[[BSV]], %[[BPROSV]] :
+// CHECK:     scf.yield %[[APROSV]], %[[BPROSV]] :
+// CHECK:   }
+// CHECK:   %[[CV:.+]] = memref.subview %[[ARG2]][%[[C3]]] [8] [1] :
+// CHECK:   linalg.generic 
+// CHECK-SAME: ins(%[[R]]#0, %[[R]]#1, %{{.*}} : {{.*}}) outs(%[[CV]] :
+
+
+#map = affine_map<(d0)[s0]->(d0 + s0)>
+#map1 = affine_map<(d0)->(d0)>
+#map2 = affine_map<(d0)->()>
+#linalg_attrs = {
+  indexing_maps = [
+      #map1,
+      #map1,
+      #map2,
+      #map1        
+    ],
+  iterator_types = ["parallel"],
+  __test_pipelining_stage__ = 1,
+  __test_pipelining_op_order__ = 2
+}
+func.func @pipeline_op_with_region(%A: memref<?xf32>, %B: memref<?xf32>, %result: memref<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %cf = arith.constant 1.0 : f32
+  %a_buf = memref.alloc() : memref<2x8xf32>
+  %b_buf = memref.alloc() : memref<2x8xf32>
+  scf.for %i0 = %c0 to %c4 step %c1 {
+    %A_view = memref.subview %A[%i0][8][1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref<?xf32> to memref<8xf32, #map>    
+    %B_view = memref.subview %B[%i0][8][1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 4 } : memref<?xf32> to memref<8xf32, #map>
+    %buf_idx = affine.apply  affine_map<(d0)->(d0 mod 2)> (%i0)[] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 5 }
+    %a_buf_view = memref.subview %a_buf[%buf_idx,0][1,8][1,1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 6 } : memref<2x8xf32> to memref<8xf32, #map>
+    %b_buf_view = memref.subview %b_buf[%buf_idx,0][1,8][1,1] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 7 } : memref<2x8xf32> to memref<8xf32, #map>
+    memref.copy %A_view , %a_buf_view {__test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 8} : memref<8xf32, #map> to memref<8xf32, #map>
+    memref.copy %B_view , %b_buf_view {__test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 9} : memref<8xf32, #map> to memref<8xf32, #map>
+    %C_view = memref.subview %result[%i0][8][1] { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : memref<?xf32> to memref<8xf32, #map>
+    %scalar = arith.addf %cf, %cf {__test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1} : f32
+    linalg.generic #linalg_attrs ins(%a_buf_view, %b_buf_view, %scalar : memref<8xf32, #map>, memref<8xf32, #map>, f32)
+      outs(%C_view: memref<8xf32, #map>) {
+      ^bb0(%a: f32, %b: f32, %s: f32, %c: f32):
+        %add = arith.addf %a, %b : f32
+        %accum = arith.addf %add, %c : f32 
+        %accum1 = arith.addf %scalar, %accum : f32
+        %accum2 = arith.addf %s, %accum1 : f32        
+        linalg.yield %accum2 : f32
+    }
+    scf.yield
+  }  { __test_pipelining_loop__ }
+  return
+}

diff  --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index d720d591c82c9..36c41ab0e93bd 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_library(MLIRSCFTestPasses
   EXCLUDE_FROM_LIBMLIR
 
   LINK_LIBS PUBLIC
+  MLIRMemRefDialect
   MLIRPass
   MLIRSCFDialect
   MLIRSCFTransforms

diff  --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index d62f67564cf3e..9018f8eb378d4 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/Dialect/SCF/Utils/Utils.h"
@@ -131,6 +132,7 @@ struct TestSCFPipeliningPass
               std::vector<std::pair<Operation *, unsigned>> &schedule) {
     if (!forOp->hasAttr(kTestPipeliningLoopMarker))
       return;
+
     schedule.resize(forOp.getBody()->getOperations().size() - 1);
     forOp.walk([&schedule](Operation *op) {
       auto attrStage =
@@ -153,17 +155,30 @@ struct TestSCFPipeliningPass
         rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
     // True branch.
     op->moveBefore(&ifOp.getThenRegion().front(),
-                   ifOp.getThenRegion().front().end());
+                   ifOp.getThenRegion().front().begin());
     rewriter.setInsertionPointAfter(op);
-    rewriter.create<scf::YieldOp>(loc, op->getResults());
+    if (op->getNumResults() > 0)
+      rewriter.create<scf::YieldOp>(loc, op->getResults());
     // False branch.
     rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
-    SmallVector<Value> zeros;
-    for (Type type : op->getResultTypes()) {
-      zeros.push_back(
-          rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(type)));
+    SmallVector<Value> elseYieldOperands;
+    elseYieldOperands.reserve(ifOp.getNumResults());
+    if (auto viewOp = dyn_cast<memref::SubViewOp>(op)) {
+      // For sub-views, just clone the op.
+      // NOTE: This is okay in the test because we use dynamic memref sizes, so
+      // the verifier will not complain. Otherwise, we may create a logically
+      // out-of-bounds view and a 
diff erent technique should be used.
+      Operation *opClone = rewriter.clone(*op);
+      elseYieldOperands.append(opClone->result_begin(), opClone->result_end());
+    } else {
+      // Default to assuming constant numeric values.
+      for (Type type : op->getResultTypes()) {
+        elseYieldOperands.push_back(rewriter.create<arith::ConstantOp>(
+            loc, rewriter.getZeroAttr(type)));
+      }
     }
-    rewriter.create<scf::YieldOp>(loc, zeros);
+    if (op->getNumResults() > 0)
+      rewriter.create<scf::YieldOp>(loc, elseYieldOperands);
     return ifOp.getOperation();
   }
 
@@ -187,7 +202,7 @@ struct TestSCFPipeliningPass
   }
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<arith::ArithmeticDialect>();
+    registry.insert<arith::ArithmeticDialect, memref::MemRefDialect>();
   }
 
   void runOnOperation() override {


        


More information about the Mlir-commits mailing list