[Mlir-commits] [mlir] Introduce new Unroll And Jam loop transform for SCF/Affine loops (PR #94142)
Aviad Cohen
llvmlistbot at llvm.org
Sat Jun 1 22:40:32 PDT 2024
https://github.com/AviadCo created https://github.com/llvm/llvm-project/pull/94142
Unroll And Jam was supported in affine dialect long time ago using pass. This commit exposes the pattern using transform and in addition adds partial support for SCF loops.
>From f52e7b2ae08ad24c607ab38757e8be7e5bad5c67 Mon Sep 17 00:00:00 2001
From: Aviad Cohen <aviad.cohen2 at mobileye.com>
Date: Wed, 29 May 2024 11:25:39 +0300
Subject: [PATCH] Introduce new Unroll And Jam loop transform for SCF/Affine
loops
Unroll And Jam was supported in affine dialect long time ago using pass.
This commit exposes the pattern using transform and in addition adds
partial support for SCF loops.
---
.../SCF/TransformOps/SCFTransformOps.td | 36 ++-
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 8 +
mlir/include/mlir/IR/LoopUtils.h | 56 +++++
mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp | 31 +--
.../SCF/TransformOps/SCFTransformOps.cpp | 22 ++
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 216 +++++++++++++++++-
.../Dialect/SCF/transform-ops-invalid.mlir | 99 ++++++++
mlir/test/Dialect/SCF/transform-ops.mlir | 215 +++++++++++++++++
8 files changed, 646 insertions(+), 37 deletions(-)
create mode 100644 mlir/include/mlir/IR/LoopUtils.h
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 5eefe2664d0a1..29651bed296e5 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -243,7 +243,41 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
in the return. If all the operations referred to by the `target` operand
unroll properly, the transform succeeds. Otherwise the transform produces a
- silencebale failure.
+ silenceable failure.
+
+ Does not return handles as the operation may result in the loop being
+ removed after a full unrolling.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ ConfinedAttr<I64Attr, [IntPositive]>:$factor);
+
+ let assemblyFormat = "$target attr-dict `:` type($target)";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
+def LoopUnrollAndJamOp : Op<Transform_Dialect, "loop.unroll_and_jam",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait]> {
+ let summary = "Unrolls and jam the given loop with the given unroll factor";
+ let description = [{
+ Unrolls & jams each loop associated with the given handle to have up to the given
+ number of loop body copies per iteration. If the unroll factor is larger
+ than the loop trip count, the latter is used as the unroll factor instead.
+
+ #### Return modes
+
+ This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
+ in the return. If all the operations referred to by the `target` operand
+ unroll properly, the transform succeeds. Otherwise the transform produces a
+ silenceable failure.
Does not return handles as the operation may result in the loop being
removed after a full unrolling.
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index bc09cc7f7fa5e..ad3113c960d37 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -120,6 +120,14 @@ LogicalResult loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
+/// Unrolls and jam this for operation by the specified unroll factor. Returns
+/// failure if the loop cannot be unrolled either due to restrictions or due to
+/// invalid unroll factors. In case of unroll factor of 1, the function bails
+/// out without doing anything (returns success). Currently, only constant trip
+/// count that are divided by the unroll factor is supported. Currently, for
+/// operations with results are not supported.
+LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor);
+
/// Tile a nest of standard for loops rooted at `rootForOp` by finding such
/// parametric tile sizes that the outer loops have a fixed number of iterations
/// as defined in `sizes`.
diff --git a/mlir/include/mlir/IR/LoopUtils.h b/mlir/include/mlir/IR/LoopUtils.h
new file mode 100644
index 0000000000000..410eec223e637
--- /dev/null
+++ b/mlir/include/mlir/IR/LoopUtils.h
@@ -0,0 +1,56 @@
+//===- LoopUtils.h - LoopUtils Support ---------------------*- C++
+//-*-=============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains definitions for the action framework. This framework
+// allows for external entities to control certain actions taken by the compiler
+// by registering handler functions.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_LOOP_UTILS_H
+#define MLIR_IR_LOOP_UTILS_H
+
+#include "mlir/IR/BuiltinOps.h"
+
+namespace mlir {
+
+// Gathers all maximal sub-blocks of operations that do not themselves
+// include a `OpTy` (an operation could have a descendant `OpTy` though
+// in its tree). Ignore the block terminators.
+template <typename OpTy>
+struct JamBlockGatherer {
+ // Store iterators to the first and last op of each sub-block found.
+ llvm::SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+
+ // This is a linear time walk.
+ void walk(Operation *op) {
+ for (auto ®ion : op->getRegions())
+ for (auto &block : region)
+ walk(block);
+ }
+
+ void walk(Block &block) {
+ assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
+ "expected block to have a terminator");
+ for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
+ auto subBlockStart = it;
+ while (it != e && !isa<OpTy>(&*it))
+ ++it;
+ if (it != subBlockStart)
+ subBlocks.emplace_back(subBlockStart, std::prev(it));
+ // Process all for ops that appear next.
+ while (it != e && isa<OpTy>(&*it))
+ walk(&*it++);
+ }
+ }
+};
+
+} // namespace mlir
+
+#endif // MLIR_IR_LOOP_UTILS_H
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 268050a30e002..6d35f06faa45b 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/LoopUtils.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/MapVector.h"
@@ -1102,34 +1103,6 @@ static bool areInnerBoundsInvariant(AffineForOp forOp) {
return !walkResult.wasInterrupted();
}
-// Gathers all maximal sub-blocks of operations that do not themselves
-// include a for op (a operation could have a descendant for op though
-// in its tree). Ignore the block terminators.
-struct JamBlockGatherer {
- // Store iterators to the first and last op of each sub-block found.
- std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
-
- // This is a linear time walk.
- void walk(Operation *op) {
- for (auto ®ion : op->getRegions())
- for (auto &block : region)
- walk(block);
- }
-
- void walk(Block &block) {
- for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
- auto subBlockStart = it;
- while (it != e && !isa<AffineForOp>(&*it))
- ++it;
- if (it != subBlockStart)
- subBlocks.emplace_back(subBlockStart, std::prev(it));
- // Process all for ops that appear next.
- while (it != e && isa<AffineForOp>(&*it))
- walk(&*it++);
- }
- }
-};
-
/// Unrolls and jams this loop by the specified factor.
LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
uint64_t unrollJamFactor) {
@@ -1159,7 +1132,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
return failure();
// Gather all sub-blocks to jam upon the loop being unrolled.
- JamBlockGatherer jbg;
+ JamBlockGatherer<AffineForOp> jbg;
jbg.walk(forOp);
auto &subBlocks = jbg.subBlocks;
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 69f83d8bd70da..df5701a40dc1b 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -304,6 +304,28 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}
+//===----------------------------------------------------------------------===//
+// LoopUnrollAndJamOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *op,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ LogicalResult result(failure());
+ if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
+ result = loopUnrollJamByFactor(scfFor, getFactor());
+ else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
+ result = loopUnrollJamByFactor(affineFor, getFactor());
+
+ if (failed(result)) {
+ DiagnosedSilenceableFailure diag = emitSilenceableError()
+ << "failed to unroll and jam";
+ return diag;
+ }
+ return DiagnosedSilenceableFailure::success();
+}
+
//===----------------------------------------------------------------------===//
// LoopCoalesceOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 6658cca03eba7..07ae3507c3ded 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/LoopUtils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Support/MathExtras.h"
@@ -26,9 +27,15 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include <cstdint>
using namespace mlir;
+#define DEBUG_TYPE "scf-utils"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
namespace {
// This structure is to pass and return sets of loop parameters without
// confusing the order.
@@ -296,6 +303,23 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
return builder.create<arith::DivUIOp>(loc, sum, divisor);
}
+static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
+ std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
+ std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
+ std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
+ if (lbCstOp && ubCstOp && stepCstOp) {
+ // Constant loop bounds computation.
+ int64_t lbCst = lbCstOp.value();
+ int64_t ubCst = ubCstOp.value();
+ int64_t stepCst = stepCstOp.value();
+ assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
+ "expected positive loop bounds and step");
+ return mlir::ceilDiv(ubCst - lbCst, stepCst);
+ }
+
+ return {};
+}
+
/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -375,22 +399,21 @@ LogicalResult mlir::loopUnrollByFactor(
std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
- if (lbCstOp && ubCstOp && stepCstOp) {
+ auto constTripCount = getConstantTripCount(forOp);
+ if (constTripCount) {
// Constant loop bounds computation.
int64_t lbCst = lbCstOp.value();
int64_t ubCst = ubCstOp.value();
int64_t stepCst = stepCstOp.value();
- assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
- "expected positive loop bounds and step");
- int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
-
if (unrollFactor == 1) {
- if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter)))
+ if (*constTripCount == 1 &&
+ failed(forOp.promoteIfSingleIteration(rewriter)))
return failure();
return success();
}
- int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
+ int64_t tripCountEvenMultiple =
+ *constTripCount - (*constTripCount % unrollFactor);
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
int64_t stepUnrolledCst = stepCst * unrollFactor;
@@ -473,6 +496,185 @@ LogicalResult mlir::loopUnrollByFactor(
return success();
}
+/// Check if bounds of all inner loops are defined outside of `forOp`
+/// and return false if not.
+static bool areInnerBoundsInvariant(scf::ForOp forOp) {
+ auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+ if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+ !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+ !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+ return !walkResult.wasInterrupted();
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
+ uint64_t unrollJamFactor) {
+ assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
+
+ if (unrollJamFactor == 1)
+ return success();
+
+ // If any control operand of any inner loop of `forOp` is defined within
+ // `forOp`, no unroll jam.
+ if (!areInnerBoundsInvariant(forOp)) {
+ LDBG("failed to unroll and jam: inner bounds are not invariant");
+ return failure();
+ }
+
+ // Currently, for operations with results are not supported.
+ if (forOp->getNumResults() > 0) {
+ LDBG("failed to unroll and jam: unsupported loop with results");
+ return failure();
+ }
+
+ // Currently, only constant trip count that divided by the unroll factor is
+ // supported.
+ std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+ if (!tripCount) {
+ // If the trip count is dynamic, do not unroll & jam.
+ LDBG("failed to unroll and jam: trip count could not be determined");
+ return failure();
+ }
+ if (unrollJamFactor > *tripCount) {
+ LDBG("unroll and jam factor is greater than trip count, set factor to trip "
+ "count");
+ unrollJamFactor = *tripCount;
+ } else if (*tripCount % unrollJamFactor != 0) {
+ LDBG("failed to unroll and jam: unsupported trip count that is not a "
+ "multiple of unroll jam factor");
+ return failure();
+ }
+
+ // Nothing in the loop body other than the terminator.
+ if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
+ return success();
+
+ // Gather all sub-blocks to jam upon the loop being unrolled.
+ JamBlockGatherer<scf::ForOp> jbg;
+ jbg.walk(forOp);
+ auto &subBlocks = jbg.subBlocks;
+
+ // Collect inner loops.
+ SmallVector<scf::ForOp, 4> innerLoops;
+ forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+
+ // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
+ // iteration. There are (`unrollJamFactor` - 1) iterations.
+ SmallVector<IRMapping, 4> operandMaps(unrollJamFactor - 1);
+
+ // For any loop with iter_args, replace it with a new loop that has
+ // `unrollJamFactor` copies of its iterOperands, iter_args and yield
+ // operands.
+ SmallVector<scf::ForOp, 4> newInnerLoops;
+ IRRewriter rewriter(forOp.getContext());
+ for (scf::ForOp oldForOp : innerLoops) {
+ SmallVector<Value> dupIterOperands, dupYieldOperands;
+ ValueRange oldIterOperands = oldForOp.getInits();
+ ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
+ ValueRange oldYieldOperands =
+ cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
+ // Get additional iterOperands, iterArgs, and yield operands. We will
+ // fix iterOperands and yield operands after cloning of sub-blocks.
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
+ dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
+ }
+ // Create a new loop with additional iterOperands, iter_args and yield
+ // operands. This new loop will take the loop body of the original loop.
+ bool forOpReplaced = oldForOp == forOp;
+ scf::ForOp newForOp =
+ cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
+ rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
+ [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+ return dupYieldOperands;
+ }));
+ newInnerLoops.push_back(newForOp);
+ // `forOp` has been replaced with a new loop.
+ if (forOpReplaced)
+ forOp = newForOp;
+ // Update `operandMaps` for `newForOp` iterArgs and results.
+ ValueRange newIterArgs = newForOp.getRegionIterArgs();
+ unsigned oldNumIterArgs = oldIterArgs.size();
+ ValueRange newResults = newForOp.getResults();
+ unsigned oldNumResults = newResults.size() / unrollJamFactor;
+ assert(oldNumIterArgs == oldNumResults &&
+ "oldNumIterArgs must be the same as oldNumResults");
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+ // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
+ // results. Update `operandMaps[i - 1]` to map old iterArgs and results
+ // to those in the `i`th new set.
+ operandMaps[i - 1].map(newIterArgs[j],
+ newIterArgs[i * oldNumIterArgs + j]);
+ operandMaps[i - 1].map(newResults[j],
+ newResults[i * oldNumResults + j]);
+ }
+ }
+ }
+
+ // Scale the step of loop being unroll-jammed by the unroll-jam factor.
+ rewriter.setInsertionPoint(forOp);
+ int64_t step = forOp.getConstantStep()->getSExtValue();
+ auto newStep = rewriter.createOrFold<arith::MulIOp>(
+ forOp.getLoc(), forOp.getStep(),
+ rewriter.createOrFold<arith::ConstantOp>(
+ forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
+ forOp.setStep(newStep);
+ auto forOpIV = forOp.getInductionVar();
+
+ // Unroll and jam (appends unrollJamFactor - 1 additional copies).
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ for (auto &subBlock : subBlocks) {
+ // Builder to insert unroll-jammed bodies. Insert right at the end of
+ // sub-block.
+ OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
+
+ // If the induction variable is used, create a remapping to the value for
+ // this unrolled instance.
+ if (!forOpIV.use_empty()) {
+ // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
+ auto ivTag = builder.createOrFold<arith::ConstantOp>(
+ forOp.getLoc(), builder.getIndexAttr(step * i));
+ auto ivUnroll =
+ builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
+ operandMaps[i - 1].map(forOpIV, ivUnroll);
+ }
+ // Clone the sub-block being unroll-jammed.
+ for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
+ builder.clone(*it, operandMaps[i - 1]);
+ }
+ // Fix iterOperands and yield op operands of newly created loops.
+ for (auto newForOp : newInnerLoops) {
+ unsigned oldNumIterOperands =
+ newForOp.getNumRegionIterArgs() / unrollJamFactor;
+ unsigned numControlOperands = newForOp.getNumControlOperands();
+ auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+ unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
+ assert(oldNumIterOperands == oldNumYieldOperands &&
+ "oldNumIterOperands must be the same as oldNumYieldOperands");
+ for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+ // The `i`th duplication of an old iterOperand or yield op operand
+ // needs to be replaced with a mapped value from `operandMaps[i - 1]`
+ // if such mapped value exists.
+ newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+ operandMaps[i - 1].lookupOrDefault(
+ newForOp.getOperand(numControlOperands + j)));
+ yieldOp.setOperand(
+ i * oldNumYieldOperands + j,
+ operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
+ }
+ }
+ }
+
+ // Promote the loop body up if this has turned into a single iteration loop.
+ (void)forOp.promoteIfSingleIteration(rewriter);
+ return success();
+}
+
/// Transform a loop with a strictly positive step
/// for %i = %lb to %ub step %s
/// into a 0-based loop with step 1
diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
index a01a7995d6189..742b8a2861839 100644
--- a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
@@ -41,6 +41,105 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @loop_unroll_and_jam_unsupported_trip_count_not_multiple_of_factor() {
+ %c0 = arith.constant 0 : index
+ %c40 = arith.constant 40 : index
+ %c2 = arith.constant 2 : index
+ scf.for %i = %c0 to %c40 step %c2 {
+ arith.addi %i, %i : index
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ // expected-error @below {{failed to unroll and jam}}
+ transform.loop.unroll_and_jam %1 { factor = 3 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @loop_unroll_and_jam_unsupported_loop_with_results() -> index {
+ %c0 = arith.constant 0 : index
+ %c40 = arith.constant 40 : index
+ %c2 = arith.constant 2 : index
+ %sum = scf.for %i = %c0 to %c40 step %c2 iter_args(%does_not_alias_aggregated = %c0) -> (index) {
+ %sum = arith.addi %i, %i : index
+ scf.yield %sum : index
+ }
+ return %sum : index
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ // expected-error @below {{failed to unroll and jam}}
+ transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @loop_unroll_and_jam_unsupported_dynamic_trip_count(%arg0: memref<96x128xi8, 3>, %arg1: memref<128xi8, 3>) {
+ %c96 = arith.constant 96 : index
+ %c1 = arith.constant 1 : index
+ %c128 = arith.constant 128 : index
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ scf.for %arg4 = %c0 to %c4 step %c1 {
+ scf.for %arg2 = %c0 to %c128 step %arg4 {
+ %3 = memref.load %arg1[%arg2] : memref<128xi8, 3>
+ %sum = scf.for %arg3 = %c0 to %c96 step %c1 iter_args(%does_not_alias_aggregated = %3) -> (i8) {
+ %2 = memref.load %arg0[%arg3, %arg2] : memref<96x128xi8, 3>
+ %4 = arith.addi %2, %3 : i8
+ scf.yield %4 : i8
+ }
+ memref.store %sum, %arg1[%arg2] : memref<128xi8, 3>
+ }
+ scf.yield
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["memref.store"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ // expected-error @below {{failed to unroll and jam}}
+ transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @loop_unroll_and_jam_unsupported_dynamic_trip_count(%upper_bound: index) {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ scf.for %i = %c0 to %upper_bound step %c2 {
+ arith.addi %i, %i : index
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ // expected-error @below {{failed to unroll and jam}}
+ transform.loop.unroll_and_jam %1 { factor = 2 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
func.func private @cond() -> i1
func.func private @body()
diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
index a4daa86583c3d..a23a6f9e816e8 100644
--- a/mlir/test/Dialect/SCF/transform-ops.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops.mlir
@@ -168,6 +168,221 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @loop_unroll_and_jam_op
+func.func @loop_unroll_and_jam_op() {
+ // CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL_1:.*]] = arith.constant 40 : index
+ // CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
+ // CHECK: %[[FACTOR:.*]] = arith.constant 4 : index
+ // CHECK: %[[STEP:.*]] = arith.constant 8 : index
+ %c0 = arith.constant 0 : index
+ %c40 = arith.constant 40 : index
+ %c2 = arith.constant 2 : index
+ // CHECK: scf.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[STEP]] {
+ scf.for %i = %c0 to %c40 step %c2 {
+ // CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : index
+ // CHECK: %[[VAL_7:.*]] = arith.constant 2 : index
+ // CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_5]], %[[VAL_7]] : index
+ // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_8]] : index
+ // CHECK: %[[VAL_10:.*]] = arith.constant 4 : index
+ // CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_5]], %[[VAL_10]] : index
+ // CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_11]] : index
+ // CHECK: %[[VAL_13:.*]] = arith.constant 6 : index
+ // CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_5]], %[[VAL_13]] : index
+ // CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_14]] : index
+ arith.addi %i, %i : index
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+// CHECK: %[[VAL_0:.*]]: memref<96x128xi8, 3>, %[[VAL_1:.*]]: memref<128xi8, 3>) {
+func.func private @loop_unroll_and_jam_op(%arg0: memref<96x128xi8, 3>, %arg1: memref<128xi8, 3>) {
+ // CHECK: %[[VAL_2:.*]] = arith.constant 96 : index
+ // CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
+ // CHECK: %[[VAL_4:.*]] = arith.constant 128 : index
+ // CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL_6:.*]] = arith.constant 4 : index
+ // CHECK: %[[VAL_7:.*]] = arith.constant 4 : index
+ %c96 = arith.constant 96 : index
+ %c1 = arith.constant 1 : index
+ %c128 = arith.constant 128 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: scf.for %[[VAL_8:.*]] = %[[VAL_5]] to %[[VAL_4]] step %[[VAL_7]] {
+ scf.for %arg2 = %c0 to %c128 step %c1 {
+ // CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_8]]] : memref<128xi8, 3>
+ // CHECK: %[[VAL_10:.*]] = arith.constant 1 : index
+ // CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_8]], %[[VAL_10]] : index
+ // CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_11]]] : memref<128xi8, 3>
+ // CHECK: %[[VAL_13:.*]] = arith.constant 2 : index
+ // CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_8]], %[[VAL_13]] : index
+ // CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_14]]] : memref<128xi8, 3>
+ // CHECK: %[[VAL_16:.*]] = arith.constant 3 : index
+ // CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_8]], %[[VAL_16]] : index
+ // CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_17]]] : memref<128xi8, 3>
+ %3 = memref.load %arg1[%arg2] : memref<128xi8, 3>
+ // CHECK: %[[VAL_19:.*]]:4 = scf.for %[[VAL_20:.*]] = %[[VAL_5]] to %[[VAL_2]] step %[[VAL_3]] iter_args(%[[VAL_21:.*]] = %[[VAL_9]], %[[VAL_22:.*]] = %[[VAL_12]], %[[VAL_23:.*]] = %[[VAL_15]], %[[VAL_24:.*]] = %[[VAL_18]]) -> (i8, i8, i8, i8) {
+ %sum = scf.for %arg3 = %c0 to %c96 step %c1 iter_args(%does_not_alias_aggregated = %3) -> (i8) {
+ // CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[VAL_8]]] : memref<96x128xi8, 3>
+ // CHECK: %[[VAL_26:.*]] = arith.addi %[[VAL_25]], %[[VAL_9]] : i8
+ // CHECK: %[[VAL_27:.*]] = arith.constant 1 : index
+ // CHECK: %[[VAL_28:.*]] = arith.addi %[[VAL_8]], %[[VAL_27]] : index
+ // CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[VAL_28]]] : memref<96x128xi8, 3>
+ // CHECK: %[[VAL_30:.*]] = arith.addi %[[VAL_29]], %[[VAL_12]] : i8
+ // CHECK: %[[VAL_31:.*]] = arith.constant 2 : index
+ // CHECK: %[[VAL_32:.*]] = arith.addi %[[VAL_8]], %[[VAL_31]] : index
+ // CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[VAL_32]]] : memref<96x128xi8, 3>
+ // CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_33]], %[[VAL_15]] : i8
+ // CHECK: %[[VAL_35:.*]] = arith.constant 3 : index
+ // CHECK: %[[VAL_36:.*]] = arith.addi %[[VAL_8]], %[[VAL_35]] : index
+ // CHECK: %[[VAL_37:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[VAL_36]]] : memref<96x128xi8, 3>
+ // CHECK: %[[VAL_38:.*]] = arith.addi %[[VAL_37]], %[[VAL_18]] : i8
+ %2 = memref.load %arg0[%arg3, %arg2] : memref<96x128xi8, 3>
+ %4 = arith.addi %2, %3 : i8
+ // CHECK: scf.yield %[[VAL_26]], %[[VAL_30]], %[[VAL_34]], %[[VAL_38]] : i8, i8, i8, i8
+ scf.yield %4 : i8
+ }
+ memref.store %sum, %arg1[%arg2] : memref<128xi8, 3>
+ // CHECK: memref.store %[[VAL_39:.*]]#0, %[[VAL_1]]{{\[}}%[[VAL_8]]] : memref<128xi8, 3>
+ // CHECK: %[[VAL_40:.*]] = arith.constant 1 : index
+ // CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_8]], %[[VAL_40]] : index
+ // CHECK: memref.store %[[VAL_39]]#1, %[[VAL_1]]{{\[}}%[[VAL_41]]] : memref<128xi8, 3>
+ // CHECK: %[[VAL_42:.*]] = arith.constant 2 : index
+ // CHECK: %[[VAL_43:.*]] = arith.addi %[[VAL_8]], %[[VAL_42]] : index
+ // CHECK: memref.store %[[VAL_39]]#2, %[[VAL_1]]{{\[}}%[[VAL_43]]] : memref<128xi8, 3>
+ // CHECK: %[[VAL_44:.*]] = arith.constant 3 : index
+ // CHECK: %[[VAL_45:.*]] = arith.addi %[[VAL_8]], %[[VAL_44]] : index
+ // CHECK: memref.store %[[VAL_39]]#3, %[[VAL_1]]{{\[}}%[[VAL_45]]] : memref<128xi8, 3>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["memref.store"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+func.func @loop_unroll_and_jam_op() {
+ // CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
+ // CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
+ // CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
+ // CHECK: %[[VAL_4:.*]] = arith.constant 4 : index
+ // CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_0]], %[[VAL_0]] : index
+ // CHECK: %[[VAL_6:.*]] = arith.constant 2 : index
+ // CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_0]], %[[VAL_6]] : index
+ // CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_7]], %[[VAL_7]] : index
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c2 = arith.constant 2 : index
+ scf.for %i = %c0 to %c4 step %c2 {
+ arith.addi %i, %i : index
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+func.func @loop_unroll_and_jam_op() {
+ // CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
+ // CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c2 = arith.constant 2 : index
+ // CHECK: scf.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] {
+ scf.for %i = %c0 to %c4 step %c2 {
+ scf.yield
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ transform.loop.unroll_and_jam %1 { factor = 2 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+// CHECK: %[[VAL_0:.*]]: memref<21x30xf32, 1>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32) {
+func.func @loop_unroll_and_jam_op(%arg0: memref<21x30xf32, 1>, %init : f32, %init1 : f32) {
+ // CHECK: %[[VAL_3:.*]] = arith.constant 20 : index
+ // CHECK: %[[VAL_4:.*]]:2 = affine.for %[[VAL_5:.*]] = 0 to 20 step 2 iter_args(%[[VAL_6:.*]] = %[[VAL_1]], %[[VAL_7:.*]] = %[[VAL_1]]) -> (f32, f32) {
+ %0 = affine.for %arg3 = 0 to 21 iter_args(%arg4 = %init) -> (f32) {
+ // CHECK: %[[VAL_8:.*]]:2 = affine.for %[[VAL_9:.*]] = 0 to 30 iter_args(%[[VAL_10:.*]] = %[[VAL_2]], %[[VAL_11:.*]] = %[[VAL_2]]) -> (f32, f32) {
+ %1 = affine.for %arg5 = 0 to 30 iter_args(%arg6 = %init1) -> (f32) {
+ // CHECK: %[[VAL_12:.*]] = affine.load %[[VAL_0]]{{\[}}%[[VAL_5]], %[[VAL_9]]] : memref<21x30xf32, 1>
+ // CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_10]], %[[VAL_12]] : f32
+ // CHECK: %[[VAL_14:.*]] = affine.apply #map(%[[VAL_5]])
+ // CHECK: %[[VAL_15:.*]] = affine.load %[[VAL_0]]{{\[}}%[[VAL_14]], %[[VAL_9]]] : memref<21x30xf32, 1>
+ // CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_11]], %[[VAL_15]] : f32
+ // CHECK: affine.yield %[[VAL_13]], %[[VAL_16]] : f32, f32
+ %3 = affine.load %arg0[%arg3, %arg5] : memref<21x30xf32, 1>
+ %4 = arith.addf %arg6, %3 : f32
+ affine.yield %4 : f32
+ }
+ // CHECK: %[[VAL_17:.*]] = arith.mulf %[[VAL_6]], %[[VAL_18:.*]]#0 : f32
+ // CHECK: %[[VAL_19:.*]] = affine.apply #map(%[[VAL_5]])
+ // CHECK: %[[VAL_20:.*]] = arith.mulf %[[VAL_7]], %[[VAL_18]]#1 : f32
+ // CHECK: affine.yield %[[VAL_17]], %[[VAL_20]] : f32, f32
+ // CHECK: }
+ // CHECK: %[[VAL_21:.*]] = arith.mulf %[[VAL_22:.*]]#0, %[[VAL_22]]#1 : f32
+ // CHECK: %[[VAL_23:.*]] = affine.for %[[VAL_24:.*]] = 0 to 30 iter_args(%[[VAL_25:.*]] = %[[VAL_2]]) -> (f32) {
+ // CHECK: %[[VAL_26:.*]] = affine.load %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_24]]] : memref<21x30xf32, 1>
+ // CHECK: %[[VAL_27:.*]] = arith.addf %[[VAL_25]], %[[VAL_26]] : f32
+ // CHECK: affine.yield %[[VAL_27]] : f32
+ %2 = arith.mulf %arg4, %1 : f32
+ affine.yield %2 : f32
+ }
+ // CHECK: %[[VAL_28:.*]] = arith.mulf %[[VAL_21]], %[[VAL_29:.*]] : f32
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "affine.for"} : (!transform.any_op) -> !transform.op<"affine.for">
+ %2 = transform.get_parent_op %1 {op_name = "affine.for"} : (!transform.op<"affine.for">) -> !transform.op<"affine.for">
+ transform.loop.unroll_and_jam %2 { factor = 2 } : !transform.op<"affine.for">
+ transform.yield
+ }
+}
+
+// -----
+
func.func @loop_unroll_op() {
%c0 = arith.constant 0 : index
%c42 = arith.constant 42 : index
More information about the Mlir-commits
mailing list