[Mlir-commits] [mlir] Introduce new Unroll And Jam loop transform for SCF/Affine loops (PR #94142)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jun 1 22:41:02 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Aviad Cohen (AviadCo)

<details>
<summary>Changes</summary>

Unroll And Jam was supported in affine dialect long time ago using pass. This commit exposes the pattern using transform and in addition adds partial support for SCF loops.

---

Patch is 37.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94142.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td (+35-1) 
- (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (+8) 
- (added) mlir/include/mlir/IR/LoopUtils.h (+56) 
- (modified) mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (+2-29) 
- (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+22) 
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+209-7) 
- (modified) mlir/test/Dialect/SCF/transform-ops-invalid.mlir (+99) 
- (modified) mlir/test/Dialect/SCF/transform-ops.mlir (+215) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5eefe2664d0a1..29651bed296e5 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -243,7 +243,41 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
     This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
     in the return. If all the operations referred to by the `target` operand
     unroll properly, the transform succeeds. Otherwise the transform produces a
-    silencebale failure.
+    silenceable failure.
+
+    Does not return handles as the operation may result in the loop being
+    removed after a full unrolling.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       ConfinedAttr<I64Attr, [IntPositive]>:$factor);
+
+  let assemblyFormat = "$target attr-dict `:` type($target)";
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+def LoopUnrollAndJamOp : Op<Transform_Dialect, "loop.unroll_and_jam",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait]> {
+  let summary = "Unrolls and jam the given loop with the given unroll factor";
+  let description = [{
+    Unrolls & jams each loop associated with the given handle to have up to the given
+    number of loop body copies per iteration. If the unroll factor is larger
+    than the loop trip count, the latter is used as the unroll factor instead.
+
+    #### Return modes
+
+    This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
+    in the return. If all the operations referred to by the `target` operand
+    unroll properly, the transform succeeds. Otherwise the transform produces a
+    silenceable failure.
 
     Does not return handles as the operation may result in the loop being
     removed after a full unrolling.
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index bc09cc7f7fa5e..ad3113c960d37 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -120,6 +120,14 @@ LogicalResult loopUnrollByFactor(
     scf::ForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
 
+/// Unrolls and jam this for operation by the specified unroll factor. Returns
+/// failure if the loop cannot be unrolled either due to restrictions or due to
+/// invalid unroll factors. In case of unroll factor of 1, the function bails
+/// out without doing anything (returns success). Currently, only constant trip
+/// count that are divided by the unroll factor is supported. Currently, for
+/// operations with results are not supported.
+LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor);
+
 /// Tile a nest of standard for loops rooted at `rootForOp` by finding such
 /// parametric tile sizes that the outer loops have a fixed number of iterations
 /// as defined in `sizes`.
diff --git a/mlir/include/mlir/IR/LoopUtils.h b/mlir/include/mlir/IR/LoopUtils.h
new file mode 100644
index 0000000000000..410eec223e637
--- /dev/null
+++ b/mlir/include/mlir/IR/LoopUtils.h
@@ -0,0 +1,56 @@
+//===- LoopUtils.h -  LoopUtils Support ---------------------*- C++
+//-*-=============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions for the action framework. This framework
+// allows for external entities to control certain actions taken by the compiler
+// by registering handler functions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_LOOP_UTILS_H
+#define MLIR_IR_LOOP_UTILS_H
+
+#include "mlir/IR/BuiltinOps.h"
+
+namespace mlir {
+
+// Gathers all maximal sub-blocks of operations that do not themselves
+// include a `OpTy` (an operation could have a descendant `OpTy` though
+// in its tree). Ignore the block terminators.
+template <typename OpTy>
+struct JamBlockGatherer {
+  // Store iterators to the first and last op of each sub-block found.
+  llvm::SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+
+  // This is a linear time walk.
+  void walk(Operation *op) {
+    for (auto &region : op->getRegions())
+      for (auto &block : region)
+        walk(block);
+  }
+
+  void walk(Block &block) {
+    assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
+           "expected block to have a terminator");
+    for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
+      auto subBlockStart = it;
+      while (it != e && !isa<OpTy>(&*it))
+        ++it;
+      if (it != subBlockStart)
+        subBlocks.emplace_back(subBlockStart, std::prev(it));
+      // Process all for ops that appear next.
+      while (it != e && isa<OpTy>(&*it))
+        walk(&*it++);
+    }
+  }
+};
+
+} // namespace mlir
+
+#endif // MLIR_IR_LOOP_UTILS_H
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 268050a30e002..6d35f06faa45b 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/LoopUtils.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/MapVector.h"
@@ -1102,34 +1103,6 @@ static bool areInnerBoundsInvariant(AffineForOp forOp) {
   return !walkResult.wasInterrupted();
 }
 
-// Gathers all maximal sub-blocks of operations that do not themselves
-// include a for op (a operation could have a descendant for op though
-// in its tree).  Ignore the block terminators.
-struct JamBlockGatherer {
-  // Store iterators to the first and last op of each sub-block found.
-  std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
-
-  // This is a linear time walk.
-  void walk(Operation *op) {
-    for (auto &region : op->getRegions())
-      for (auto &block : region)
-        walk(block);
-  }
-
-  void walk(Block &block) {
-    for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
-      auto subBlockStart = it;
-      while (it != e && !isa<AffineForOp>(&*it))
-        ++it;
-      if (it != subBlockStart)
-        subBlocks.emplace_back(subBlockStart, std::prev(it));
-      // Process all for ops that appear next.
-      while (it != e && isa<AffineForOp>(&*it))
-        walk(&*it++);
-    }
-  }
-};
-
 /// Unrolls and jams this loop by the specified factor.
 LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
                                                   uint64_t unrollJamFactor) {
@@ -1159,7 +1132,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
     return failure();
 
   // Gather all sub-blocks to jam upon the loop being unrolled.
-  JamBlockGatherer jbg;
+  JamBlockGatherer<AffineForOp> jbg;
   jbg.walk(forOp);
   auto &subBlocks = jbg.subBlocks;
 
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 69f83d8bd70da..df5701a40dc1b 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -304,6 +304,28 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
   return DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// LoopUnrollAndJamOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
+    transform::TransformRewriter &rewriter, Operation *op,
+    transform::ApplyToEachResultList &results,
+    transform::TransformState &state) {
+  LogicalResult result(failure());
+  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
+    result = loopUnrollJamByFactor(scfFor, getFactor());
+  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
+    result = loopUnrollJamByFactor(affineFor, getFactor());
+
+  if (failed(result)) {
+    DiagnosedSilenceableFailure diag = emitSilenceableError()
+                                       << "failed to unroll and jam";
+    return diag;
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopCoalesceOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 6658cca03eba7..07ae3507c3ded 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
+#include "mlir/IR/LoopUtils.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Support/MathExtras.h"
@@ -26,9 +27,15 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include <cstdint>
 
 using namespace mlir;
 
+#define DEBUG_TYPE "scf-utils"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
 namespace {
 // This structure is to pass and return sets of loop parameters without
 // confusing the order.
@@ -296,6 +303,23 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
   return builder.create<arith::DivUIOp>(loc, sum, divisor);
 }
 
+static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
+  std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
+  std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
+  std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
+  if (lbCstOp && ubCstOp && stepCstOp) {
+    // Constant loop bounds computation.
+    int64_t lbCst = lbCstOp.value();
+    int64_t ubCst = ubCstOp.value();
+    int64_t stepCst = stepCstOp.value();
+    assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
+           "expected positive loop bounds and step");
+    return mlir::ceilDiv(ubCst - lbCst, stepCst);
+  }
+
+  return {};
+}
+
 /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
 /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
 /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -375,22 +399,21 @@ LogicalResult mlir::loopUnrollByFactor(
   std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
   std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
   std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
-  if (lbCstOp && ubCstOp && stepCstOp) {
+  auto constTripCount = getConstantTripCount(forOp);
+  if (constTripCount) {
     // Constant loop bounds computation.
     int64_t lbCst = lbCstOp.value();
     int64_t ubCst = ubCstOp.value();
     int64_t stepCst = stepCstOp.value();
-    assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
-           "expected positive loop bounds and step");
-    int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
-
     if (unrollFactor == 1) {
-      if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter)))
+      if (*constTripCount == 1 &&
+          failed(forOp.promoteIfSingleIteration(rewriter)))
         return failure();
       return success();
     }
 
-    int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
+    int64_t tripCountEvenMultiple =
+        *constTripCount - (*constTripCount % unrollFactor);
     int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
     int64_t stepUnrolledCst = stepCst * unrollFactor;
 
@@ -473,6 +496,185 @@ LogicalResult mlir::loopUnrollByFactor(
   return success();
 }
 
+/// Check if bounds of all inner loops are defined outside of `forOp`
+/// and return false if not.
+static bool areInnerBoundsInvariant(scf::ForOp forOp) {
+  auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+    if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+        !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+      return WalkResult::interrupt();
+
+    return WalkResult::advance();
+  });
+  return !walkResult.wasInterrupted();
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
+                                          uint64_t unrollJamFactor) {
+  assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
+
+  if (unrollJamFactor == 1)
+    return success();
+
+  // If any control operand of any inner loop of `forOp` is defined within
+  // `forOp`, no unroll jam.
+  if (!areInnerBoundsInvariant(forOp)) {
+    LDBG("failed to unroll and jam: inner bounds are not invariant");
+    return failure();
+  }
+
+  // Currently, for operations with results are not supported.
+  if (forOp->getNumResults() > 0) {
+    LDBG("failed to unroll and jam: unsupported loop with results");
+    return failure();
+  }
+
+  // Currently, only constant trip count that divided by the unroll factor is
+  // supported.
+  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount) {
+    // If the trip count is dynamic, do not unroll & jam.
+    LDBG("failed to unroll and jam: trip count could not be determined");
+    return failure();
+  }
+  if (unrollJamFactor > *tripCount) {
+    LDBG("unroll and jam factor is greater than trip count, set factor to trip "
+         "count");
+    unrollJamFactor = *tripCount;
+  } else if (*tripCount % unrollJamFactor != 0) {
+    LDBG("failed to unroll and jam: unsupported trip count that is not a "
+         "multiple of unroll jam factor");
+    return failure();
+  }
+
+  // Nothing in the loop body other than the terminator.
+  if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
+    return success();
+
+  // Gather all sub-blocks to jam upon the loop being unrolled.
+  JamBlockGatherer<scf::ForOp> jbg;
+  jbg.walk(forOp);
+  auto &subBlocks = jbg.subBlocks;
+
+  // Collect inner loops.
+  SmallVector<scf::ForOp, 4> innerLoops;
+  forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+
+  // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
+  // iteration. There are (`unrollJamFactor` - 1) iterations.
+  SmallVector<IRMapping, 4> operandMaps(unrollJamFactor - 1);
+
+  // For any loop with iter_args, replace it with a new loop that has
+  // `unrollJamFactor` copies of its iterOperands, iter_args and yield
+  // operands.
+  SmallVector<scf::ForOp, 4> newInnerLoops;
+  IRRewriter rewriter(forOp.getContext());
+  for (scf::ForOp oldForOp : innerLoops) {
+    SmallVector<Value> dupIterOperands, dupYieldOperands;
+    ValueRange oldIterOperands = oldForOp.getInits();
+    ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
+    ValueRange oldYieldOperands =
+        cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
+    // Get additional iterOperands, iterArgs, and yield operands. We will
+    // fix iterOperands and yield operands after cloning of sub-blocks.
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
+      dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
+    }
+    // Create a new loop with additional iterOperands, iter_args and yield
+    // operands. This new loop will take the loop body of the original loop.
+    bool forOpReplaced = oldForOp == forOp;
+    scf::ForOp newForOp =
+        cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
+            rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
+            [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+              return dupYieldOperands;
+            }));
+    newInnerLoops.push_back(newForOp);
+    // `forOp` has been replaced with a new loop.
+    if (forOpReplaced)
+      forOp = newForOp;
+    // Update `operandMaps` for `newForOp` iterArgs and results.
+    ValueRange newIterArgs = newForOp.getRegionIterArgs();
+    unsigned oldNumIterArgs = oldIterArgs.size();
+    ValueRange newResults = newForOp.getResults();
+    unsigned oldNumResults = newResults.size() / unrollJamFactor;
+    assert(oldNumIterArgs == oldNumResults &&
+           "oldNumIterArgs must be the same as oldNumResults");
+    for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+      for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+        // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
+        // results. Update `operandMaps[i - 1]` to map old iterArgs and results
+        // to those in the `i`th new set.
+        operandMaps[i - 1].map(newIterArgs[j],
+                               newIterArgs[i * oldNumIterArgs + j]);
+        operandMaps[i - 1].map(newResults[j],
+                               newResults[i * oldNumResults + j]);
+      }
+    }
+  }
+
+  // Scale the step of loop being unroll-jammed by the unroll-jam factor.
+  rewriter.setInsertionPoint(forOp);
+  int64_t step = forOp.getConstantStep()->getSExtValue();
+  auto newStep = rewriter.createOrFold<arith::MulIOp>(
+      forOp.getLoc(), forOp.getStep(),
+      rewriter.createOrFold<arith::ConstantOp>(
+          forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
+  forOp.setStep(newStep);
+  auto forOpIV = forOp.getInductionVar();
+
+  // Unroll and jam (appends unrollJamFactor - 1 additional copies).
+  for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+    for (auto &subBlock : subBlocks) {
+      // Builder to insert unroll-jammed bodies. Insert right at the end of
+      // sub-block.
+      OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
+
+      // If the induction variable is used, create a remapping to the value for
+      // this unrolled instance.
+      if (!forOpIV.use_empty()) {
+        // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
+        auto ivTag = builder.createOrFold<arith::ConstantOp>(
+            forOp.getLoc(), builder.getIndexAttr(step * i));
+        auto ivUnroll =
+            builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
+        operandMaps[i - 1].map(forOpIV, ivUnroll);
+      }
+      // Clone the sub-block being unroll-jammed.
+      for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
+        builder.clone(*it, operandMaps[i - 1]);
+    }
+    // Fix iterOperands and yield op operands of newly created loops.
+    for (auto newForOp : newInnerLoops) {
+      unsigned oldNumIterOperands =
+          newForOp.getNumRegionIterArgs() / unrollJamFactor;
+      unsigned numControlOperands = newForOp.getNumControlOperands();
+      auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+      unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
+      assert(oldNumIterOperands == oldNumYieldOperands &&
+             "oldNumIterOperands must be the same as oldNumYieldOperands");
+      for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+        // The `i`th duplication of an old iterOperand or yield op operand
+        // needs to be replaced with a mapped value from `operandMaps[i - 1]`
+        // if such mapped value exists.
+        newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+                            operandMaps[i - 1].lookupOrDefault(
+                   ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/94142


More information about the Mlir-commits mailing list