[Mlir-commits] [mlir] Introduce new Unroll And Jam loop transform for SCF/Affine loops (PR #94142)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jun 1 22:41:02 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Aviad Cohen (AviadCo)
<details>
<summary>Changes</summary>
Unroll And Jam was supported in affine dialect long time ago using pass. This commit exposes the pattern using transform and in addition adds partial support for SCF loops.
---
Patch is 37.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94142.diff
8 Files Affected:
- (modified) mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td (+35-1)
- (modified) mlir/include/mlir/Dialect/SCF/Utils/Utils.h (+8)
- (added) mlir/include/mlir/IR/LoopUtils.h (+56)
- (modified) mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (+2-29)
- (modified) mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp (+22)
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+209-7)
- (modified) mlir/test/Dialect/SCF/transform-ops-invalid.mlir (+99)
- (modified) mlir/test/Dialect/SCF/transform-ops.mlir (+215)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5eefe2664d0a1..29651bed296e5 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -243,7 +243,41 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
in the return. If all the operations referred to by the `target` operand
unroll properly, the transform succeeds. Otherwise the transform produces a
- silencebale failure.
+ silenceable failure.
+
+ Does not return handles as the operation may result in the loop being
+ removed after a full unrolling.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ ConfinedAttr<I64Attr, [IntPositive]>:$factor);
+
+ let assemblyFormat = "$target attr-dict `:` type($target)";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
+def LoopUnrollAndJamOp : Op<Transform_Dialect, "loop.unroll_and_jam",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait]> {
+ let summary = "Unrolls and jam the given loop with the given unroll factor";
+ let description = [{
+ Unrolls & jams each loop associated with the given handle to have up to the given
+ number of loop body copies per iteration. If the unroll factor is larger
+ than the loop trip count, the latter is used as the unroll factor instead.
+
+ #### Return modes
+
+ This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
+ in the return. If all the operations referred to by the `target` operand
+ unroll properly, the transform succeeds. Otherwise the transform produces a
+ silenceable failure.
Does not return handles as the operation may result in the loop being
removed after a full unrolling.
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index bc09cc7f7fa5e..ad3113c960d37 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -120,6 +120,14 @@ LogicalResult loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
+/// Unrolls and jam this for operation by the specified unroll factor. Returns
+/// failure if the loop cannot be unrolled either due to restrictions or due to
+/// invalid unroll factors. In case of unroll factor of 1, the function bails
+/// out without doing anything (returns success). Currently, only constant trip
+/// count that are divided by the unroll factor is supported. Currently, for
+/// operations with results are not supported.
+LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor);
+
/// Tile a nest of standard for loops rooted at `rootForOp` by finding such
/// parametric tile sizes that the outer loops have a fixed number of iterations
/// as defined in `sizes`.
diff --git a/mlir/include/mlir/IR/LoopUtils.h b/mlir/include/mlir/IR/LoopUtils.h
new file mode 100644
index 0000000000000..410eec223e637
--- /dev/null
+++ b/mlir/include/mlir/IR/LoopUtils.h
@@ -0,0 +1,56 @@
+//===- LoopUtils.h - LoopUtils Support ---------------------*- C++
+//-*-=============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions for the action framework. This framework
+// allows for external entities to control certain actions taken by the compiler
+// by registering handler functions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_LOOP_UTILS_H
+#define MLIR_IR_LOOP_UTILS_H
+
+#include "mlir/IR/BuiltinOps.h"
+
+namespace mlir {
+
+// Gathers all maximal sub-blocks of operations that do not themselves
+// include a `OpTy` (an operation could have a descendant `OpTy` though
+// in its tree). Ignore the block terminators.
+template <typename OpTy>
+struct JamBlockGatherer {
+ // Store iterators to the first and last op of each sub-block found.
+ llvm::SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+
+ // This is a linear time walk.
+ void walk(Operation *op) {
+ for (auto ®ion : op->getRegions())
+ for (auto &block : region)
+ walk(block);
+ }
+
+ void walk(Block &block) {
+ assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
+ "expected block to have a terminator");
+ for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
+ auto subBlockStart = it;
+ while (it != e && !isa<OpTy>(&*it))
+ ++it;
+ if (it != subBlockStart)
+ subBlocks.emplace_back(subBlockStart, std::prev(it));
+ // Process all for ops that appear next.
+ while (it != e && isa<OpTy>(&*it))
+ walk(&*it++);
+ }
+ }
+};
+
+} // namespace mlir
+
+#endif // MLIR_IR_LOOP_UTILS_H
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 268050a30e002..6d35f06faa45b 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/LoopUtils.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/MapVector.h"
@@ -1102,34 +1103,6 @@ static bool areInnerBoundsInvariant(AffineForOp forOp) {
return !walkResult.wasInterrupted();
}
-// Gathers all maximal sub-blocks of operations that do not themselves
-// include a for op (a operation could have a descendant for op though
-// in its tree). Ignore the block terminators.
-struct JamBlockGatherer {
- // Store iterators to the first and last op of each sub-block found.
- std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
-
- // This is a linear time walk.
- void walk(Operation *op) {
- for (auto ®ion : op->getRegions())
- for (auto &block : region)
- walk(block);
- }
-
- void walk(Block &block) {
- for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
- auto subBlockStart = it;
- while (it != e && !isa<AffineForOp>(&*it))
- ++it;
- if (it != subBlockStart)
- subBlocks.emplace_back(subBlockStart, std::prev(it));
- // Process all for ops that appear next.
- while (it != e && isa<AffineForOp>(&*it))
- walk(&*it++);
- }
- }
-};
-
/// Unrolls and jams this loop by the specified factor.
LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
uint64_t unrollJamFactor) {
@@ -1159,7 +1132,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
return failure();
// Gather all sub-blocks to jam upon the loop being unrolled.
- JamBlockGatherer jbg;
+ JamBlockGatherer<AffineForOp> jbg;
jbg.walk(forOp);
auto &subBlocks = jbg.subBlocks;
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 69f83d8bd70da..df5701a40dc1b 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -304,6 +304,28 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// LoopUnrollAndJamOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *op,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ LogicalResult result(failure());
+ if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
+ result = loopUnrollJamByFactor(scfFor, getFactor());
+ else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
+ result = loopUnrollJamByFactor(affineFor, getFactor());
+
+ if (failed(result)) {
+ DiagnosedSilenceableFailure diag = emitSilenceableError()
+ << "failed to unroll and jam";
+ return diag;
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// LoopCoalesceOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 6658cca03eba7..07ae3507c3ded 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/LoopUtils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/MathExtras.h"
@@ -26,9 +27,15 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include <cstdint>
using namespace mlir;
+#define DEBUG_TYPE "scf-utils"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
namespace {
// This structure is to pass and return sets of loop parameters without
// confusing the order.
@@ -296,6 +303,23 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
return builder.create<arith::DivUIOp>(loc, sum, divisor);
}
+static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
+ std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
+ std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
+ std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
+ if (lbCstOp && ubCstOp && stepCstOp) {
+ // Constant loop bounds computation.
+ int64_t lbCst = lbCstOp.value();
+ int64_t ubCst = ubCstOp.value();
+ int64_t stepCst = stepCstOp.value();
+ assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
+ "expected positive loop bounds and step");
+ return mlir::ceilDiv(ubCst - lbCst, stepCst);
+ }
+
+ return {};
+}
+
/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -375,22 +399,21 @@ LogicalResult mlir::loopUnrollByFactor(
std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
- if (lbCstOp && ubCstOp && stepCstOp) {
+ auto constTripCount = getConstantTripCount(forOp);
+ if (constTripCount) {
// Constant loop bounds computation.
int64_t lbCst = lbCstOp.value();
int64_t ubCst = ubCstOp.value();
int64_t stepCst = stepCstOp.value();
- assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
- "expected positive loop bounds and step");
- int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
-
if (unrollFactor == 1) {
- if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter)))
+ if (*constTripCount == 1 &&
+ failed(forOp.promoteIfSingleIteration(rewriter)))
return failure();
return success();
}
- int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
+ int64_t tripCountEvenMultiple =
+ *constTripCount - (*constTripCount % unrollFactor);
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
int64_t stepUnrolledCst = stepCst * unrollFactor;
@@ -473,6 +496,185 @@ LogicalResult mlir::loopUnrollByFactor(
return success();
}
+/// Check if bounds of all inner loops are defined outside of `forOp`
+/// and return false if not.
+static bool areInnerBoundsInvariant(scf::ForOp forOp) {
+ auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+ if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+ !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+ !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+ return !walkResult.wasInterrupted();
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
+ uint64_t unrollJamFactor) {
+ assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
+
+ if (unrollJamFactor == 1)
+ return success();
+
+ // If any control operand of any inner loop of `forOp` is defined within
+ // `forOp`, no unroll jam.
+ if (!areInnerBoundsInvariant(forOp)) {
+ LDBG("failed to unroll and jam: inner bounds are not invariant");
+ return failure();
+ }
+
+ // Currently, for operations with results are not supported.
+ if (forOp->getNumResults() > 0) {
+ LDBG("failed to unroll and jam: unsupported loop with results");
+ return failure();
+ }
+
+ // Currently, only constant trip count that divided by the unroll factor is
+ // supported.
+ std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+ if (!tripCount) {
+ // If the trip count is dynamic, do not unroll & jam.
+ LDBG("failed to unroll and jam: trip count could not be determined");
+ return failure();
+ }
+ if (unrollJamFactor > *tripCount) {
+ LDBG("unroll and jam factor is greater than trip count, set factor to trip "
+ "count");
+ unrollJamFactor = *tripCount;
+ } else if (*tripCount % unrollJamFactor != 0) {
+ LDBG("failed to unroll and jam: unsupported trip count that is not a "
+ "multiple of unroll jam factor");
+ return failure();
+ }
+
+ // Nothing in the loop body other than the terminator.
+ if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
+ return success();
+
+ // Gather all sub-blocks to jam upon the loop being unrolled.
+ JamBlockGatherer<scf::ForOp> jbg;
+ jbg.walk(forOp);
+ auto &subBlocks = jbg.subBlocks;
+
+ // Collect inner loops.
+ SmallVector<scf::ForOp, 4> innerLoops;
+ forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+
+ // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
+ // iteration. There are (`unrollJamFactor` - 1) iterations.
+ SmallVector<IRMapping, 4> operandMaps(unrollJamFactor - 1);
+
+ // For any loop with iter_args, replace it with a new loop that has
+ // `unrollJamFactor` copies of its iterOperands, iter_args and yield
+ // operands.
+ SmallVector<scf::ForOp, 4> newInnerLoops;
+ IRRewriter rewriter(forOp.getContext());
+ for (scf::ForOp oldForOp : innerLoops) {
+ SmallVector<Value> dupIterOperands, dupYieldOperands;
+ ValueRange oldIterOperands = oldForOp.getInits();
+ ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
+ ValueRange oldYieldOperands =
+ cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
+ // Get additional iterOperands, iterArgs, and yield operands. We will
+ // fix iterOperands and yield operands after cloning of sub-blocks.
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
+ dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
+ }
+ // Create a new loop with additional iterOperands, iter_args and yield
+ // operands. This new loop will take the loop body of the original loop.
+ bool forOpReplaced = oldForOp == forOp;
+ scf::ForOp newForOp =
+ cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
+ rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
+ [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+ return dupYieldOperands;
+ }));
+ newInnerLoops.push_back(newForOp);
+ // `forOp` has been replaced with a new loop.
+ if (forOpReplaced)
+ forOp = newForOp;
+ // Update `operandMaps` for `newForOp` iterArgs and results.
+ ValueRange newIterArgs = newForOp.getRegionIterArgs();
+ unsigned oldNumIterArgs = oldIterArgs.size();
+ ValueRange newResults = newForOp.getResults();
+ unsigned oldNumResults = newResults.size() / unrollJamFactor;
+ assert(oldNumIterArgs == oldNumResults &&
+ "oldNumIterArgs must be the same as oldNumResults");
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+ // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
+ // results. Update `operandMaps[i - 1]` to map old iterArgs and results
+ // to those in the `i`th new set.
+ operandMaps[i - 1].map(newIterArgs[j],
+ newIterArgs[i * oldNumIterArgs + j]);
+ operandMaps[i - 1].map(newResults[j],
+ newResults[i * oldNumResults + j]);
+ }
+ }
+ }
+
+ // Scale the step of loop being unroll-jammed by the unroll-jam factor.
+ rewriter.setInsertionPoint(forOp);
+ int64_t step = forOp.getConstantStep()->getSExtValue();
+ auto newStep = rewriter.createOrFold<arith::MulIOp>(
+ forOp.getLoc(), forOp.getStep(),
+ rewriter.createOrFold<arith::ConstantOp>(
+ forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
+ forOp.setStep(newStep);
+ auto forOpIV = forOp.getInductionVar();
+
+ // Unroll and jam (appends unrollJamFactor - 1 additional copies).
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ for (auto &subBlock : subBlocks) {
+ // Builder to insert unroll-jammed bodies. Insert right at the end of
+ // sub-block.
+ OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
+
+ // If the induction variable is used, create a remapping to the value for
+ // this unrolled instance.
+ if (!forOpIV.use_empty()) {
+ // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
+ auto ivTag = builder.createOrFold<arith::ConstantOp>(
+ forOp.getLoc(), builder.getIndexAttr(step * i));
+ auto ivUnroll =
+ builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
+ operandMaps[i - 1].map(forOpIV, ivUnroll);
+ }
+ // Clone the sub-block being unroll-jammed.
+ for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
+ builder.clone(*it, operandMaps[i - 1]);
+ }
+ // Fix iterOperands and yield op operands of newly created loops.
+ for (auto newForOp : newInnerLoops) {
+ unsigned oldNumIterOperands =
+ newForOp.getNumRegionIterArgs() / unrollJamFactor;
+ unsigned numControlOperands = newForOp.getNumControlOperands();
+ auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+ unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
+ assert(oldNumIterOperands == oldNumYieldOperands &&
+ "oldNumIterOperands must be the same as oldNumYieldOperands");
+ for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+ // The `i`th duplication of an old iterOperand or yield op operand
+ // needs to be replaced with a mapped value from `operandMaps[i - 1]`
+ // if such mapped value exists.
+ newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+ operandMaps[i - 1].lookupOrDefault(
+ ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/94142
More information about the Mlir-commits
mailing list