[Mlir-commits] [mlir] Introduce new Unroll And Jam loop transform for SCF/Affine loops (PR #94142)

Aviad Cohen llvmlistbot at llvm.org
Sun Jun 16 20:27:18 PDT 2024


https://github.com/AviadCo updated https://github.com/llvm/llvm-project/pull/94142

>From e02c2ab0171f6886b30321e1951acc43d539408e Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Wed, 29 May 2024 11:25:39 +0300
Subject: [PATCH] Introduce new Unroll And Jam loop transform for SCF/Affine
 loops

Unroll And Jam was supported in affine dialect long time ago using pass.
This commit exposes the pattern using transform and in addition adds
partial support for SCF loops.
---
 .../SCF/TransformOps/SCFTransformOps.td       |  36 ++-
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   |   8 +
 .../mlir/Interfaces/LoopLikeInterface.h       |  33 +++
 mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp   |  40 +---
 .../SCF/TransformOps/SCFTransformOps.cpp      |  34 ++-
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 226 +++++++++++++++++-
 .../Dialect/SCF/transform-ops-invalid.mlir    |  99 ++++++++
 mlir/test/Dialect/SCF/transform-ops.mlir      | 215 +++++++++++++++++
 8 files changed, 633 insertions(+), 58 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 3d7fe7b0f093f..7bf914f6456ce 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -269,7 +269,41 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
     This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
     in the return. If all the operations referred to by the `target` operand
     unroll properly, the transform succeeds. Otherwise the transform produces a
-    silencebale failure.
+    silenceable failure.
+
+    Does not return handles as the operation may result in the loop being
+    removed after a full unrolling.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       ConfinedAttr<I64Attr, [IntPositive]>:$factor);
+
+  let assemblyFormat = "$target attr-dict `:` type($target)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+def LoopUnrollAndJamOp : Op<Transform_Dialect, "loop.unroll_and_jam",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let summary = "Unrolls and jam the given loop with the given unroll factor";
+  let description = [{
+    Unrolls & jams each loop associated with the given handle to have up to the given
+    number of loop body copies per iteration. If the unroll factor is larger
+    than the loop trip count, the latter is used as the unroll factor instead.
+
+    #### Return modes
+
+    This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
+    in the return. If all the operations referred to by the `target` operand
+    unroll properly, the transform succeeds. Otherwise the transform produces a
+    silenceable failure.
 
     Does not return handles as the operation may result in the loop being
     removed after a full unrolling.
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index da3fe3ceb86be..fea151b393152 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -120,6 +120,14 @@ LogicalResult loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
+/// Unrolls and jams this `scf.for` operation by the specified unroll factor.
+/// Returns failure if the loop cannot be unrolled either due to restrictions or
+/// due to invalid unroll factors. In case of unroll factor of 1, the function
+/// bails out without doing anything (returns success). Currently, only constant
+/// trip count that are divided by the unroll factor is supported. Currently,
+/// for operations with results are not supported.
+LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor);
+
 /// Transform a loop with a strictly positive step
 ///   for %i = %lb to %ub step %s
 /// into a 0-based loop with step 1
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 42609e824c86a..9925fc6ce6ca9 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -48,6 +48,39 @@ class HasParallelRegion : public TraitBase<ConcreteType, HasParallelRegion> {
 };
 
 } // namespace OpTrait
+
+// Gathers all maximal sub-blocks of operations that do not themselves
+// include a `OpTy` (an operation could have a descendant `OpTy` though
+// in its tree). Ignores the block terminators.
+template <typename OpTy>
+struct JamBlockGatherer {
+  // Store iterators to the first and last op of each sub-block found.
+  SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+
+  // This is a linear time walk.
+  void walk(Operation *op) {
+    for (Region &region : op->getRegions())
+      for (Block &block : region)
+        walk(block);
+  }
+
+  void walk(Block &block) {
+    assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
+           "expected block to have a terminator");
+    for (Block::iterator it = block.begin(), e = std::prev(block.end());
+         it != e;) {
+      Block::iterator subBlockStart = it;
+      while (it != e && !isa<OpTy>(&*it))
+        ++it;
+      if (it != subBlockStart)
+        subBlocks.emplace_back(subBlockStart, std::prev(it));
+      // Process all for ops that appear next.
+      while (it != e && isa<OpTy>(&*it))
+        walk(&*it++);
+    }
+  }
+};
+
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index c934f229bb6c4..605737542e9fc 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -37,16 +37,6 @@ using namespace affine;
 using namespace presburger;
 using llvm::SmallMapVector;
 
-namespace {
-// This structure is to pass and return sets of loop parameters without
-// confusing the order.
-struct LoopParams {
-  Value lowerBound;
-  Value upperBound;
-  Value step;
-};
-} // namespace
-
 /// Computes the cleanup loop lower bound of the loop being unrolled with
 /// the specified unroll factor; this bound will also be upper bound of the main
 /// part of the unrolled loop. Computes the bound as an AffineMap with its
@@ -1101,34 +1091,6 @@ static bool areInnerBoundsInvariant(AffineForOp forOp) {
   return !walkResult.wasInterrupted();
 }
 
-// Gathers all maximal sub-blocks of operations that do not themselves
-// include a for op (a operation could have a descendant for op though
-// in its tree).  Ignore the block terminators.
-struct JamBlockGatherer {
-  // Store iterators to the first and last op of each sub-block found.
-  std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
-
-  // This is a linear time walk.
-  void walk(Operation *op) {
-    for (auto &region : op->getRegions())
-      for (auto &block : region)
-        walk(block);
-  }
-
-  void walk(Block &block) {
-    for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
-      auto subBlockStart = it;
-      while (it != e && !isa<AffineForOp>(&*it))
-        ++it;
-      if (it != subBlockStart)
-        subBlocks.emplace_back(subBlockStart, std::prev(it));
-      // Process all for ops that appear next.
-      while (it != e && isa<AffineForOp>(&*it))
-        walk(&*it++);
-    }
-  }
-};
-
 /// Unrolls and jams this loop by the specified factor.
 LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
                                                   uint64_t unrollJamFactor) {
@@ -1158,7 +1120,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
     return failure();
 
   // Gather all sub-blocks to jam upon the loop being unrolled.
-  JamBlockGatherer jbg;
+  JamBlockGatherer<AffineForOp> jbg;
   jbg.walk(forOp);
   auto &subBlocks = jbg.subBlocks;
 
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 15b8aca5e267f..d065eb167a357 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -348,12 +348,36 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
     result = loopUnrollByFactor(scfFor, getFactor());
   else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
     result = loopUnrollByFactor(affineFor, getFactor());
+  else
+    return emitSilenceableError()
+           << "failed to unroll, incorrect type of payload";
+
+  if (failed(result))
+    return emitSilenceableError() << "failed to unroll";
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// LoopUnrollAndJamOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *op,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  LogicalResult result(failure());
+  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
+    result = loopUnrollJamByFactor(scfFor, getFactor());
+  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
+    result = loopUnrollJamByFactor(affineFor, getFactor());
+  else
+    return emitSilenceableError()
+           << "failed to unroll and jam, incorrect type of payload";
+
+  if (failed(result))
+    return emitSilenceableError() << "failed to unroll and jam";
 
-  if (failed(result)) {
-    DiagnosedSilenceableFailure diag = emitSilenceableError()
-                                       << "failed to unroll";
-    return diag;
-  }
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index ff5e3a002263d..efe619f1a89f9 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -26,10 +26,16 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/MathExtras.h"
+#include <cstdint>
 
 using namespace mlir;
 
+#define DEBUG_TYPE "scf-utils"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
     RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
     ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
@@ -287,6 +293,25 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
   return builder.create<arith::DivUIOp>(loc, sum, divisor);
 }
 
+// Returns the trip count of `forOp` if its' low bound, high bound and step are
+// constants, or optional otherwise. Trip count is computed as ceilDiv(highBound
+// - lowBound, step).
+static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
+  std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
+  std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
+  std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
+  if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value())
+    return {};
+
+  // Constant loop bounds computation.
+  int64_t lbCst = lbCstOp.value();
+  int64_t ubCst = ubCstOp.value();
+  int64_t stepCst = stepCstOp.value();
+  assert(lbCst >= 0 && ubCst >= 0 && stepCst > 0 &&
+         "expected positive loop bounds and step");
+  return llvm::divideCeilSigned(ubCst - lbCst, stepCst);
+}
+
 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -363,25 +388,21 @@ LogicalResult mlir::loopUnrollByFactor(
   Value stepUnrolled;
   bool generateEpilogueLoop = true;
 
-  std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
-  std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
-  std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
-  if (lbCstOp && ubCstOp && stepCstOp) {
+  std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
+  if (constTripCount) {
     // Constant loop bounds computation.
-    int64_t lbCst = lbCstOp.value();
-    int64_t ubCst = ubCstOp.value();
-    int64_t stepCst = stepCstOp.value();
-    assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
-           "expected positive loop bounds and step");
-    int64_t tripCount = llvm::divideCeilSigned(ubCst - lbCst, stepCst);
-
+    int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
+    int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
+    int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
     if (unrollFactor == 1) {
-      if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter)))
+      if (*constTripCount == 1 &&
+          failed(forOp.promoteIfSingleIteration(rewriter)))
         return failure();
       return success();
     }
 
-    int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
+    int64_t tripCountEvenMultiple =
+        *constTripCount - (*constTripCount % unrollFactor);
     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
     int64_t stepUnrolledCst = stepCst * unrollFactor;
 
@@ -464,6 +485,185 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
+/// Check if bounds of all inner loops are defined outside of `forOp`
+/// and return false if not.
+static bool areInnerBoundsInvariant(scf::ForOp forOp) {
+  auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+    if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+      return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+  return !walkResult.wasInterrupted();
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
+                                          uint64_t unrollJamFactor) {
+  assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
+
+  if (unrollJamFactor == 1)
+    return success();
+
+  // If any control operand of any inner loop of `forOp` is defined within
+  // `forOp`, no unroll jam.
+  if (!areInnerBoundsInvariant(forOp)) {
+    LDBG("failed to unroll and jam: inner bounds are not invariant");
+    return failure();
+  }
+
+  // Currently, for operations with results are not supported.
+  if (forOp->getNumResults() > 0) {
+    LDBG("failed to unroll and jam: unsupported loop with results");
+    return failure();
+  }
+
+  // Currently, only constant trip count that divided by the unroll factor is
+  // supported.
+  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount.has_value()) {
+    // If the trip count is dynamic, do not unroll & jam.
+    LDBG("failed to unroll and jam: trip count could not be determined");
+    return failure();
+  }
+  if (unrollJamFactor > *tripCount) {
+    LDBG("unroll and jam factor is greater than trip count, set factor to trip "
+         "count");
+    unrollJamFactor = *tripCount;
+  } else if (*tripCount % unrollJamFactor != 0) {
+    LDBG("failed to unroll and jam: unsupported trip count that is not a "
+         "multiple of unroll jam factor");
+    return failure();
+  }
+
+  // Nothing in the loop body other than the terminator.
+  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
+    return success();
+
+  // Gather all sub-blocks to jam upon the loop being unrolled.
+  JamBlockGatherer<scf::ForOp> jbg;
+  jbg.walk(forOp);
+  auto &subBlocks = jbg.subBlocks;
+
+  // Collect inner loops.
+  SmallVector<scf::ForOp> innerLoops;
+  forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+
+  // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
+  // iteration. There are (`unrollJamFactor` - 1) iterations.
+  SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);
+
+  // For any loop with iter_args, replace it with a new loop that has
+  // `unrollJamFactor` copies of its iterOperands, iter_args and yield
+  // operands.
+  SmallVector<scf::ForOp> newInnerLoops;
+  IRRewriter rewriter(forOp.getContext());
+  for (scf::ForOp oldForOp : innerLoops) {
+    SmallVector<Value> dupIterOperands, dupYieldOperands;
+    ValueRange oldIterOperands = oldForOp.getInits();
+    ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
+    ValueRange oldYieldOperands =
+        cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
+    // Get additional iterOperands, iterArgs, and yield operands. We will
+    // fix iterOperands and yield operands after cloning of sub-blocks.
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
+      dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
+    }
+    // Create a new loop with additional iterOperands, iter_args and yield
+    // operands. This new loop will take the loop body of the original loop.
+    bool forOpReplaced = oldForOp == forOp;
+    scf::ForOp newForOp =
+        cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
+            rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
+            [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+              return dupYieldOperands;
+            }));
+    newInnerLoops.push_back(newForOp);
+    // `forOp` has been replaced with a new loop.
+    if (forOpReplaced)
+      forOp = newForOp;
+    // Update `operandMaps` for `newForOp` iterArgs and results.
+    ValueRange newIterArgs = newForOp.getRegionIterArgs();
+    unsigned oldNumIterArgs = oldIterArgs.size();
+    ValueRange newResults = newForOp.getResults();
+    unsigned oldNumResults = newResults.size() / unrollJamFactor;
+    assert(oldNumIterArgs == oldNumResults &&
+           "oldNumIterArgs must be the same as oldNumResults");
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+        // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
+        // results. Update `operandMaps[i - 1]` to map old iterArgs and results
+        // to those in the `i`th new set.
+        operandMaps[i - 1].map(newIterArgs[j],
+                               newIterArgs[i * oldNumIterArgs + j]);
+        operandMaps[i - 1].map(newResults[j],
+                               newResults[i * oldNumResults + j]);
+      }
+    }
+  }
+
+  // Scale the step of loop being unroll-jammed by the unroll-jam factor.
+  rewriter.setInsertionPoint(forOp);
+  int64_t step = forOp.getConstantStep()->getSExtValue();
+  auto newStep = rewriter.createOrFold<arith::MulIOp>(
+      forOp.getLoc(), forOp.getStep(),
+      rewriter.createOrFold<arith::ConstantOp>(
+          forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
+  forOp.setStep(newStep);
+  auto forOpIV = forOp.getInductionVar();
+
+  // Unroll and jam (appends unrollJamFactor - 1 additional copies).
+  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+    for (auto &subBlock : subBlocks) {
+      // Builder to insert unroll-jammed bodies. Insert right at the end of
+      // sub-block.
+      OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
+
+      // If the induction variable is used, create a remapping to the value for
+      // this unrolled instance.
+      if (!forOpIV.use_empty()) {
+        // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
+        auto ivTag = builder.createOrFold<arith::ConstantOp>(
+            forOp.getLoc(), builder.getIndexAttr(step * i));
+        auto ivUnroll =
+            builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
+        operandMaps[i - 1].map(forOpIV, ivUnroll);
+      }
+      // Clone the sub-block being unroll-jammed.
+      for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
+        builder.clone(*it, operandMaps[i - 1]);
+    }
+    // Fix iterOperands and yield op operands of newly created loops.
+    for (auto newForOp : newInnerLoops) {
+      unsigned oldNumIterOperands =
+          newForOp.getNumRegionIterArgs() / unrollJamFactor;
+      unsigned numControlOperands = newForOp.getNumControlOperands();
+      auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+      unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
+      assert(oldNumIterOperands == oldNumYieldOperands &&
+             "oldNumIterOperands must be the same as oldNumYieldOperands");
+      for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+        // The `i`th duplication of an old iterOperand or yield op operand
+        // needs to be replaced with a mapped value from `operandMaps[i - 1]`
+        // if such mapped value exists.
+        newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+                            operandMaps[i - 1].lookupOrDefault(
+                                newForOp.getOperand(numControlOperands + j)));
+        yieldOp.setOperand(
+            i * oldNumYieldOperands + j,
+            operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
+      }
+    }
+  }
+
+  // Promote the loop body up if this has turned into a single iteration loop.
+  (void)forOp.promoteIfSingleIteration(rewriter);
+  return success();
+}
+
 Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
                                      OpFoldResult lb, OpFoldResult ub,
                                      OpFoldResult step) {
diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
index a01a7995d6189..742b8a2861839 100644
--- a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
@@ -41,6 +41,105 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @loop_unroll_and_jam_unsupported_trip_count_not_multiple_of_factor() {
+  %c0 = arith.constant 0 : index
+  %c40 = arith.constant 40 : index
+  %c2 = arith.constant 2 : index
+  scf.for %i = %c0 to %c40 step %c2 {
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    // expected-error @below {{failed to unroll and jam}}
+    transform.loop.unroll_and_jam %1 { factor = 3 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @loop_unroll_and_jam_unsupported_loop_with_results() -> index {
+  %c0 = arith.constant 0 : index
+  %c40 = arith.constant 40 : index
+  %c2 = arith.constant 2 : index
+  %sum = scf.for %i = %c0 to %c40 step %c2 iter_args(%does_not_alias_aggregated = %c0) -> (index) {
+    %sum = arith.addi %i, %i : index
+    scf.yield %sum : index
+  }
+  return %sum : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    // expected-error @below {{failed to unroll and jam}}
+    transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @loop_unroll_and_jam_unsupported_dynamic_trip_count(%arg0: memref<96x128xi8, 3>, %arg1: memref<128xi8, 3>) {
+  %c96 = arith.constant 96 : index
+  %c1 = arith.constant 1 : index
+  %c128 = arith.constant 128 : index
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  scf.for %arg4 = %c0 to %c4 step %c1 {
+    scf.for %arg2 = %c0 to %c128 step %arg4 {
+      %3 = memref.load %arg1[%arg2] : memref<128xi8, 3>
+      %sum = scf.for %arg3 = %c0 to %c96 step %c1 iter_args(%does_not_alias_aggregated = %3) -> (i8) {
+      %2 = memref.load %arg0[%arg3, %arg2] : memref<96x128xi8, 3>
+      %4 = arith.addi %2, %3 : i8
+      scf.yield %4 : i8
+      }
+      memref.store %sum, %arg1[%arg2] : memref<128xi8, 3>
+    }
+    scf.yield
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["memref.store"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    // expected-error @below {{failed to unroll and jam}}
+    transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @loop_unroll_and_jam_unsupported_dynamic_trip_count(%upper_bound: index) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  scf.for %i = %c0 to %upper_bound step %c2 {
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    // expected-error @below {{failed to unroll and jam}}
+    transform.loop.unroll_and_jam %1 { factor = 2 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
 func.func private @cond() -> i1
 func.func private @body()
 
diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
index b91225bf45b96..d9445182769e7 100644
--- a/mlir/test/Dialect/SCF/transform-ops.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops.mlir
@@ -168,6 +168,221 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @loop_unroll_and_jam_op
+func.func @loop_unroll_and_jam_op() {
+  // CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
+  // CHECK:           %[[VAL_1:.*]] = arith.constant 40 : index
+  // CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
+  // CHECK:           %[[FACTOR:.*]] = arith.constant 4 : index
+  // CHECK:           %[[STEP:.*]] = arith.constant 8 : index
+  %c0 = arith.constant 0 : index
+  %c40 = arith.constant 40 : index
+  %c2 = arith.constant 2 : index
+  // CHECK:           scf.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[STEP]] {
+  scf.for %i = %c0 to %c40 step %c2 {
+  // CHECK:             %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : index
+  // CHECK:             %[[VAL_7:.*]] = arith.constant 2 : index
+  // CHECK:             %[[VAL_8:.*]] = arith.addi %[[VAL_5]], %[[VAL_7]] : index
+  // CHECK:             %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_8]] : index
+  // CHECK:             %[[VAL_10:.*]] = arith.constant 4 : index
+  // CHECK:             %[[VAL_11:.*]] = arith.addi %[[VAL_5]], %[[VAL_10]] : index
+  // CHECK:             %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_11]] : index
+  // CHECK:             %[[VAL_13:.*]] = arith.constant 6 : index
+  // CHECK:             %[[VAL_14:.*]] = arith.addi %[[VAL_5]], %[[VAL_13]] : index
+  // CHECK:             %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_14]] : index
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+// CHECK:       %[[VAL_0:.*]]: memref<96x128xi8, 3>, %[[VAL_1:.*]]: memref<128xi8, 3>) {
+func.func private @loop_unroll_and_jam_op(%arg0: memref<96x128xi8, 3>, %arg1: memref<128xi8, 3>) {
+  // CHECK:           %[[UB_INNER:.*]] = arith.constant 96
+  // CHECK:           %[[STEP_INNER:.*]] = arith.constant 1
+  // CHECK:           %[[UB_OUTER:.*]] = arith.constant 128
+  // CHECK:           %[[LB:.*]] = arith.constant 0
+  // CHECK:           %[[UNUSED:.*]] = arith.constant 4
+  // CHECK:           %[[UNROLL_FACTOR:.*]] = arith.constant 4
+  %c96 = arith.constant 96 : index
+  %c1 = arith.constant 1 : index
+  %c128 = arith.constant 128 : index
+  %c0 = arith.constant 0 : index
+  // CHECK:           scf.for %[[OUTER_I:.*]] = %[[LB]] to %[[UB_OUTER]] step %[[UNROLL_FACTOR]] {
+	scf.for %arg2 = %c0 to %c128 step %c1 {
+    // CHECK:             %[[LOAD0:.*]] = memref.load %[[VAL_1]]{{\[}}%[[OUTER_I]]]
+    // CHECK:             %[[ONE_0:.*]] = arith.constant 1
+    // CHECK:             %[[INC_LOAD1:.*]] = arith.addi %[[OUTER_I]], %[[ONE_0]]
+    // CHECK:             %[[LOAD1:.*]] = memref.load %[[VAL_1]]{{\[}}%[[INC_LOAD1]]]
+    // CHECK:             %[[TWO_0:.*]] = arith.constant 2
+    // CHECK:             %[[INC_LOAD2:.*]] = arith.addi %[[OUTER_I]], %[[TWO_0]]
+    // CHECK:             %[[LOAD2:.*]] = memref.load %[[VAL_1]]{{\[}}%[[INC_LOAD2]]]
+    // CHECK:             %[[THREE_0:.*]] = arith.constant 3
+    // CHECK:             %[[INC_LOAD3:.*]] = arith.addi %[[OUTER_I]], %[[THREE_0]]
+    // CHECK:             %[[LOAD3:.*]] = memref.load %[[VAL_1]]{{\[}}%[[INC_LOAD3]]]
+	  %3 = memref.load %arg1[%arg2] : memref<128xi8, 3>
+    // CHECK:             %[[VAL_19:.*]]:4 = scf.for %[[VAL_20:.*]] = %[[LB]] to %[[UB_INNER]] step %[[STEP_INNER]] iter_args(%[[VAL_21:.*]] = %[[LOAD0]], %[[VAL_22:.*]] = %[[LOAD1]], %[[VAL_23:.*]] = %[[LOAD2]], %[[VAL_24:.*]] = %[[LOAD3]])
+	  %sum = scf.for %arg3 = %c0 to %c96 step %c1 iter_args(%does_not_alias_aggregated = %3) -> (i8) {
+    // CHECK:               %[[LOAD0_INNER:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[OUTER_I]]]
+    // CHECK:               %[[SUM_0:.*]] = arith.addi %[[LOAD0_INNER]], %[[LOAD0]]
+    // CHECK:               %[[ONE_1:.*]] = arith.constant 1
+    // CHECK:               %[[INC1_INNER:.*]] = arith.addi %[[OUTER_I]], %[[ONE_1]]
+    // CHECK:               %[[LOAD1_INNER:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[INC1_INNER]]]
+    // CHECK:               %[[SUM_1:.*]] = arith.addi %[[LOAD1_INNER]], %[[LOAD1]]
+    // CHECK:               %[[TWO_1:.*]] = arith.constant 2
+    // CHECK:               %[[INC2_INNER:.*]] = arith.addi %[[OUTER_I]], %[[TWO_1]]
+    // CHECK:               %[[LOAD2_INNER:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[INC2_INNER]]]
+    // CHECK:               %[[SUM_2:.*]] = arith.addi %[[LOAD2_INNER]], %[[LOAD2]]
+    // CHECK:               %[[THREE_1:.*]] = arith.constant 3
+    // CHECK:               %[[INC3_INNER:.*]] = arith.addi %[[OUTER_I]], %[[THREE_1]]
+    // CHECK:               %[[LOAD3_INNER:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[INC3_INNER]]]
+    // CHECK:               %[[SUM_3:.*]] = arith.addi %[[LOAD3_INNER]], %[[LOAD3]]
+		%2 = memref.load %arg0[%arg3, %arg2] : memref<96x128xi8, 3>
+		%4 = arith.addi %2, %3 : i8
+    // CHECK:               scf.yield %[[SUM_0]], %[[SUM_1]], %[[SUM_2]], %[[SUM_3]]
+		scf.yield %4 : i8
+	  }
+	  memref.store %sum, %arg1[%arg2] : memref<128xi8, 3>
+    // CHECK:             memref.store %[[VAL_39:.*]]#0, %[[VAL_1]]{{\[}}%[[OUTER_I]]]
+    // CHECK:             %[[ONE_2:.*]] = arith.constant 1
+    // CHECK:             %[[INC_STORE1:.*]] = arith.addi %[[OUTER_I]], %[[ONE_2]]
+    // CHECK:             memref.store %[[VAL_39]]#1, %[[VAL_1]]{{\[}}%[[INC_STORE1]]]
+    // CHECK:             %[[TWO_2:.*]] = arith.constant 2
+    // CHECK:             %[[INC_STORE2:.*]] = arith.addi %[[OUTER_I]], %[[TWO_2]]
+    // CHECK:             memref.store %[[VAL_39]]#2, %[[VAL_1]]{{\[}}%[[INC_STORE2]]]
+    // CHECK:             %[[THREE_2:.*]] = arith.constant 3
+    // CHECK:             %[[INC_STORE3:.*]] = arith.addi %[[OUTER_I]], %[[THREE_2]]
+    // CHECK:             memref.store %[[VAL_39]]#3, %[[VAL_1]]{{\[}}%[[INC_STORE3]]]
+	}
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["memref.store"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+func.func @loop_unroll_and_jam_op() {
+  // CHECK:           %[[ZERO:.*]] = arith.constant 0
+  // CHECK:           %[[UNUSED_FACTOR:.*]] = arith.constant 4
+  // CHECK:           %[[UNUSED_STEP:.*]] = arith.constant 2
+  // CHECK:           %[[UNUSED_STEP2:.*]] = arith.constant 2
+  // CHECK:           %[[UNUSED_UB:.*]] = arith.constant 4
+  // CHECK:           %[[ITER_0_RES:.*]] = arith.addi %[[ZERO]], %[[ZERO]]
+  // CHECK:           %[[TWO:.*]] = arith.constant 2
+  // CHECK:           %[[STEP_1_I:.*]] = arith.addi %[[ZERO]], %[[TWO]]
+  // CHECK:           %[[ITER_1_RES:.*]] = arith.addi %[[STEP_1_I]], %[[STEP_1_I]]
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c2 = arith.constant 2 : index
+  scf.for %i = %c0 to %c4 step %c2 {
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+func.func @loop_unroll_and_jam_op() {
+  // CHECK:           %[[LB:.*]] = arith.constant 0
+  // CHECK:           %[[UB:.*]] = arith.constant 4
+  // CHECK:           %[[STEP:.*]] = arith.constant 2
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c2 = arith.constant 2 : index
+  // CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+  scf.for %i = %c0 to %c4 step %c2 {
+    scf.yield
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    transform.loop.unroll_and_jam %1 { factor = 2 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+// CHECK:  %[[VAL_0:.*]]: memref<21x30xf32, 1>, %[[INIT0:.*]]: f32, %[[INIT1:.*]]: f32) {
+func.func @loop_unroll_and_jam_op(%arg0: memref<21x30xf32, 1>, %init : f32, %init1 : f32) {
+  // CHECK:           %[[LAST_OUT_ITER:.*]] = arith.constant 20
+  // CHECK:           %[[VAL_4:.*]]:2 = affine.for %[[OUTER_I:.*]] = 0 to 20 step 2 iter_args(%[[VAL_6:.*]] = %[[INIT0]], %[[VAL_7:.*]] = %[[INIT0]])
+  %0 = affine.for %arg3 = 0 to 21 iter_args(%arg4 = %init) -> (f32) {
+    // CHECK:             %[[VAL_8:.*]]:2 = affine.for %[[INNER_I:.*]] = 0 to 30 iter_args(%[[SUM0:.*]] = %[[INIT1]], %[[SUM1:.*]] = %[[INIT1]])
+    %1 = affine.for %arg5 = 0 to 30 iter_args(%arg6 = %init1) -> (f32) {
+      // CHECK:               %[[LOAD0:.*]] = affine.load %[[VAL_0]]{{\[}}%[[OUTER_I]], %[[INNER_I]]]
+      // CHECK:               %[[ITER_SUM0:.*]] = arith.addf %[[SUM0]], %[[LOAD0]]
+      // CHECK:               %[[APPLY_OUTER_I:.*]] = affine.apply #map(%[[OUTER_I]])
+      // CHECK:               %[[LOAD1:.*]] = affine.load %[[VAL_0]]{{\[}}%[[APPLY_OUTER_I]], %[[INNER_I]]]
+      // CHECK:               %[[ITER_SUM1:.*]] = arith.addf %[[SUM1]], %[[LOAD1]]
+      // CHECK:               affine.yield %[[ITER_SUM0]], %[[ITER_SUM1]]
+      %3 = affine.load %arg0[%arg3, %arg5] : memref<21x30xf32, 1>
+      %4 = arith.addf %arg6, %3 : f32
+      affine.yield %4 : f32
+    }
+    // CHECK:             %[[MUL0:.*]] = arith.mulf %[[VAL_6]], %[[VAL_18:.*]]#0
+    // CHECK:             %[[VAL_19:.*]] = affine.apply #map(%[[OUTER_I]])
+    // CHECK:             %[[MUL1:.*]] = arith.mulf %[[VAL_7]], %[[VAL_18]]#1
+    // CHECK:             affine.yield %[[MUL0]], %[[MUL1]]
+    // CHECK:           }
+    // CHECK:           %[[VAL_21:.*]] = arith.mulf %[[VAL_22:.*]]#0, %[[VAL_22]]#1
+    // CHECK:           %[[VAL_23:.*]] = affine.for %[[SUFFIX_I:.*]] = 0 to 30 iter_args(%[[ITER_I:.*]] = %[[INIT1]])
+    // CHECK:             %[[LOAD_SUFFIX:.*]] = affine.load %[[VAL_0]]{{\[}}%[[LAST_OUT_ITER]], %[[SUFFIX_I]]]
+    // CHECK:             %[[RES:.*]] = arith.addf %[[ITER_I]], %[[LOAD_SUFFIX]]
+    // CHECK:             affine.yield %[[RES]]
+    %2 = arith.mulf %arg4, %1 : f32
+    affine.yield %2 : f32
+  }
+  // CHECK:           %[[VAL_28:.*]] = arith.mulf %[[VAL_21]], %[[VAL_29:.*]]
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "affine.for"} : (!transform.any_op) -> !transform.op<"affine.for">
+    %2 = transform.get_parent_op %1 {op_name = "affine.for"} : (!transform.op<"affine.for">) -> !transform.op<"affine.for">
+    transform.loop.unroll_and_jam %2 { factor = 2 } : !transform.op<"affine.for">
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @loop_unroll_op() {
   %c0 = arith.constant 0 : index
   %c42 = arith.constant 42 : index



More information about the Mlir-commits mailing list