[Mlir-commits] [mlir] cb8bd6f - Introduce new Unroll And Jam loop transform for SCF/Affine loops (#94142)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 21 05:48:14 PDT 2024


Author: Aviad Cohen
Date: 2024-06-21T15:48:11+03:00
New Revision: cb8bd6f77235ee15bbe549c8f3486392b8966447

URL: https://github.com/llvm/llvm-project/commit/cb8bd6f77235ee15bbe549c8f3486392b8966447
DIFF: https://github.com/llvm/llvm-project/commit/cb8bd6f77235ee15bbe549c8f3486392b8966447.diff

LOG: Introduce new Unroll And Jam loop transform for SCF/Affine loops (#94142)

Unroll And Jam was supported in affine dialect long time ago using pass.
This commit exposes the pattern using transform and in addition adds
partial support for SCF loops.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
    mlir/include/mlir/Dialect/SCF/Utils/Utils.h
    mlir/include/mlir/Interfaces/LoopLikeInterface.h
    mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
    mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
    mlir/lib/Dialect/SCF/Utils/Utils.cpp
    mlir/test/Dialect/SCF/transform-ops-invalid.mlir
    mlir/test/Dialect/SCF/transform-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 3d7fe7b0f093f..7bf914f6456ce 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -269,7 +269,41 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
     This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
     in the return. If all the operations referred to by the `target` operand
     unroll properly, the transform succeeds. Otherwise the transform produces a
-    silencebale failure.
+    silenceable failure.
+
+    Does not return handles as the operation may result in the loop being
+    removed after a full unrolling.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       ConfinedAttr<I64Attr, [IntPositive]>:$factor);
+
+  let assemblyFormat = "$target attr-dict `:` type($target)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+def LoopUnrollAndJamOp : Op<Transform_Dialect, "loop.unroll_and_jam",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let summary = "Unrolls and jam the given loop with the given unroll factor";
+  let description = [{
+    Unrolls & jams each loop associated with the given handle to have up to the given
+    number of loop body copies per iteration. If the unroll factor is larger
+    than the loop trip count, the latter is used as the unroll factor instead.
+
+    #### Return modes
+
+    This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
+    in the return. If all the operations referred to by the `target` operand
+    unroll properly, the transform succeeds. Otherwise the transform produces a
+    silenceable failure.
 
     Does not return handles as the operation may result in the loop being
     removed after a full unrolling.

diff  --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index da3fe3ceb86be..fea151b393152 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -120,6 +120,14 @@ LogicalResult loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
+/// Unrolls and jams this `scf.for` operation by the specified unroll factor.
+/// Returns failure if the loop cannot be unrolled either due to restrictions or
+/// due to invalid unroll factors. In case of unroll factor of 1, the function
+/// bails out without doing anything (returns success). Currently, only constant
+/// trip count that are divided by the unroll factor is supported. Currently,
+/// for operations with results are not supported.
+LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor);
+
 /// Transform a loop with a strictly positive step
 ///   for %i = %lb to %ub step %s
 /// into a 0-based loop with step 1

diff  --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 42609e824c86a..9925fc6ce6ca9 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -48,6 +48,39 @@ class HasParallelRegion : public TraitBase<ConcreteType, HasParallelRegion> {
 };
 
 } // namespace OpTrait
+
+// Gathers all maximal sub-blocks of operations that do not themselves
+// include a `OpTy` (an operation could have a descendant `OpTy` though
+// in its tree). Ignores the block terminators.
+template <typename OpTy>
+struct JamBlockGatherer {
+  // Store iterators to the first and last op of each sub-block found.
+  SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+
+  // This is a linear time walk.
+  void walk(Operation *op) {
+    for (Region &region : op->getRegions())
+      for (Block &block : region)
+        walk(block);
+  }
+
+  void walk(Block &block) {
+    assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
+           "expected block to have a terminator");
+    for (Block::iterator it = block.begin(), e = std::prev(block.end());
+         it != e;) {
+      Block::iterator subBlockStart = it;
+      while (it != e && !isa<OpTy>(&*it))
+        ++it;
+      if (it != subBlockStart)
+        subBlocks.emplace_back(subBlockStart, std::prev(it));
+      // Process all for ops that appear next.
+      while (it != e && isa<OpTy>(&*it))
+        walk(&*it++);
+    }
+  }
+};
+
 } // namespace mlir
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index c934f229bb6c4..605737542e9fc 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -37,16 +37,6 @@ using namespace affine;
 using namespace presburger;
 using llvm::SmallMapVector;
 
-namespace {
-// This structure is to pass and return sets of loop parameters without
-// confusing the order.
-struct LoopParams {
-  Value lowerBound;
-  Value upperBound;
-  Value step;
-};
-} // namespace
-
 /// Computes the cleanup loop lower bound of the loop being unrolled with
 /// the specified unroll factor; this bound will also be upper bound of the main
 /// part of the unrolled loop. Computes the bound as an AffineMap with its
@@ -1101,34 +1091,6 @@ static bool areInnerBoundsInvariant(AffineForOp forOp) {
   return !walkResult.wasInterrupted();
 }
 
-// Gathers all maximal sub-blocks of operations that do not themselves
-// include a for op (a operation could have a descendant for op though
-// in its tree).  Ignore the block terminators.
-struct JamBlockGatherer {
-  // Store iterators to the first and last op of each sub-block found.
-  std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
-
-  // This is a linear time walk.
-  void walk(Operation *op) {
-    for (auto &region : op->getRegions())
-      for (auto &block : region)
-        walk(block);
-  }
-
-  void walk(Block &block) {
-    for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
-      auto subBlockStart = it;
-      while (it != e && !isa<AffineForOp>(&*it))
-        ++it;
-      if (it != subBlockStart)
-        subBlocks.emplace_back(subBlockStart, std::prev(it));
-      // Process all for ops that appear next.
-      while (it != e && isa<AffineForOp>(&*it))
-        walk(&*it++);
-    }
-  }
-};
-
 /// Unrolls and jams this loop by the specified factor.
 LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
                                                   uint64_t unrollJamFactor) {
@@ -1158,7 +1120,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
     return failure();
 
   // Gather all sub-blocks to jam upon the loop being unrolled.
-  JamBlockGatherer jbg;
+  JamBlockGatherer<AffineForOp> jbg;
   jbg.walk(forOp);
   auto &subBlocks = jbg.subBlocks;
 

diff  --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 6919de393b668..56ff2709a589e 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -348,12 +348,36 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
     result = loopUnrollByFactor(scfFor, getFactor());
   else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
     result = loopUnrollByFactor(affineFor, getFactor());
+  else
+    return emitSilenceableError()
+           << "failed to unroll, incorrect type of payload";
+
+  if (failed(result))
+    return emitSilenceableError() << "failed to unroll";
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// LoopUnrollAndJamOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *op,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  LogicalResult result(failure());
+  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
+    result = loopUnrollJamByFactor(scfFor, getFactor());
+  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
+    result = loopUnrollJamByFactor(affineFor, getFactor());
+  else
+    return emitSilenceableError()
+           << "failed to unroll and jam, incorrect type of payload";
+
+  if (failed(result))
+    return emitSilenceableError() << "failed to unroll and jam";
 
-  if (failed(result)) {
-    DiagnosedSilenceableFailure diag = emitSilenceableError()
-                                       << "failed to unroll";
-    return diag;
-  }
   return DiagnosedSilenceableFailure::success();
 }
 

diff  --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index ff5e3a002263d..c0ee9d2afe91c 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -26,10 +26,16 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/MathExtras.h"
+#include <cstdint>
 
 using namespace mlir;
 
+#define DEBUG_TYPE "scf-utils"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
     RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
     ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
@@ -287,6 +293,25 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
   return builder.create<arith::DivUIOp>(loc, sum, divisor);
 }
 
+/// Returns the trip count of `forOp` if its' low bound, high bound and step are
+/// constants, or optional otherwise. Trip count is computed as ceilDiv(highBound
+/// - lowBound, step).
+static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
+  std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
+  std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
+  std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
+  if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value())
+    return {};
+
+  // Constant loop bounds computation.
+  int64_t lbCst = lbCstOp.value();
+  int64_t ubCst = ubCstOp.value();
+  int64_t stepCst = stepCstOp.value();
+  assert(lbCst >= 0 && ubCst >= 0 && stepCst > 0 &&
+         "expected positive loop bounds and step");
+  return llvm::divideCeilSigned(ubCst - lbCst, stepCst);
+}
+
 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -363,25 +388,21 @@ LogicalResult mlir::loopUnrollByFactor(
   Value stepUnrolled;
   bool generateEpilogueLoop = true;
 
-  std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
-  std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
-  std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
-  if (lbCstOp && ubCstOp && stepCstOp) {
+  std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
+  if (constTripCount) {
     // Constant loop bounds computation.
-    int64_t lbCst = lbCstOp.value();
-    int64_t ubCst = ubCstOp.value();
-    int64_t stepCst = stepCstOp.value();
-    assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
-           "expected positive loop bounds and step");
-    int64_t tripCount = llvm::divideCeilSigned(ubCst - lbCst, stepCst);
-
+    int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
+    int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
+    int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
     if (unrollFactor == 1) {
-      if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter)))
+      if (*constTripCount == 1 &&
+          failed(forOp.promoteIfSingleIteration(rewriter)))
         return failure();
       return success();
     }
 
-    int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
+    int64_t tripCountEvenMultiple =
+        *constTripCount - (*constTripCount % unrollFactor);
     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
     int64_t stepUnrolledCst = stepCst * unrollFactor;
 
@@ -464,6 +485,185 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
+/// Check if bounds of all inner loops are defined outside of `forOp`
+/// and return false if not.
+static bool areInnerBoundsInvariant(scf::ForOp forOp) {
+  auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+    if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+      return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+  return !walkResult.wasInterrupted();
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
+                                          uint64_t unrollJamFactor) {
+  assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
+
+  if (unrollJamFactor == 1)
+    return success();
+
+  // If any control operand of any inner loop of `forOp` is defined within
+  // `forOp`, no unroll jam.
+  if (!areInnerBoundsInvariant(forOp)) {
+    LDBG("failed to unroll and jam: inner bounds are not invariant");
+    return failure();
+  }
+
+  // Currently, for operations with results are not supported.
+  if (forOp->getNumResults() > 0) {
+    LDBG("failed to unroll and jam: unsupported loop with results");
+    return failure();
+  }
+
+  // Currently, only constant trip count that divided by the unroll factor is
+  // supported.
+  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount.has_value()) {
+    // If the trip count is dynamic, do not unroll & jam.
+    LDBG("failed to unroll and jam: trip count could not be determined");
+    return failure();
+  }
+  if (unrollJamFactor > *tripCount) {
+    LDBG("unroll and jam factor is greater than trip count, set factor to trip "
+         "count");
+    unrollJamFactor = *tripCount;
+  } else if (*tripCount % unrollJamFactor != 0) {
+    LDBG("failed to unroll and jam: unsupported trip count that is not a "
+         "multiple of unroll jam factor");
+    return failure();
+  }
+
+  // Nothing in the loop body other than the terminator.
+  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
+    return success();
+
+  // Gather all sub-blocks to jam upon the loop being unrolled.
+  JamBlockGatherer<scf::ForOp> jbg;
+  jbg.walk(forOp);
+  auto &subBlocks = jbg.subBlocks;
+
+  // Collect inner loops.
+  SmallVector<scf::ForOp> innerLoops;
+  forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+
+  // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
+  // iteration. There are (`unrollJamFactor` - 1) iterations.
+  SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);
+
+  // For any loop with iter_args, replace it with a new loop that has
+  // `unrollJamFactor` copies of its iterOperands, iter_args and yield
+  // operands.
+  SmallVector<scf::ForOp> newInnerLoops;
+  IRRewriter rewriter(forOp.getContext());
+  for (scf::ForOp oldForOp : innerLoops) {
+    SmallVector<Value> dupIterOperands, dupYieldOperands;
+    ValueRange oldIterOperands = oldForOp.getInits();
+    ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
+    ValueRange oldYieldOperands =
+        cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
+    // Get additional iterOperands, iterArgs, and yield operands. We will
+    // fix iterOperands and yield operands after cloning of sub-blocks.
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
+      dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
+    }
+    // Create a new loop with additional iterOperands, iter_args and yield
+    // operands. This new loop will take the loop body of the original loop.
+    bool forOpReplaced = oldForOp == forOp;
+    scf::ForOp newForOp =
+        cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
+            rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
+            [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+              return dupYieldOperands;
+            }));
+    newInnerLoops.push_back(newForOp);
+    // `forOp` has been replaced with a new loop.
+    if (forOpReplaced)
+      forOp = newForOp;
+    // Update `operandMaps` for `newForOp` iterArgs and results.
+    ValueRange newIterArgs = newForOp.getRegionIterArgs();
+    unsigned oldNumIterArgs = oldIterArgs.size();
+    ValueRange newResults = newForOp.getResults();
+    unsigned oldNumResults = newResults.size() / unrollJamFactor;
+    assert(oldNumIterArgs == oldNumResults &&
+           "oldNumIterArgs must be the same as oldNumResults");
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+        // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
+        // results. Update `operandMaps[i - 1]` to map old iterArgs and results
+        // to those in the `i`th new set.
+        operandMaps[i - 1].map(newIterArgs[j],
+                               newIterArgs[i * oldNumIterArgs + j]);
+        operandMaps[i - 1].map(newResults[j],
+                               newResults[i * oldNumResults + j]);
+      }
+    }
+  }
+
+  // Scale the step of loop being unroll-jammed by the unroll-jam factor.
+  rewriter.setInsertionPoint(forOp);
+  int64_t step = forOp.getConstantStep()->getSExtValue();
+  auto newStep = rewriter.createOrFold<arith::MulIOp>(
+      forOp.getLoc(), forOp.getStep(),
+      rewriter.createOrFold<arith::ConstantOp>(
+          forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
+  forOp.setStep(newStep);
+  auto forOpIV = forOp.getInductionVar();
+
+  // Unroll and jam (appends unrollJamFactor - 1 additional copies).
+  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+    for (auto &subBlock : subBlocks) {
+      // Builder to insert unroll-jammed bodies. Insert right at the end of
+      // sub-block.
+      OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
+
+      // If the induction variable is used, create a remapping to the value for
+      // this unrolled instance.
+      if (!forOpIV.use_empty()) {
+        // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
+        auto ivTag = builder.createOrFold<arith::ConstantOp>(
+            forOp.getLoc(), builder.getIndexAttr(step * i));
+        auto ivUnroll =
+            builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
+        operandMaps[i - 1].map(forOpIV, ivUnroll);
+      }
+      // Clone the sub-block being unroll-jammed.
+      for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
+        builder.clone(*it, operandMaps[i - 1]);
+    }
+    // Fix iterOperands and yield op operands of newly created loops.
+    for (auto newForOp : newInnerLoops) {
+      unsigned oldNumIterOperands =
+          newForOp.getNumRegionIterArgs() / unrollJamFactor;
+      unsigned numControlOperands = newForOp.getNumControlOperands();
+      auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+      unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
+      assert(oldNumIterOperands == oldNumYieldOperands &&
+             "oldNumIterOperands must be the same as oldNumYieldOperands");
+      for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+        // The `i`th duplication of an old iterOperand or yield op operand
+        // needs to be replaced with a mapped value from `operandMaps[i - 1]`
+        // if such mapped value exists.
+        newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+                            operandMaps[i - 1].lookupOrDefault(
+                                newForOp.getOperand(numControlOperands + j)));
+        yieldOp.setOperand(
+            i * oldNumYieldOperands + j,
+            operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
+      }
+    }
+  }
+
+  // Promote the loop body up if this has turned into a single iteration loop.
+  (void)forOp.promoteIfSingleIteration(rewriter);
+  return success();
+}
+
 Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
                                      OpFoldResult lb, OpFoldResult ub,
                                      OpFoldResult step) {

diff  --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
index a01a7995d6189..742b8a2861839 100644
--- a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
@@ -41,6 +41,105 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @loop_unroll_and_jam_unsupported_trip_count_not_multiple_of_factor() {
+  %c0 = arith.constant 0 : index
+  %c40 = arith.constant 40 : index
+  %c2 = arith.constant 2 : index
+  scf.for %i = %c0 to %c40 step %c2 {
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    // expected-error @below {{failed to unroll and jam}}
+    transform.loop.unroll_and_jam %1 { factor = 3 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @loop_unroll_and_jam_unsupported_loop_with_results() -> index {
+  %c0 = arith.constant 0 : index
+  %c40 = arith.constant 40 : index
+  %c2 = arith.constant 2 : index
+  %sum = scf.for %i = %c0 to %c40 step %c2 iter_args(%does_not_alias_aggregated = %c0) -> (index) {
+    %sum = arith.addi %i, %i : index
+    scf.yield %sum : index
+  }
+  return %sum : index
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    // expected-error @below {{failed to unroll and jam}}
+    transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+func.func private @loop_unroll_and_jam_unsupported_dynamic_trip_count(%arg0: memref<96x128xi8, 3>, %arg1: memref<128xi8, 3>) {
+  %c96 = arith.constant 96 : index
+  %c1 = arith.constant 1 : index
+  %c128 = arith.constant 128 : index
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  scf.for %arg4 = %c0 to %c4 step %c1 {
+    scf.for %arg2 = %c0 to %c128 step %arg4 {
+      %3 = memref.load %arg1[%arg2] : memref<128xi8, 3>
+      %sum = scf.for %arg3 = %c0 to %c96 step %c1 iter_args(%does_not_alias_aggregated = %3) -> (i8) {
+      %2 = memref.load %arg0[%arg3, %arg2] : memref<96x128xi8, 3>
+      %4 = arith.addi %2, %3 : i8
+      scf.yield %4 : i8
+      }
+      memref.store %sum, %arg1[%arg2] : memref<128xi8, 3>
+    }
+    scf.yield
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["memref.store"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    // expected-error @below {{failed to unroll and jam}}
+    transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @loop_unroll_and_jam_unsupported_dynamic_trip_count(%upper_bound: index) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  scf.for %i = %c0 to %upper_bound step %c2 {
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    // expected-error @below {{failed to unroll and jam}}
+    transform.loop.unroll_and_jam %1 { factor = 2 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
 func.func private @cond() -> i1
 func.func private @body()
 

diff  --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
index b91225bf45b96..d9445182769e7 100644
--- a/mlir/test/Dialect/SCF/transform-ops.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops.mlir
@@ -168,6 +168,221 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @loop_unroll_and_jam_op
+func.func @loop_unroll_and_jam_op() {
+  // CHECK:           %[[VAL_0:.*]] = arith.constant 0 : index
+  // CHECK:           %[[VAL_1:.*]] = arith.constant 40 : index
+  // CHECK:           %[[VAL_2:.*]] = arith.constant 2 : index
+  // CHECK:           %[[FACTOR:.*]] = arith.constant 4 : index
+  // CHECK:           %[[STEP:.*]] = arith.constant 8 : index
+  %c0 = arith.constant 0 : index
+  %c40 = arith.constant 40 : index
+  %c2 = arith.constant 2 : index
+  // CHECK:           scf.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[STEP]] {
+  scf.for %i = %c0 to %c40 step %c2 {
+  // CHECK:             %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : index
+  // CHECK:             %[[VAL_7:.*]] = arith.constant 2 : index
+  // CHECK:             %[[VAL_8:.*]] = arith.addi %[[VAL_5]], %[[VAL_7]] : index
+  // CHECK:             %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_8]] : index
+  // CHECK:             %[[VAL_10:.*]] = arith.constant 4 : index
+  // CHECK:             %[[VAL_11:.*]] = arith.addi %[[VAL_5]], %[[VAL_10]] : index
+  // CHECK:             %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_11]] : index
+  // CHECK:             %[[VAL_13:.*]] = arith.constant 6 : index
+  // CHECK:             %[[VAL_14:.*]] = arith.addi %[[VAL_5]], %[[VAL_13]] : index
+  // CHECK:             %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_14]] : index
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+// CHECK:       %[[VAL_0:.*]]: memref<96x128xi8, 3>, %[[VAL_1:.*]]: memref<128xi8, 3>) {
+func.func private @loop_unroll_and_jam_op(%arg0: memref<96x128xi8, 3>, %arg1: memref<128xi8, 3>) {
+  // CHECK:           %[[UB_INNER:.*]] = arith.constant 96
+  // CHECK:           %[[STEP_INNER:.*]] = arith.constant 1
+  // CHECK:           %[[UB_OUTER:.*]] = arith.constant 128
+  // CHECK:           %[[LB:.*]] = arith.constant 0
+  // CHECK:           %[[UNUSED:.*]] = arith.constant 4
+  // CHECK:           %[[UNROLL_FACTOR:.*]] = arith.constant 4
+  %c96 = arith.constant 96 : index
+  %c1 = arith.constant 1 : index
+  %c128 = arith.constant 128 : index
+  %c0 = arith.constant 0 : index
+  // CHECK:           scf.for %[[OUTER_I:.*]] = %[[LB]] to %[[UB_OUTER]] step %[[UNROLL_FACTOR]] {
+	scf.for %arg2 = %c0 to %c128 step %c1 {
+    // CHECK:             %[[LOAD0:.*]] = memref.load %[[VAL_1]]{{\[}}%[[OUTER_I]]]
+    // CHECK:             %[[ONE_0:.*]] = arith.constant 1
+    // CHECK:             %[[INC_LOAD1:.*]] = arith.addi %[[OUTER_I]], %[[ONE_0]]
+    // CHECK:             %[[LOAD1:.*]] = memref.load %[[VAL_1]]{{\[}}%[[INC_LOAD1]]]
+    // CHECK:             %[[TWO_0:.*]] = arith.constant 2
+    // CHECK:             %[[INC_LOAD2:.*]] = arith.addi %[[OUTER_I]], %[[TWO_0]]
+    // CHECK:             %[[LOAD2:.*]] = memref.load %[[VAL_1]]{{\[}}%[[INC_LOAD2]]]
+    // CHECK:             %[[THREE_0:.*]] = arith.constant 3
+    // CHECK:             %[[INC_LOAD3:.*]] = arith.addi %[[OUTER_I]], %[[THREE_0]]
+    // CHECK:             %[[LOAD3:.*]] = memref.load %[[VAL_1]]{{\[}}%[[INC_LOAD3]]]
+	  %3 = memref.load %arg1[%arg2] : memref<128xi8, 3>
+    // CHECK:             %[[VAL_19:.*]]:4 = scf.for %[[VAL_20:.*]] = %[[LB]] to %[[UB_INNER]] step %[[STEP_INNER]] iter_args(%[[VAL_21:.*]] = %[[LOAD0]], %[[VAL_22:.*]] = %[[LOAD1]], %[[VAL_23:.*]] = %[[LOAD2]], %[[VAL_24:.*]] = %[[LOAD3]])
+	  %sum = scf.for %arg3 = %c0 to %c96 step %c1 iter_args(%does_not_alias_aggregated = %3) -> (i8) {
+    // CHECK:               %[[LOAD0_INNER:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[OUTER_I]]]
+    // CHECK:               %[[SUM_0:.*]] = arith.addi %[[LOAD0_INNER]], %[[LOAD0]]
+    // CHECK:               %[[ONE_1:.*]] = arith.constant 1
+    // CHECK:               %[[INC1_INNER:.*]] = arith.addi %[[OUTER_I]], %[[ONE_1]]
+    // CHECK:               %[[LOAD1_INNER:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[INC1_INNER]]]
+    // CHECK:               %[[SUM_1:.*]] = arith.addi %[[LOAD1_INNER]], %[[LOAD1]]
+    // CHECK:               %[[TWO_1:.*]] = arith.constant 2
+    // CHECK:               %[[INC2_INNER:.*]] = arith.addi %[[OUTER_I]], %[[TWO_1]]
+    // CHECK:               %[[LOAD2_INNER:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[INC2_INNER]]]
+    // CHECK:               %[[SUM_2:.*]] = arith.addi %[[LOAD2_INNER]], %[[LOAD2]]
+    // CHECK:               %[[THREE_1:.*]] = arith.constant 3
+    // CHECK:               %[[INC3_INNER:.*]] = arith.addi %[[OUTER_I]], %[[THREE_1]]
+    // CHECK:               %[[LOAD3_INNER:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[INC3_INNER]]]
+    // CHECK:               %[[SUM_3:.*]] = arith.addi %[[LOAD3_INNER]], %[[LOAD3]]
+		%2 = memref.load %arg0[%arg3, %arg2] : memref<96x128xi8, 3>
+		%4 = arith.addi %2, %3 : i8
+    // CHECK:               scf.yield %[[SUM_0]], %[[SUM_1]], %[[SUM_2]], %[[SUM_3]]
+		scf.yield %4 : i8
+	  }
+	  memref.store %sum, %arg1[%arg2] : memref<128xi8, 3>
+    // CHECK:             memref.store %[[VAL_39:.*]]#0, %[[VAL_1]]{{\[}}%[[OUTER_I]]]
+    // CHECK:             %[[ONE_2:.*]] = arith.constant 1
+    // CHECK:             %[[INC_STORE1:.*]] = arith.addi %[[OUTER_I]], %[[ONE_2]]
+    // CHECK:             memref.store %[[VAL_39]]#1, %[[VAL_1]]{{\[}}%[[INC_STORE1]]]
+    // CHECK:             %[[TWO_2:.*]] = arith.constant 2
+    // CHECK:             %[[INC_STORE2:.*]] = arith.addi %[[OUTER_I]], %[[TWO_2]]
+    // CHECK:             memref.store %[[VAL_39]]#2, %[[VAL_1]]{{\[}}%[[INC_STORE2]]]
+    // CHECK:             %[[THREE_2:.*]] = arith.constant 3
+    // CHECK:             %[[INC_STORE3:.*]] = arith.addi %[[OUTER_I]], %[[THREE_2]]
+    // CHECK:             memref.store %[[VAL_39]]#3, %[[VAL_1]]{{\[}}%[[INC_STORE3]]]
+	}
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["memref.store"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+func.func @loop_unroll_and_jam_op() {
+  // CHECK:           %[[ZERO:.*]] = arith.constant 0
+  // CHECK:           %[[UNUSED_FACTOR:.*]] = arith.constant 4
+  // CHECK:           %[[UNUSED_STEP:.*]] = arith.constant 2
+  // CHECK:           %[[UNUSED_STEP2:.*]] = arith.constant 2
+  // CHECK:           %[[UNUSED_UB:.*]] = arith.constant 4
+  // CHECK:           %[[ITER_0_RES:.*]] = arith.addi %[[ZERO]], %[[ZERO]]
+  // CHECK:           %[[TWO:.*]] = arith.constant 2
+  // CHECK:           %[[STEP_1_I:.*]] = arith.addi %[[ZERO]], %[[TWO]]
+  // CHECK:           %[[ITER_1_RES:.*]] = arith.addi %[[STEP_1_I]], %[[STEP_1_I]]
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c2 = arith.constant 2 : index
+  scf.for %i = %c0 to %c4 step %c2 {
+    arith.addi %i, %i : index
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+func.func @loop_unroll_and_jam_op() {
+  // CHECK:           %[[LB:.*]] = arith.constant 0
+  // CHECK:           %[[UB:.*]] = arith.constant 4
+  // CHECK:           %[[STEP:.*]] = arith.constant 2
+  %c0 = arith.constant 0 : index
+  %c4 = arith.constant 4 : index
+  %c2 = arith.constant 2 : index
+  // CHECK:           scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+  scf.for %i = %c0 to %c4 step %c2 {
+    scf.yield
+  }
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+    transform.loop.unroll_and_jam %1 { factor = 2 } : !transform.op<"scf.for">
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+// CHECK:  %[[VAL_0:.*]]: memref<21x30xf32, 1>, %[[INIT0:.*]]: f32, %[[INIT1:.*]]: f32) {
+func.func @loop_unroll_and_jam_op(%arg0: memref<21x30xf32, 1>, %init : f32, %init1 : f32) {
+  // CHECK:           %[[LAST_OUT_ITER:.*]] = arith.constant 20
+  // CHECK:           %[[VAL_4:.*]]:2 = affine.for %[[OUTER_I:.*]] = 0 to 20 step 2 iter_args(%[[VAL_6:.*]] = %[[INIT0]], %[[VAL_7:.*]] = %[[INIT0]])
+  %0 = affine.for %arg3 = 0 to 21 iter_args(%arg4 = %init) -> (f32) {
+    // CHECK:             %[[VAL_8:.*]]:2 = affine.for %[[INNER_I:.*]] = 0 to 30 iter_args(%[[SUM0:.*]] = %[[INIT1]], %[[SUM1:.*]] = %[[INIT1]])
+    %1 = affine.for %arg5 = 0 to 30 iter_args(%arg6 = %init1) -> (f32) {
+      // CHECK:               %[[LOAD0:.*]] = affine.load %[[VAL_0]]{{\[}}%[[OUTER_I]], %[[INNER_I]]]
+      // CHECK:               %[[ITER_SUM0:.*]] = arith.addf %[[SUM0]], %[[LOAD0]]
+      // CHECK:               %[[APPLY_OUTER_I:.*]] = affine.apply #map(%[[OUTER_I]])
+      // CHECK:               %[[LOAD1:.*]] = affine.load %[[VAL_0]]{{\[}}%[[APPLY_OUTER_I]], %[[INNER_I]]]
+      // CHECK:               %[[ITER_SUM1:.*]] = arith.addf %[[SUM1]], %[[LOAD1]]
+      // CHECK:               affine.yield %[[ITER_SUM0]], %[[ITER_SUM1]]
+      %3 = affine.load %arg0[%arg3, %arg5] : memref<21x30xf32, 1>
+      %4 = arith.addf %arg6, %3 : f32
+      affine.yield %4 : f32
+    }
+    // CHECK:             %[[MUL0:.*]] = arith.mulf %[[VAL_6]], %[[VAL_18:.*]]#0
+    // CHECK:             %[[VAL_19:.*]] = affine.apply #map(%[[OUTER_I]])
+    // CHECK:             %[[MUL1:.*]] = arith.mulf %[[VAL_7]], %[[VAL_18]]#1
+    // CHECK:             affine.yield %[[MUL0]], %[[MUL1]]
+    // CHECK:           }
+    // CHECK:           %[[VAL_21:.*]] = arith.mulf %[[VAL_22:.*]]#0, %[[VAL_22]]#1
+    // CHECK:           %[[VAL_23:.*]] = affine.for %[[SUFFIX_I:.*]] = 0 to 30 iter_args(%[[ITER_I:.*]] = %[[INIT1]])
+    // CHECK:             %[[LOAD_SUFFIX:.*]] = affine.load %[[VAL_0]]{{\[}}%[[LAST_OUT_ITER]], %[[SUFFIX_I]]]
+    // CHECK:             %[[RES:.*]] = arith.addf %[[ITER_I]], %[[LOAD_SUFFIX]]
+    // CHECK:             affine.yield %[[RES]]
+    %2 = arith.mulf %arg4, %1 : f32
+    affine.yield %2 : f32
+  }
+  // CHECK:           %[[VAL_28:.*]] = arith.mulf %[[VAL_21]], %[[VAL_29:.*]]
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {op_name = "affine.for"} : (!transform.any_op) -> !transform.op<"affine.for">
+    %2 = transform.get_parent_op %1 {op_name = "affine.for"} : (!transform.op<"affine.for">) -> !transform.op<"affine.for">
+    transform.loop.unroll_and_jam %2 { factor = 2 } : !transform.op<"affine.for">
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @loop_unroll_op() {
   %c0 = arith.constant 0 : index
   %c42 = arith.constant 42 : index


        


More information about the Mlir-commits mailing list