[Mlir-commits] [mlir] [mlir][Affine] Genarilze the linearize(delinearize()) simplifications (PR #117637)
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Dec 2 10:00:02 PST 2024
https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/117637
>From fdef27e22bebc7b43f365ae55bc9e9e3c507ae48 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 25 Nov 2024 19:40:06 +0000
Subject: [PATCH 1/2] [mlir][Affine] Genarilze the linearize(delinearize())
simplifications
The existing canonicalization patterns would only cancel out cases
where the entire result list of an affine.delineraize_index was passed
to an affine.lineraize_index and the basis elements matched
exactly (except possibly for the outer bounds).
This was correct, but limited, and left open many cases where a
delinearize_index would take a series of divisions and modulos only
for a subsequent linearize_index to use additions and multiplications
to undo all that work.
This sort of simplification is reasably easy to observe at the level
of splititng and merging indexes, but difficult to perform once the
underlying arithmetic operations have been created.
Therefore, this commit generalizes the existing simplification logic.
Now, any run of two or more delinearize_index results that appears
within the argument list of a linearize_index operation with the same
basis (or where they're both at the outermost position and so can be
unbonded, or when `linearize_index disjoint` implies a bound not
present on the `delinearize_index`) will be reduced to one signle
delinearize_index output, whose basis element (that is, size or
length) is equal to the product of the sizes that were simplified
away.
That is, we can now simplify
%0:2 = affine.delinearize_index %n into (8, 8) : inde, index
%1 = affine.linearize_index [%x, %0#0, %0#1, %y] by (3, 8, 8, 5) : index
to the simpler
%1 = affine.linearize_index [%x, %n, %y] by (3, 64, 5) : index
This new pattern also works with dynamically-sized basis values.
While I'm here, I fixed a bunch of typos in existing tests, and added
a new getPaddedBasis() method to make processing
potentially-underspecified basis elements simpler in some cases.
---
.../mlir/Dialect/Affine/IR/AffineOps.td | 27 +-
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 250 ++++++++++++++++--
mlir/test/Dialect/Affine/canonicalize.mlir | 224 +++++++++++++++-
3 files changed, 459 insertions(+), 42 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 03172f7ce00e4b..5e55a4c3a9b2d0 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1083,6 +1083,9 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
%indices_2 = affine.apply #map2()[%linear_index]
```
+ In other words, `%0:3 = affine.delinearize_index %x into (B, C)` produces
+ `%0 = {%x / (B * C), (%x mod (B * C)) / C, %x mod C}`.
+
The basis may either contain `N` or `N-1` elements, where `N` is the number of results.
If there are N basis elements, the first one will not be used during computations,
but may be used during analysis and canonicalization to eliminate terms from
@@ -1098,7 +1101,12 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
%0:3 = affine.delinearize_index %linear_index into (244, 244) : index, index
```
- Note that, due to the constraints of affine maps, all the basis elements must
+ Note that, for symmetry with `getPaddedBasis()`, if `hasOuterBound` is `true`
+ when one of the `OpFoldResult` builders is called but the first element of the
+ basis is `nullptr`, that first element is ignored and the builder proceeds as if
+ there was no outer bound.
+
+ Due to the constraints of affine maps, all the basis elements must
be strictly positive. A dynamic basis element being 0 or negative causes
undefined behavior.
}];
@@ -1136,6 +1144,11 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
/// Return a vector that contains the basis of the operation, removing
/// the outer bound if one is present.
SmallVector<OpFoldResult> getEffectiveBasis();
+
+ /// Return the vector with one basis element per result of the operation. If
+ /// there is no outer bound specified, the leading entry of this result will be
+ /// nullptr.
+ SmallVector<OpFoldResult> getPaddedBasis();
}];
let hasVerifier = 1;
@@ -1160,6 +1173,9 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
sum(i = 0 to N-1) %idx_i * product(j = i + 1 to N-1) B_j
```
+ In other words, `%0 = affine.linearize_index [%z, %y, %x] by (Z, Y, X)`
+ gives `%0 = %x + %y * X + %z * X * Y`, or `%0 = %x + X * (%y + Y * (%z))`.
+
The basis may either have `N` or `N-1` elements, where `N` is the number of
inputs to linearize_index. If `N` inputs are provided, the first one is not used
in computation, but may be used during analysis or canonicalization as a bound
@@ -1168,6 +1184,10 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
If all `N` basis elements are provided, the linearize_index operation is said to
"have an outer bound".
+ As a convenience, and for symmetry with `getPaddedBasis()`, ifg the first
+ element of a set of `OpFoldResult`s passed to the builders of this operation is
+ `nullptr`, that element is ignored.
+
If the `disjoint` property is present, this is an optimization hint that,
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
except that `%idx_0` may be negative to make the index as a whole negative.
@@ -1224,6 +1244,11 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
/// Return a vector that contains the basis of the operation, removing
/// the outer bound if one is present.
SmallVector<OpFoldResult> getEffectiveBasis();
+
+ /// Return the vector with one basis element per index operand of the operation.
+ /// If there is no outer bound specified, the leading entry of this basis will be
+ /// nullptr.
+ SmallVector<OpFoldResult> getPaddedBasis();
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index dceebbfec586c8..394a1395d4de1c 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4520,6 +4520,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
Value linearIndex, ValueRange basis,
bool hasOuterBound) {
+ if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
+ hasOuterBound = false;
+ basis = basis.drop_front();
+ }
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
@@ -4533,6 +4537,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
Value linearIndex,
ArrayRef<OpFoldResult> basis,
bool hasOuterBound) {
+ if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
+ hasOuterBound = false;
+ basis = basis.drop_front();
+ }
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
@@ -4654,6 +4662,13 @@ SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
}
+SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
+ SmallVector<OpFoldResult> ret = getMixedBasis();
+ if (!hasOuterBound())
+ ret.insert(ret.begin(), OpFoldResult());
+ return ret;
+}
+
namespace {
// Drops delinearization indices that correspond to unit-extent basis
@@ -4672,25 +4687,27 @@ struct DropUnitExtentBasis
return zero.value();
};
- bool hasOuterBound = delinearizeOp.hasOuterBound();
// Replace all indices corresponding to unit-extent basis with 0.
// Remaining basis can be used to get a new `affine.delinearize_index` op.
SmallVector<OpFoldResult> newBasis;
- for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
- std::optional<int64_t> basisVal = getConstantIntValue(basis);
+ for (auto [index, basis] :
+ llvm::enumerate(delinearizeOp.getPaddedBasis())) {
+ std::optional<int64_t> basisVal =
+ basis ? getConstantIntValue(basis) : std::nullopt;
if (basisVal && *basisVal == 1)
- replacements[index + (hasOuterBound ? 0 : 1)] = getZero();
+ replacements[index] = getZero();
else
newBasis.push_back(basis);
}
- if (newBasis.size() == delinearizeOp.getStaticBasis().size())
+ if (newBasis.size() == delinearizeOp.getNumResults())
return rewriter.notifyMatchFailure(delinearizeOp,
"no unit basis elements");
- if (!newBasis.empty() || !hasOuterBound) {
+ if (!newBasis.empty()) {
+ // Will drop the leading nullptr from `basis` if there was no outer bound.
auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
- loc, delinearizeOp.getLinearIndex(), newBasis, hasOuterBound);
+ loc, delinearizeOp.getLinearIndex(), newBasis);
int newIndex = 0;
// Map back the new delinearized indices to the values they replace.
for (auto &replacement : replacements) {
@@ -4871,6 +4888,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
ValueRange multiIndex, ValueRange basis,
bool disjoint) {
+ if (!basis.empty() && basis.front() == Value())
+ basis = basis.drop_front();
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
@@ -4883,6 +4902,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
ValueRange multiIndex,
ArrayRef<OpFoldResult> basis,
bool disjoint) {
+ if (!basis.empty() && basis.front() == OpFoldResult())
+ basis = basis.drop_front();
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
@@ -4965,7 +4986,14 @@ SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
builder);
}
- return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+ return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+}
+
+SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
+ SmallVector<OpFoldResult> ret = getMixedBasis();
+ if (!hasOuterBound())
+ ret.insert(ret.begin(), OpFoldResult());
+ return ret;
}
namespace {
@@ -5027,38 +5055,202 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
}
};
-/// Cancel out linearize_index(delinearize_index(x, B), B).
+/// Return the product of `terms`, creating an `affine.apply` if any of them are
+/// non-constant values. If any of `terms` is `nullptr`, return `nullptr`.
+static OpFoldResult computeProduct(Location loc, OpBuilder &builder,
+ ArrayRef<OpFoldResult> terms) {
+ int64_t nDynamic = 0;
+ SmallVector<Value> dynamicPart;
+ AffineExpr result = builder.getAffineConstantExpr(1);
+ for (OpFoldResult term : terms) {
+ if (!term)
+ return term;
+ std::optional<int64_t> maybeConst = getConstantIntValue(term);
+ if (maybeConst) {
+ result = result * builder.getAffineConstantExpr(*maybeConst);
+ } else {
+ dynamicPart.push_back(term.get<Value>());
+ result = result * builder.getAffineSymbolExpr(nDynamic++);
+ }
+ }
+ if (auto constant = dyn_cast<AffineConstantExpr>(result))
+ return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
+ return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
+}
+
+/// If conseceutive outputs of a delinearize_index are linearized with the same
+/// bounds, canonicalize away the redundant arithmetic.
+///
+/// That is, if we have
+/// ```
+/// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
+/// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
+/// by (...e, B1, B2, ..., BK, ...f)
+/// ```
///
-/// That is, rewrite
+/// We can rewrite this to
+/// ```
+/// B = B1 * B2 ... BK
+/// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
+/// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
/// ```
-/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN)
-/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ...
-/// %bN)
+/// where we replace all results of %s unaffected by the change with results
+/// from %sMerged.
+///
+/// As a special case, if all results of the delinearize are merged in this way
+/// we can replace those usages with %x, thus cancelling the delinearization
+/// entirely, as in
/// ```
-/// to replacing `%y` with `%x`.
-struct CancelLinearizeOfDelinearizeExact final
+/// %s:3 = affine.delinearize_index %x into (2, 4, 8)
+/// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
+/// ```
+/// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
+struct CancelLinearizeOfDelinearizePortion final
: OpRewritePattern<affine::AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
+ struct Match {
+ AffineDelinearizeIndexOp delinearize;
+ unsigned linStart = 0;
+ unsigned delinStart = 0;
+ unsigned length = 0;
+ };
+
LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
PatternRewriter &rewriter) const override {
- auto delinearizeOp = linearizeOp.getMultiIndex()
- .front()
- .getDefiningOp<affine::AffineDelinearizeIndexOp>();
- if (!delinearizeOp)
- return rewriter.notifyMatchFailure(
- linearizeOp, "last entry doesn't come from a delinearize");
+ SmallVector<Match> matches;
+
+ const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
+ ArrayRef<OpFoldResult> linBasisRef = linBasis;
+
+ ValueRange multiIndex = linearizeOp.getMultiIndex();
+ unsigned numLinArgs = multiIndex.size();
+ unsigned linArgIdx = 0;
+ // We only want to replace one run from the same delinearize op per
+ // pattern invocation lest we run into invalidation issues.
+ llvm::SmallPtrSet<Operation *, 2> seen;
+ while (linArgIdx < numLinArgs) {
+ auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
+ if (!asResult) {
+ linArgIdx++;
+ continue;
+ }
- if (linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
- return rewriter.notifyMatchFailure(
- linearizeOp, "basis of linearize and delinearize don't match exactly "
- "(excluding outer bounds)");
+ auto delinearizeOp =
+ dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
+ if (!delinearizeOp) {
+ linArgIdx++;
+ continue;
+ }
+
+ /// Result 0 of the delinearize and argument 0 of the linearize can
+ /// leave their maximum value unspecified. However, even if this happens
+ /// we can still sometimes start the match process. Specifically, if
+ /// - The argument we're matching is result 0 and argument 0 (so the
+ /// bounds don't matter). For example,
+ ///
+ /// %0:2 = affine.delinearize_index %x into (8) : index, index
+ /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
+ /// allows cancellation
+ /// - The delinearization doesn't specify a bound, but the linearization
+ /// is `disjoint`, which asserts that the bound on the linearization is
+ /// correct.
+ unsigned firstDelinArg = asResult.getResultNumber();
+ SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
+ OpFoldResult firstDelinBound = delinBasis[firstDelinArg];
+ OpFoldResult firstLinBound = linBasis[linArgIdx];
+ bool boundsMatch = firstDelinBound == firstLinBound;
+ bool bothAtFront = linArgIdx == 0 && firstDelinArg == 0;
+ bool knownByDisjoint =
+ linearizeOp.getDisjoint() && firstDelinArg == 0 && !firstDelinBound;
+ if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
+ linArgIdx++;
+ continue;
+ }
- if (delinearizeOp.getResults() != linearizeOp.getMultiIndex())
+ unsigned j = 1;
+ unsigned numDelinOuts = delinearizeOp.getNumResults();
+ for (; j + linArgIdx < numLinArgs && j + firstDelinArg < numDelinOuts;
+ ++j) {
+ if (multiIndex[linArgIdx + j] !=
+ delinearizeOp.getResult(firstDelinArg + j))
+ break;
+ if (linBasis[linArgIdx + j] != delinBasis[firstDelinArg + j])
+ break;
+ }
+ // If there're multiple matches against the same delinearize_index,
+ // only rewrite the first one we find to prevent invalidations. The next
+ // ones will be taken caer of by subsequent pattern invocations.
+ if (j <= 1 || !seen.insert(delinearizeOp).second) {
+ linArgIdx++;
+ continue;
+ }
+ matches.push_back(Match{delinearizeOp, linArgIdx, firstDelinArg, j});
+ linArgIdx += j;
+ }
+
+ if (matches.empty())
return rewriter.notifyMatchFailure(
- linearizeOp, "not all indices come from delinearize");
+ linearizeOp, "no run of delinearize outputs to deal with");
+
+ SmallVector<std::tuple<Value, Value>> delinearizeReplacements;
+ SmallVector<Value> newIndex;
+ newIndex.reserve(numLinArgs);
+ SmallVector<OpFoldResult> newBasis;
+ newBasis.reserve(numLinArgs);
+ unsigned prevMatchEnd = 0;
+ for (Match m : matches) {
+ unsigned gap = m.linStart - prevMatchEnd;
+ llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
+ llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
+ // Update here so we don't forget this during early continues
+ prevMatchEnd = m.linStart + m.length;
+
+ // We use the slice from the linearize's basis above because of the
+ // "bounds inferred from `disjoint`" case above.
+ OpFoldResult newSize =
+ computeProduct(linearizeOp.getLoc(), rewriter,
+ linBasisRef.slice(m.linStart, m.length));
+
+ // Trivial case where we can just skip past the delinearize all together
+ if (m.length == m.delinearize.getNumResults()) {
+ newIndex.push_back(m.delinearize.getLinearIndex());
+ newBasis.push_back(newSize);
+ continue;
+ }
+ SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
+ newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
+ newDelinBasis.begin() + m.delinStart + m.length);
+ newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
+ auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
+ m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
+ newDelinBasis);
+
+ // Swap all the uses of the unaffected delinearize outputs to the new
+ // delinearization so that the old code can be removed if this
+ // linearize_index is the only user of the merged results.
+ llvm::append_range(
+ delinearizeReplacements,
+ llvm::zip_equal(
+ m.delinearize.getResults().take_front(m.delinStart),
+ newDelinearize.getResults().take_front(m.delinStart)));
+ llvm::append_range(
+ delinearizeReplacements,
+ llvm::zip_equal(
+ m.delinearize.getResults().drop_front(m.delinStart + m.length),
+ newDelinearize.getResults().drop_front(m.delinStart + 1)));
+
+ Value newLinArg = newDelinearize.getResult(m.delinStart);
+ newIndex.push_back(newLinArg);
+ newBasis.push_back(newSize);
+ }
+ llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
+ llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
+ rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
+ linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
- rewriter.replaceOp(linearizeOp, delinearizeOp.getLinearIndex());
+ for (auto [from, to] : delinearizeReplacements)
+ rewriter.replaceAllUsesWith(from, to);
return success();
}
};
@@ -5096,7 +5288,7 @@ struct DropLinearizeLeadingZero final
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
- patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
+ patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
DropLinearizeUnitComponentsIfDisjointOrZero>(context);
}
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 717004eb50c0fc..223d9c0996439a 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1917,12 +1917,12 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
// -----
-// CHECK-LABEL: func @cancel_linearize_denearize_exact(
+// CHECK-LABEL: func @cancel_linearize_delinearize_exact(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
// CHECK: return %[[ARG0]]
-func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
+func.func @cancel_linearize_delinearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
%0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
%1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
return %1 : index
@@ -1930,12 +1930,12 @@ func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: i
// -----
-// CHECK-LABEL: func @cancel_linearize_denearize_linearize_extra_bound(
+// CHECK-LABEL: func @cancel_linearize_delinearize_linearize_extra_bound(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
// CHECK: return %[[ARG0]]
-func.func @cancel_linearize_denearize_linearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index {
+func.func @cancel_linearize_delinearize_linearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index {
%0:3 = affine.delinearize_index %arg0 into (4, %arg2) : index, index, index
%1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
return %1 : index
@@ -1943,12 +1943,12 @@ func.func @cancel_linearize_denearize_linearize_extra_bound(%arg0: index, %arg1:
// -----
-// CHECK-LABEL: func @cancel_linearize_denearize_delinearize_extra_bound(
+// CHECK-LABEL: func @cancel_linearize_delinearize_delinearize_extra_bound(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
// CHECK: return %[[ARG0]]
-func.func @cancel_linearize_denearize_delinearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index {
+func.func @cancel_linearize_delinearize_delinearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index {
%0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
%1 = affine.linearize_index [%0#0, %0#1, %0#2] by (4, %arg2) : index
return %1 : index
@@ -1956,31 +1956,231 @@ func.func @cancel_linearize_denearize_delinearize_extra_bound(%arg0: index, %arg
// -----
+// CHECK-LABEL: func @cancel_linearize_delinearize_head(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (12, 8)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[ARG1]]] by (12, 16)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_head(%arg0: index, %arg1: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (3, 4, 8) : index, index, index
+ %1 = affine.linearize_index [%0#0, %0#1, %arg1] by (3, 4, 16) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_head_delinearize_unbounded(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (12, 8)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[ARG1]]] by (12, 16)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_head_delinearize_unbounded(%arg0: index, %arg1: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (4, 8) : index, index, index
+ %1 = affine.linearize_index [%0#0, %0#1, %arg1] by (3, 4, 16) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_head_linearize_unbounded(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (8)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[ARG1]]] by (16)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_head_linearize_unbounded(%arg0: index, %arg1: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (3, 4, 8) : index, index, index
+ %1 = affine.linearize_index [%0#0, %0#1, %arg1] by (4, 16) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_head_both_unbounded(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (8)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[ARG1]]] by (16)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_head_both_unbounded(%arg0: index, %arg1: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (4, 8) : index, index, index
+ %1 = affine.linearize_index [%0#0, %0#1, %arg1] by (4, 16) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_tail(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (3, 32)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG1]], %[[DELIN]]#1] by (5, 32)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_tail(%arg0: index, %arg1: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (3, 4, 8) : index, index, index
+ %1 = affine.linearize_index [%arg1, %0#1, %0#2] by (5, 4, 8) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_middle_exact(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG1]], %[[ARG0]], %[[ARG2]]] by (9, 30, 7)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_middle_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (2, 3, 5) : index, index, index
+ %1 = affine.linearize_index [%arg1, %0#0, %0#1, %0#2, %arg2] by (9, 2, 3, 5, 7) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) * 16)>
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_middle_exact_dynamic_basis(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index)
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[SIZEPROD:.+]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[C1]], %[[ARG0]], %[[C1]]] by (3, %[[SIZEPROD]], 4)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_middle_exact_dynamic_basis(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %c1 = arith.constant 1 : index
+ %0:4 = affine.delinearize_index %arg0 into (2, %arg1, %arg2, 8) : index, index, index, index
+ %1 = affine.linearize_index [%c1, %0#0, %0#1, %0#2, %0#3, %c1] by (3, 2, %arg1, %arg2, 8, 4) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_middle_exact_delinearize_unbounded_disjoint(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG0]], %[[ARG2]]] by (9, 30, 7)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_middle_exact_delinearize_unbounded_disjoint(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (3, 5) : index, index, index
+ %1 = affine.linearize_index disjoint [%arg1, %0#0, %0#1, %0#2, %arg2] by (9, 2, 3, 5, 7) : index
+ return %1 : index
+}
+
+// -----
+
+// Unlike in the test above, the linerize indices aren't asserted to be disjoint, so
+// we can't know if the `2` from the basis is a correct bound.
+// CHECK-LABEL: func @dont_cancel_linearize_delinearize_middle_exact_delinearize_unbounded(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (3)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG1]], %[[DELIN]]#0, %[[DELIN]]#1, %[[ARG2]]] by (9, 2, 3, 7)
+// CHECK: return %[[LIN]]
+
+func.func @dont_cancel_linearize_delinearize_middle_exact_delinearize_unbounded(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:2 = affine.delinearize_index %arg0 into (3) : index, index
+ %1 = affine.linearize_index [%arg1, %0#0, %0#1, %arg2] by (9, 2, 3, 7) : index
+ return %1 : index
+}
+
+// -----
+
+// The presence of a `disjoint` here tells us that the "unbounded" term on the
+// delinearization can't have been above 2.
+// CHECK-LABEL: func @cancel_linearize_delinearize_middle_delinearize_unbounded_disjoint_implied_bound(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index)
+// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[ARG0]] into (6, 5)
+// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG1]], %[[DELIN]]#0, %[[ARG2]]] by (9, 6, 7)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_middle_delinearize_unbounded_disjoint_implied_bound(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:3 = affine.delinearize_index %arg0 into (3, 5) : index, index, index
+ %1 = affine.linearize_index disjoint [%arg1, %0#0, %0#1, %arg2] by (9, 2, 3, 7) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_multiple_matches(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[C0:.+]] = arith.constant 0
+// CHECK: %[[DELIN:.+]]:4 = affine.delinearize_index %[[ARG0]] into (4, 16, 4, 64)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG1]], %[[DELIN]]#1, %[[C0]], %[[DELIN]]#3] by (4, 16, 4, 64)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_multiple_matches(%arg0: index, %arg1: index) -> index {
+ %c0 = arith.constant 0 : index
+ %0:7 = affine.delinearize_index %arg0 into (4, 4, 4, 4, 4, 4, 4) : index, index, index, index, index, index, index
+ %1 = affine.linearize_index [%arg1, %0#1, %0#2, %c0, %0#4, %0#5, %0#6] by (4, 4, 4, 4, 4, 4, 4) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_multiple_delinearizes(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (32, 32)
+// CHECK: return %[[LIN]]
+func.func @cancel_linearize_delinearize_multiple_delinearizes(%arg0: index, %arg1: index) -> index {
+ %0:2 = affine.delinearize_index %arg0 into (4, 8) : index, index
+ %1:2 = affine.delinearize_index %arg1 into (2, 16) : index, index
+ %2 = affine.linearize_index [%0#0, %0#1, %1#0, %1#1] by (4, 8, 2, 16) : index
+ return %2 : index
+}
+
+// -----
+
// Don't cancel because the values from the delinearize aren't used in order
-// CHECK-LABEL: func @no_cancel_linearize_denearize_permuted(
+// CHECK-LABEL: func @no_cancel_linearize_delinearize_permuted(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
-// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#2, %[[DELIN]]#1] by (%[[ARG1]], 4, %[[ARG2]])
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#2, %[[DELIN]]#1] by (%[[ARG1]], %[[ARG2]], 4)
// CHECK: return %[[LIN]]
-func.func @no_cancel_linearize_denearize_permuted(%arg0: index, %arg1: index, %arg2: index) -> index {
+func.func @no_cancel_linearize_delinearize_permuted(%arg0: index, %arg1: index, %arg2: index) -> index {
%0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
- %1 = affine.linearize_index [%0#0, %0#2, %0#1] by (%arg1, 4, %arg2) : index
+ %1 = affine.linearize_index [%0#0, %0#2, %0#1] by (%arg1, %arg2, 4) : index
+ return %1 : index
+}
+
+// -----
+
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 3)>
+// But these cancel because they're a contiguous segment
+// CHECK-LABEL: func @partial_cancel_linearize_delinearize_not_fully_permuted(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[SIZEPROD:.+]] = affine.apply #[[$MAP]]()[%[[ARG2]]]
+// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[SIZEPROD]])
+// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#2, %[[DELIN]]#1] by (%[[ARG1]], %[[SIZEPROD]], 4)
+// CHECK: return %[[LIN]]
+func.func @partial_cancel_linearize_delinearize_not_fully_permuted(%arg0: index, %arg1: index, %arg2: index) -> index {
+ %0:4 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2, 3) : index, index, index, index
+ %1 = affine.linearize_index [%0#0, %0#2, %0#3, %0#1] by (%arg1, %arg2, 3, 4) : index
return %1 : index
}
// -----
// Won't cancel because the linearize and delinearize are using a different basis
-// CHECK-LABEL: func @no_cancel_linearize_denearize_different_basis(
+// CHECK-LABEL: func @no_cancel_linearize_delinearize_different_basis(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], 4, %[[ARG2]])
// CHECK: %[[LIN:.+]] = affine.linearize_index [%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] by (%[[ARG1]], 8, %[[ARG2]])
// CHECK: return %[[LIN]]
-func.func @no_cancel_linearize_denearize_different_basis(%arg0: index, %arg1: index, %arg2: index) -> index {
+func.func @no_cancel_linearize_delinearize_different_basis(%arg0: index, %arg1: index, %arg2: index) -> index {
%0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
%1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 8, %arg2) : index
return %1 : index
>From 6a8d42d3b79e5c8180323525b93ddc0c9bfef903 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Mon, 2 Dec 2024 17:59:45 +0000
Subject: [PATCH 2/2] Review feedback
---
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 28 ++++++++++++++----------
1 file changed, 17 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 394a1395d4de1c..5bec002fffee22 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -5109,6 +5109,11 @@ struct CancelLinearizeOfDelinearizePortion final
: OpRewritePattern<affine::AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
+private:
+ // Struct representing a case where the cancellation pattern
+ // applies. A `Match` means that `length` inputs to the linearize operation
+ // starting at `linStart` can be cancelled with `length` outputs of
+ // `delinearize`, starting from `delinStart`.
struct Match {
AffineDelinearizeIndexOp delinearize;
unsigned linStart = 0;
@@ -5116,6 +5121,7 @@ struct CancelLinearizeOfDelinearizePortion final
unsigned length = 0;
};
+public:
LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
PatternRewriter &rewriter) const override {
SmallVector<Match> matches;
@@ -5128,7 +5134,7 @@ struct CancelLinearizeOfDelinearizePortion final
unsigned linArgIdx = 0;
// We only want to replace one run from the same delinearize op per
// pattern invocation lest we run into invalidation issues.
- llvm::SmallPtrSet<Operation *, 2> seen;
+ llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
while (linArgIdx < numLinArgs) {
auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
if (!asResult) {
@@ -5155,14 +5161,14 @@ struct CancelLinearizeOfDelinearizePortion final
/// - The delinearization doesn't specify a bound, but the linearization
/// is `disjoint`, which asserts that the bound on the linearization is
/// correct.
- unsigned firstDelinArg = asResult.getResultNumber();
+ unsigned delinArgIdx = asResult.getResultNumber();
SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
- OpFoldResult firstDelinBound = delinBasis[firstDelinArg];
+ OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
OpFoldResult firstLinBound = linBasis[linArgIdx];
bool boundsMatch = firstDelinBound == firstLinBound;
- bool bothAtFront = linArgIdx == 0 && firstDelinArg == 0;
+ bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
bool knownByDisjoint =
- linearizeOp.getDisjoint() && firstDelinArg == 0 && !firstDelinBound;
+ linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
linArgIdx++;
continue;
@@ -5170,22 +5176,22 @@ struct CancelLinearizeOfDelinearizePortion final
unsigned j = 1;
unsigned numDelinOuts = delinearizeOp.getNumResults();
- for (; j + linArgIdx < numLinArgs && j + firstDelinArg < numDelinOuts;
+ for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
++j) {
if (multiIndex[linArgIdx + j] !=
- delinearizeOp.getResult(firstDelinArg + j))
+ delinearizeOp.getResult(delinArgIdx + j))
break;
- if (linBasis[linArgIdx + j] != delinBasis[firstDelinArg + j])
+ if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
break;
}
// If there're multiple matches against the same delinearize_index,
// only rewrite the first one we find to prevent invalidations. The next
- // ones will be taken caer of by subsequent pattern invocations.
- if (j <= 1 || !seen.insert(delinearizeOp).second) {
+ // ones will be taken care of by subsequent pattern invocations.
+ if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
linArgIdx++;
continue;
}
- matches.push_back(Match{delinearizeOp, linArgIdx, firstDelinArg, j});
+ matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});
linArgIdx += j;
}
More information about the Mlir-commits
mailing list