[Mlir-commits] [mlir] cb8bd6f - Introduce new Unroll And Jam loop transform for SCF/Affine loops (#94142)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 21 05:48:14 PDT 2024
Author: Aviad Cohen
Date: 2024-06-21T15:48:11+03:00
New Revision: cb8bd6f77235ee15bbe549c8f3486392b8966447
URL: https://github.com/llvm/llvm-project/commit/cb8bd6f77235ee15bbe549c8f3486392b8966447
DIFF: https://github.com/llvm/llvm-project/commit/cb8bd6f77235ee15bbe549c8f3486392b8966447.diff
LOG: Introduce new Unroll And Jam loop transform for SCF/Affine loops (#94142)
Unroll And Jam was supported in affine dialect long time ago using pass.
This commit exposes the pattern using transform and in addition adds
partial support for SCF loops.
Added:
Modified:
mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
mlir/include/mlir/Dialect/SCF/Utils/Utils.h
mlir/include/mlir/Interfaces/LoopLikeInterface.h
mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/test/Dialect/SCF/transform-ops-invalid.mlir
mlir/test/Dialect/SCF/transform-ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
index 3d7fe7b0f093f..7bf914f6456ce 100644
--- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
+++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
@@ -269,7 +269,41 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
in the return. If all the operations referred to by the `target` operand
unroll properly, the transform succeeds. Otherwise the transform produces a
- silencebale failure.
+ silenceable failure.
+
+ Does not return handles as the operation may result in the loop being
+ removed after a full unrolling.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ ConfinedAttr<I64Attr, [IntPositive]>:$factor);
+
+ let assemblyFormat = "$target attr-dict `:` type($target)";
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure applyToOne(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::Operation *target,
+ ::mlir::transform::ApplyToEachResultList &results,
+ ::mlir::transform::TransformState &state);
+ }];
+}
+
+def LoopUnrollAndJamOp : Op<Transform_Dialect, "loop.unroll_and_jam",
+ [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+ TransformOpInterface, TransformEachOpTrait]> {
+ let summary = "Unrolls and jam the given loop with the given unroll factor";
+ let description = [{
+ Unrolls & jams each loop associated with the given handle to have up to the given
+ number of loop body copies per iteration. If the unroll factor is larger
+ than the loop trip count, the latter is used as the unroll factor instead.
+
+ #### Return modes
+
+ This operation ignores non-`scf.for`, non-`affine.for` ops and drops them
+ in the return. If all the operations referred to by the `target` operand
+ unroll properly, the transform succeeds. Otherwise the transform produces a
+ silenceable failure.
Does not return handles as the operation may result in the loop being
removed after a full unrolling.
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index da3fe3ceb86be..fea151b393152 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -120,6 +120,14 @@ LogicalResult loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
+/// Unrolls and jams this `scf.for` operation by the specified unroll factor.
+/// Returns failure if the loop cannot be unrolled either due to restrictions or
+/// due to invalid unroll factors. In case of unroll factor of 1, the function
+/// bails out without doing anything (returns success). Currently, only constant
+/// trip count that are divided by the unroll factor is supported. Currently,
+/// for operations with results are not supported.
+LogicalResult loopUnrollJamByFactor(scf::ForOp forOp, uint64_t unrollFactor);
+
/// Transform a loop with a strictly positive step
/// for %i = %lb to %ub step %s
/// into a 0-based loop with step 1
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 42609e824c86a..9925fc6ce6ca9 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -48,6 +48,39 @@ class HasParallelRegion : public TraitBase<ConcreteType, HasParallelRegion> {
};
} // namespace OpTrait
+
+// Gathers all maximal sub-blocks of operations that do not themselves
+// include a `OpTy` (an operation could have a descendant `OpTy` though
+// in its tree). Ignores the block terminators.
+template <typename OpTy>
+struct JamBlockGatherer {
+ // Store iterators to the first and last op of each sub-block found.
+ SmallVector<std::pair<Block::iterator, Block::iterator>> subBlocks;
+
+ // This is a linear time walk.
+ void walk(Operation *op) {
+ for (Region ®ion : op->getRegions())
+ for (Block &block : region)
+ walk(block);
+ }
+
+ void walk(Block &block) {
+ assert(!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>() &&
+ "expected block to have a terminator");
+ for (Block::iterator it = block.begin(), e = std::prev(block.end());
+ it != e;) {
+ Block::iterator subBlockStart = it;
+ while (it != e && !isa<OpTy>(&*it))
+ ++it;
+ if (it != subBlockStart)
+ subBlocks.emplace_back(subBlockStart, std::prev(it));
+ // Process all for ops that appear next.
+ while (it != e && isa<OpTy>(&*it))
+ walk(&*it++);
+ }
+ }
+};
+
} // namespace mlir
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index c934f229bb6c4..605737542e9fc 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -37,16 +37,6 @@ using namespace affine;
using namespace presburger;
using llvm::SmallMapVector;
-namespace {
-// This structure is to pass and return sets of loop parameters without
-// confusing the order.
-struct LoopParams {
- Value lowerBound;
- Value upperBound;
- Value step;
-};
-} // namespace
-
/// Computes the cleanup loop lower bound of the loop being unrolled with
/// the specified unroll factor; this bound will also be upper bound of the main
/// part of the unrolled loop. Computes the bound as an AffineMap with its
@@ -1101,34 +1091,6 @@ static bool areInnerBoundsInvariant(AffineForOp forOp) {
return !walkResult.wasInterrupted();
}
-// Gathers all maximal sub-blocks of operations that do not themselves
-// include a for op (a operation could have a descendant for op though
-// in its tree). Ignore the block terminators.
-struct JamBlockGatherer {
- // Store iterators to the first and last op of each sub-block found.
- std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
-
- // This is a linear time walk.
- void walk(Operation *op) {
- for (auto ®ion : op->getRegions())
- for (auto &block : region)
- walk(block);
- }
-
- void walk(Block &block) {
- for (auto it = block.begin(), e = std::prev(block.end()); it != e;) {
- auto subBlockStart = it;
- while (it != e && !isa<AffineForOp>(&*it))
- ++it;
- if (it != subBlockStart)
- subBlocks.emplace_back(subBlockStart, std::prev(it));
- // Process all for ops that appear next.
- while (it != e && isa<AffineForOp>(&*it))
- walk(&*it++);
- }
- }
-};
-
/// Unrolls and jams this loop by the specified factor.
LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
uint64_t unrollJamFactor) {
@@ -1158,7 +1120,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
return failure();
// Gather all sub-blocks to jam upon the loop being unrolled.
- JamBlockGatherer jbg;
+ JamBlockGatherer<AffineForOp> jbg;
jbg.walk(forOp);
auto &subBlocks = jbg.subBlocks;
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 6919de393b668..56ff2709a589e 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -348,12 +348,36 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
result = loopUnrollByFactor(scfFor, getFactor());
else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
result = loopUnrollByFactor(affineFor, getFactor());
+ else
+ return emitSilenceableError()
+ << "failed to unroll, incorrect type of payload";
+
+ if (failed(result))
+ return emitSilenceableError() << "failed to unroll";
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+//===----------------------------------------------------------------------===//
+// LoopUnrollAndJamOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::LoopUnrollAndJamOp::applyToOne(
+ transform::TransformRewriter &rewriter, Operation *op,
+ transform::ApplyToEachResultList &results,
+ transform::TransformState &state) {
+ LogicalResult result(failure());
+ if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
+ result = loopUnrollJamByFactor(scfFor, getFactor());
+ else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
+ result = loopUnrollJamByFactor(affineFor, getFactor());
+ else
+ return emitSilenceableError()
+ << "failed to unroll and jam, incorrect type of payload";
+
+ if (failed(result))
+ return emitSilenceableError() << "failed to unroll and jam";
- if (failed(result)) {
- DiagnosedSilenceableFailure diag = emitSilenceableError()
- << "failed to unroll";
- return diag;
- }
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index ff5e3a002263d..c0ee9d2afe91c 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -26,10 +26,16 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
+#include <cstdint>
using namespace mlir;
+#define DEBUG_TYPE "scf-utils"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
@@ -287,6 +293,25 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
return builder.create<arith::DivUIOp>(loc, sum, divisor);
}
+/// Returns the trip count of `forOp` if its' low bound, high bound and step are
+/// constants, or optional otherwise. Trip count is computed as ceilDiv(highBound
+/// - lowBound, step).
+static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
+ std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
+ std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
+ std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
+ if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value())
+ return {};
+
+ // Constant loop bounds computation.
+ int64_t lbCst = lbCstOp.value();
+ int64_t ubCst = ubCstOp.value();
+ int64_t stepCst = stepCstOp.value();
+ assert(lbCst >= 0 && ubCst >= 0 && stepCst > 0 &&
+ "expected positive loop bounds and step");
+ return llvm::divideCeilSigned(ubCst - lbCst, stepCst);
+}
+
/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -363,25 +388,21 @@ LogicalResult mlir::loopUnrollByFactor(
Value stepUnrolled;
bool generateEpilogueLoop = true;
- std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
- std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
- std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
- if (lbCstOp && ubCstOp && stepCstOp) {
+ std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
+ if (constTripCount) {
// Constant loop bounds computation.
- int64_t lbCst = lbCstOp.value();
- int64_t ubCst = ubCstOp.value();
- int64_t stepCst = stepCstOp.value();
- assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
- "expected positive loop bounds and step");
- int64_t tripCount = llvm::divideCeilSigned(ubCst - lbCst, stepCst);
-
+ int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
+ int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
+ int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
if (unrollFactor == 1) {
- if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter)))
+ if (*constTripCount == 1 &&
+ failed(forOp.promoteIfSingleIteration(rewriter)))
return failure();
return success();
}
- int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
+ int64_t tripCountEvenMultiple =
+ *constTripCount - (*constTripCount % unrollFactor);
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
int64_t stepUnrolledCst = stepCst * unrollFactor;
@@ -464,6 +485,185 @@ LogicalResult mlir::loopUnrollByFactor(
return success();
}
+/// Check if bounds of all inner loops are defined outside of `forOp`
+/// and return false if not.
+static bool areInnerBoundsInvariant(scf::ForOp forOp) {
+ auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+ if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+ !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+ !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+ return !walkResult.wasInterrupted();
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
+ uint64_t unrollJamFactor) {
+ assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
+
+ if (unrollJamFactor == 1)
+ return success();
+
+ // If any control operand of any inner loop of `forOp` is defined within
+ // `forOp`, no unroll jam.
+ if (!areInnerBoundsInvariant(forOp)) {
+ LDBG("failed to unroll and jam: inner bounds are not invariant");
+ return failure();
+ }
+
+ // Currently, for operations with results are not supported.
+ if (forOp->getNumResults() > 0) {
+ LDBG("failed to unroll and jam: unsupported loop with results");
+ return failure();
+ }
+
+ // Currently, only constant trip count that divided by the unroll factor is
+ // supported.
+ std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+ if (!tripCount.has_value()) {
+ // If the trip count is dynamic, do not unroll & jam.
+ LDBG("failed to unroll and jam: trip count could not be determined");
+ return failure();
+ }
+ if (unrollJamFactor > *tripCount) {
+ LDBG("unroll and jam factor is greater than trip count, set factor to trip "
+ "count");
+ unrollJamFactor = *tripCount;
+ } else if (*tripCount % unrollJamFactor != 0) {
+ LDBG("failed to unroll and jam: unsupported trip count that is not a "
+ "multiple of unroll jam factor");
+ return failure();
+ }
+
+ // Nothing in the loop body other than the terminator.
+ if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
+ return success();
+
+ // Gather all sub-blocks to jam upon the loop being unrolled.
+ JamBlockGatherer<scf::ForOp> jbg;
+ jbg.walk(forOp);
+ auto &subBlocks = jbg.subBlocks;
+
+ // Collect inner loops.
+ SmallVector<scf::ForOp> innerLoops;
+ forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+
+ // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
+ // iteration. There are (`unrollJamFactor` - 1) iterations.
+ SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);
+
+ // For any loop with iter_args, replace it with a new loop that has
+ // `unrollJamFactor` copies of its iterOperands, iter_args and yield
+ // operands.
+ SmallVector<scf::ForOp> newInnerLoops;
+ IRRewriter rewriter(forOp.getContext());
+ for (scf::ForOp oldForOp : innerLoops) {
+ SmallVector<Value> dupIterOperands, dupYieldOperands;
+ ValueRange oldIterOperands = oldForOp.getInits();
+ ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
+ ValueRange oldYieldOperands =
+ cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
+ // Get additional iterOperands, iterArgs, and yield operands. We will
+ // fix iterOperands and yield operands after cloning of sub-blocks.
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
+ dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
+ }
+ // Create a new loop with additional iterOperands, iter_args and yield
+ // operands. This new loop will take the loop body of the original loop.
+ bool forOpReplaced = oldForOp == forOp;
+ scf::ForOp newForOp =
+ cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
+ rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
+ [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+ return dupYieldOperands;
+ }));
+ newInnerLoops.push_back(newForOp);
+ // `forOp` has been replaced with a new loop.
+ if (forOpReplaced)
+ forOp = newForOp;
+ // Update `operandMaps` for `newForOp` iterArgs and results.
+ ValueRange newIterArgs = newForOp.getRegionIterArgs();
+ unsigned oldNumIterArgs = oldIterArgs.size();
+ ValueRange newResults = newForOp.getResults();
+ unsigned oldNumResults = newResults.size() / unrollJamFactor;
+ assert(oldNumIterArgs == oldNumResults &&
+ "oldNumIterArgs must be the same as oldNumResults");
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+ // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
+ // results. Update `operandMaps[i - 1]` to map old iterArgs and results
+ // to those in the `i`th new set.
+ operandMaps[i - 1].map(newIterArgs[j],
+ newIterArgs[i * oldNumIterArgs + j]);
+ operandMaps[i - 1].map(newResults[j],
+ newResults[i * oldNumResults + j]);
+ }
+ }
+ }
+
+ // Scale the step of loop being unroll-jammed by the unroll-jam factor.
+ rewriter.setInsertionPoint(forOp);
+ int64_t step = forOp.getConstantStep()->getSExtValue();
+ auto newStep = rewriter.createOrFold<arith::MulIOp>(
+ forOp.getLoc(), forOp.getStep(),
+ rewriter.createOrFold<arith::ConstantOp>(
+ forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
+ forOp.setStep(newStep);
+ auto forOpIV = forOp.getInductionVar();
+
+ // Unroll and jam (appends unrollJamFactor - 1 additional copies).
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ for (auto &subBlock : subBlocks) {
+ // Builder to insert unroll-jammed bodies. Insert right at the end of
+ // sub-block.
+ OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
+
+ // If the induction variable is used, create a remapping to the value for
+ // this unrolled instance.
+ if (!forOpIV.use_empty()) {
+ // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
+ auto ivTag = builder.createOrFold<arith::ConstantOp>(
+ forOp.getLoc(), builder.getIndexAttr(step * i));
+ auto ivUnroll =
+ builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
+ operandMaps[i - 1].map(forOpIV, ivUnroll);
+ }
+ // Clone the sub-block being unroll-jammed.
+ for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
+ builder.clone(*it, operandMaps[i - 1]);
+ }
+ // Fix iterOperands and yield op operands of newly created loops.
+ for (auto newForOp : newInnerLoops) {
+ unsigned oldNumIterOperands =
+ newForOp.getNumRegionIterArgs() / unrollJamFactor;
+ unsigned numControlOperands = newForOp.getNumControlOperands();
+ auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+ unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
+ assert(oldNumIterOperands == oldNumYieldOperands &&
+ "oldNumIterOperands must be the same as oldNumYieldOperands");
+ for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+ // The `i`th duplication of an old iterOperand or yield op operand
+ // needs to be replaced with a mapped value from `operandMaps[i - 1]`
+ // if such mapped value exists.
+ newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+ operandMaps[i - 1].lookupOrDefault(
+ newForOp.getOperand(numControlOperands + j)));
+ yieldOp.setOperand(
+ i * oldNumYieldOperands + j,
+ operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
+ }
+ }
+ }
+
+ // Promote the loop body up if this has turned into a single iteration loop.
+ (void)forOp.promoteIfSingleIteration(rewriter);
+ return success();
+}
+
Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
OpFoldResult lb, OpFoldResult ub,
OpFoldResult step) {
diff --git a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
index a01a7995d6189..742b8a2861839 100644
--- a/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops-invalid.mlir
@@ -41,6 +41,105 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @loop_unroll_and_jam_unsupported_trip_count_not_multiple_of_factor() {
+ %c0 = arith.constant 0 : index
+ %c40 = arith.constant 40 : index
+ %c2 = arith.constant 2 : index
+ scf.for %i = %c0 to %c40 step %c2 {
+ arith.addi %i, %i : index
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ // expected-error @below {{failed to unroll and jam}}
+ transform.loop.unroll_and_jam %1 { factor = 3 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @loop_unroll_and_jam_unsupported_loop_with_results() -> index {
+ %c0 = arith.constant 0 : index
+ %c40 = arith.constant 40 : index
+ %c2 = arith.constant 2 : index
+ %sum = scf.for %i = %c0 to %c40 step %c2 iter_args(%does_not_alias_aggregated = %c0) -> (index) {
+ %sum = arith.addi %i, %i : index
+ scf.yield %sum : index
+ }
+ return %sum : index
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ // expected-error @below {{failed to unroll and jam}}
+ transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
+func.func private @loop_unroll_and_jam_unsupported_dynamic_trip_count(%arg0: memref<96x128xi8, 3>, %arg1: memref<128xi8, 3>) {
+ %c96 = arith.constant 96 : index
+ %c1 = arith.constant 1 : index
+ %c128 = arith.constant 128 : index
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ scf.for %arg4 = %c0 to %c4 step %c1 {
+ scf.for %arg2 = %c0 to %c128 step %arg4 {
+ %3 = memref.load %arg1[%arg2] : memref<128xi8, 3>
+ %sum = scf.for %arg3 = %c0 to %c96 step %c1 iter_args(%does_not_alias_aggregated = %3) -> (i8) {
+ %2 = memref.load %arg0[%arg3, %arg2] : memref<96x128xi8, 3>
+ %4 = arith.addi %2, %3 : i8
+ scf.yield %4 : i8
+ }
+ memref.store %sum, %arg1[%arg2] : memref<128xi8, 3>
+ }
+ scf.yield
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["memref.store"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ // expected-error @below {{failed to unroll and jam}}
+ transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @loop_unroll_and_jam_unsupported_dynamic_trip_count(%upper_bound: index) {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ scf.for %i = %c0 to %upper_bound step %c2 {
+ arith.addi %i, %i : index
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ // expected-error @below {{failed to unroll and jam}}
+ transform.loop.unroll_and_jam %1 { factor = 2 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
func.func private @cond() -> i1
func.func private @body()
diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir
index b91225bf45b96..d9445182769e7 100644
--- a/mlir/test/Dialect/SCF/transform-ops.mlir
+++ b/mlir/test/Dialect/SCF/transform-ops.mlir
@@ -168,6 +168,221 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @loop_unroll_and_jam_op
+func.func @loop_unroll_and_jam_op() {
+ // CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[VAL_1:.*]] = arith.constant 40 : index
+ // CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
+ // CHECK: %[[FACTOR:.*]] = arith.constant 4 : index
+ // CHECK: %[[STEP:.*]] = arith.constant 8 : index
+ %c0 = arith.constant 0 : index
+ %c40 = arith.constant 40 : index
+ %c2 = arith.constant 2 : index
+ // CHECK: scf.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[STEP]] {
+ scf.for %i = %c0 to %c40 step %c2 {
+ // CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : index
+ // CHECK: %[[VAL_7:.*]] = arith.constant 2 : index
+ // CHECK: %[[VAL_8:.*]] = arith.addi %[[VAL_5]], %[[VAL_7]] : index
+ // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_8]] : index
+ // CHECK: %[[VAL_10:.*]] = arith.constant 4 : index
+ // CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_5]], %[[VAL_10]] : index
+ // CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_11]] : index
+ // CHECK: %[[VAL_13:.*]] = arith.constant 6 : index
+ // CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_5]], %[[VAL_13]] : index
+ // CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_14]] : index
+ arith.addi %i, %i : index
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+// CHECK: %[[VAL_0:.*]]: memref<96x128xi8, 3>, %[[VAL_1:.*]]: memref<128xi8, 3>) {
+func.func private @loop_unroll_and_jam_op(%arg0: memref<96x128xi8, 3>, %arg1: memref<128xi8, 3>) {
+ // CHECK: %[[UB_INNER:.*]] = arith.constant 96
+ // CHECK: %[[STEP_INNER:.*]] = arith.constant 1
+ // CHECK: %[[UB_OUTER:.*]] = arith.constant 128
+ // CHECK: %[[LB:.*]] = arith.constant 0
+ // CHECK: %[[UNUSED:.*]] = arith.constant 4
+ // CHECK: %[[UNROLL_FACTOR:.*]] = arith.constant 4
+ %c96 = arith.constant 96 : index
+ %c1 = arith.constant 1 : index
+ %c128 = arith.constant 128 : index
+ %c0 = arith.constant 0 : index
+ // CHECK: scf.for %[[OUTER_I:.*]] = %[[LB]] to %[[UB_OUTER]] step %[[UNROLL_FACTOR]] {
+ scf.for %arg2 = %c0 to %c128 step %c1 {
+ // CHECK: %[[LOAD0:.*]] = memref.load %[[VAL_1]]{{\[}}%[[OUTER_I]]]
+ // CHECK: %[[ONE_0:.*]] = arith.constant 1
+ // CHECK: %[[INC_LOAD1:.*]] = arith.addi %[[OUTER_I]], %[[ONE_0]]
+ // CHECK: %[[LOAD1:.*]] = memref.load %[[VAL_1]]{{\[}}%[[INC_LOAD1]]]
+ // CHECK: %[[TWO_0:.*]] = arith.constant 2
+ // CHECK: %[[INC_LOAD2:.*]] = arith.addi %[[OUTER_I]], %[[TWO_0]]
+ // CHECK: %[[LOAD2:.*]] = memref.load %[[VAL_1]]{{\[}}%[[INC_LOAD2]]]
+ // CHECK: %[[THREE_0:.*]] = arith.constant 3
+ // CHECK: %[[INC_LOAD3:.*]] = arith.addi %[[OUTER_I]], %[[THREE_0]]
+ // CHECK: %[[LOAD3:.*]] = memref.load %[[VAL_1]]{{\[}}%[[INC_LOAD3]]]
+ %3 = memref.load %arg1[%arg2] : memref<128xi8, 3>
+ // CHECK: %[[VAL_19:.*]]:4 = scf.for %[[VAL_20:.*]] = %[[LB]] to %[[UB_INNER]] step %[[STEP_INNER]] iter_args(%[[VAL_21:.*]] = %[[LOAD0]], %[[VAL_22:.*]] = %[[LOAD1]], %[[VAL_23:.*]] = %[[LOAD2]], %[[VAL_24:.*]] = %[[LOAD3]])
+ %sum = scf.for %arg3 = %c0 to %c96 step %c1 iter_args(%does_not_alias_aggregated = %3) -> (i8) {
+ // CHECK: %[[LOAD0_INNER:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[OUTER_I]]]
+ // CHECK: %[[SUM_0:.*]] = arith.addi %[[LOAD0_INNER]], %[[LOAD0]]
+ // CHECK: %[[ONE_1:.*]] = arith.constant 1
+ // CHECK: %[[INC1_INNER:.*]] = arith.addi %[[OUTER_I]], %[[ONE_1]]
+ // CHECK: %[[LOAD1_INNER:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[INC1_INNER]]]
+ // CHECK: %[[SUM_1:.*]] = arith.addi %[[LOAD1_INNER]], %[[LOAD1]]
+ // CHECK: %[[TWO_1:.*]] = arith.constant 2
+ // CHECK: %[[INC2_INNER:.*]] = arith.addi %[[OUTER_I]], %[[TWO_1]]
+ // CHECK: %[[LOAD2_INNER:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[INC2_INNER]]]
+ // CHECK: %[[SUM_2:.*]] = arith.addi %[[LOAD2_INNER]], %[[LOAD2]]
+ // CHECK: %[[THREE_1:.*]] = arith.constant 3
+ // CHECK: %[[INC3_INNER:.*]] = arith.addi %[[OUTER_I]], %[[THREE_1]]
+ // CHECK: %[[LOAD3_INNER:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_20]], %[[INC3_INNER]]]
+ // CHECK: %[[SUM_3:.*]] = arith.addi %[[LOAD3_INNER]], %[[LOAD3]]
+ %2 = memref.load %arg0[%arg3, %arg2] : memref<96x128xi8, 3>
+ %4 = arith.addi %2, %3 : i8
+ // CHECK: scf.yield %[[SUM_0]], %[[SUM_1]], %[[SUM_2]], %[[SUM_3]]
+ scf.yield %4 : i8
+ }
+ memref.store %sum, %arg1[%arg2] : memref<128xi8, 3>
+ // CHECK: memref.store %[[VAL_39:.*]]#0, %[[VAL_1]]{{\[}}%[[OUTER_I]]]
+ // CHECK: %[[ONE_2:.*]] = arith.constant 1
+ // CHECK: %[[INC_STORE1:.*]] = arith.addi %[[OUTER_I]], %[[ONE_2]]
+ // CHECK: memref.store %[[VAL_39]]#1, %[[VAL_1]]{{\[}}%[[INC_STORE1]]]
+ // CHECK: %[[TWO_2:.*]] = arith.constant 2
+ // CHECK: %[[INC_STORE2:.*]] = arith.addi %[[OUTER_I]], %[[TWO_2]]
+ // CHECK: memref.store %[[VAL_39]]#2, %[[VAL_1]]{{\[}}%[[INC_STORE2]]]
+ // CHECK: %[[THREE_2:.*]] = arith.constant 3
+ // CHECK: %[[INC_STORE3:.*]] = arith.addi %[[OUTER_I]], %[[THREE_2]]
+ // CHECK: memref.store %[[VAL_39]]#3, %[[VAL_1]]{{\[}}%[[INC_STORE3]]]
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["memref.store"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+func.func @loop_unroll_and_jam_op() {
+ // CHECK: %[[ZERO:.*]] = arith.constant 0
+ // CHECK: %[[UNUSED_FACTOR:.*]] = arith.constant 4
+ // CHECK: %[[UNUSED_STEP:.*]] = arith.constant 2
+ // CHECK: %[[UNUSED_STEP2:.*]] = arith.constant 2
+ // CHECK: %[[UNUSED_UB:.*]] = arith.constant 4
+ // CHECK: %[[ITER_0_RES:.*]] = arith.addi %[[ZERO]], %[[ZERO]]
+ // CHECK: %[[TWO:.*]] = arith.constant 2
+ // CHECK: %[[STEP_1_I:.*]] = arith.addi %[[ZERO]], %[[TWO]]
+ // CHECK: %[[ITER_1_RES:.*]] = arith.addi %[[STEP_1_I]], %[[STEP_1_I]]
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c2 = arith.constant 2 : index
+ scf.for %i = %c0 to %c4 step %c2 {
+ arith.addi %i, %i : index
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ transform.loop.unroll_and_jam %1 { factor = 4 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+func.func @loop_unroll_and_jam_op() {
+ // CHECK: %[[LB:.*]] = arith.constant 0
+ // CHECK: %[[UB:.*]] = arith.constant 4
+ // CHECK: %[[STEP:.*]] = arith.constant 2
+ %c0 = arith.constant 0 : index
+ %c4 = arith.constant 4 : index
+ %c2 = arith.constant 2 : index
+ // CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
+ scf.for %i = %c0 to %c4 step %c2 {
+ scf.yield
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addi"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "scf.for"} : (!transform.any_op) -> !transform.op<"scf.for">
+ transform.loop.unroll_and_jam %1 { factor = 2 } : !transform.op<"scf.for">
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @loop_unroll_and_jam_op
+// CHECK: %[[VAL_0:.*]]: memref<21x30xf32, 1>, %[[INIT0:.*]]: f32, %[[INIT1:.*]]: f32) {
+func.func @loop_unroll_and_jam_op(%arg0: memref<21x30xf32, 1>, %init : f32, %init1 : f32) {
+ // CHECK: %[[LAST_OUT_ITER:.*]] = arith.constant 20
+ // CHECK: %[[VAL_4:.*]]:2 = affine.for %[[OUTER_I:.*]] = 0 to 20 step 2 iter_args(%[[VAL_6:.*]] = %[[INIT0]], %[[VAL_7:.*]] = %[[INIT0]])
+ %0 = affine.for %arg3 = 0 to 21 iter_args(%arg4 = %init) -> (f32) {
+ // CHECK: %[[VAL_8:.*]]:2 = affine.for %[[INNER_I:.*]] = 0 to 30 iter_args(%[[SUM0:.*]] = %[[INIT1]], %[[SUM1:.*]] = %[[INIT1]])
+ %1 = affine.for %arg5 = 0 to 30 iter_args(%arg6 = %init1) -> (f32) {
+ // CHECK: %[[LOAD0:.*]] = affine.load %[[VAL_0]]{{\[}}%[[OUTER_I]], %[[INNER_I]]]
+ // CHECK: %[[ITER_SUM0:.*]] = arith.addf %[[SUM0]], %[[LOAD0]]
+ // CHECK: %[[APPLY_OUTER_I:.*]] = affine.apply #map(%[[OUTER_I]])
+ // CHECK: %[[LOAD1:.*]] = affine.load %[[VAL_0]]{{\[}}%[[APPLY_OUTER_I]], %[[INNER_I]]]
+ // CHECK: %[[ITER_SUM1:.*]] = arith.addf %[[SUM1]], %[[LOAD1]]
+ // CHECK: affine.yield %[[ITER_SUM0]], %[[ITER_SUM1]]
+ %3 = affine.load %arg0[%arg3, %arg5] : memref<21x30xf32, 1>
+ %4 = arith.addf %arg6, %3 : f32
+ affine.yield %4 : f32
+ }
+ // CHECK: %[[MUL0:.*]] = arith.mulf %[[VAL_6]], %[[VAL_18:.*]]#0
+ // CHECK: %[[VAL_19:.*]] = affine.apply #map(%[[OUTER_I]])
+ // CHECK: %[[MUL1:.*]] = arith.mulf %[[VAL_7]], %[[VAL_18]]#1
+ // CHECK: affine.yield %[[MUL0]], %[[MUL1]]
+ // CHECK: }
+ // CHECK: %[[VAL_21:.*]] = arith.mulf %[[VAL_22:.*]]#0, %[[VAL_22]]#1
+ // CHECK: %[[VAL_23:.*]] = affine.for %[[SUFFIX_I:.*]] = 0 to 30 iter_args(%[[ITER_I:.*]] = %[[INIT1]])
+ // CHECK: %[[LOAD_SUFFIX:.*]] = affine.load %[[VAL_0]]{{\[}}%[[LAST_OUT_ITER]], %[[SUFFIX_I]]]
+ // CHECK: %[[RES:.*]] = arith.addf %[[ITER_I]], %[[LOAD_SUFFIX]]
+ // CHECK: affine.yield %[[RES]]
+ %2 = arith.mulf %arg4, %1 : f32
+ affine.yield %2 : f32
+ }
+ // CHECK: %[[VAL_28:.*]] = arith.mulf %[[VAL_21]], %[[VAL_29:.*]]
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {op_name = "affine.for"} : (!transform.any_op) -> !transform.op<"affine.for">
+ %2 = transform.get_parent_op %1 {op_name = "affine.for"} : (!transform.op<"affine.for">) -> !transform.op<"affine.for">
+ transform.loop.unroll_and_jam %2 { factor = 2 } : !transform.op<"affine.for">
+ transform.yield
+ }
+}
+
+// -----
+
func.func @loop_unroll_op() {
%c0 = arith.constant 0 : index
%c42 = arith.constant 42 : index
More information about the Mlir-commits
mailing list