[Mlir-commits] [mlir] [mlir][Affine] Genarilze the linearize(delinearize()) simplifications (PR #117637)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 25 14:39:10 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Krzysztof Drewniak (krzysz00)
<details>
<summary>Changes</summary>
The existing canonicalization patterns would only cancel out cases where the entire result list of an affine.delineraize_index was passed to an affine.lineraize_index and the basis elements matched exactly (except possibly for the outer bounds).
This was correct, but limited, and left open many cases where a delinearize_index would take a series of divisions and modulos only for a subsequent linearize_index to use additions and multiplications to undo all that work.
This sort of simplification is reasably easy to observe at the level of splititng and merging indexes, but difficult to perform once the underlying arithmetic operations have been created.
Therefore, this commit generalizes the existing simplification logic.
Now, any run of two or more delinearize_index results that appears within the argument list of a linearize_index operation with the same basis (or where they're both at the outermost position and so can be unbonded, or when `linearize_index disjoint` implies a bound not present on the `delinearize_index`) will be reduced to one signle delinearize_index output, whose basis element (that is, size or length) is equal to the product of the sizes that were simplified away.
That is, we can now simplify
%0:2 = affine.delinearize_index %n into (8, 8) : inde, index
%1 = affine.linearize_index [%x, %0#<!-- -->0, %0#<!-- -->1, %y] by (3, 8, 8, 5) : index
to the simpler
%1 = affine.linearize_index [%x, %n, %y] by (3, 64, 5) : index
This new pattern also works with dynamically-sized basis values.
While I'm here, I fixed a bunch of typos in existing tests, and added a new getPaddedBasis() method to make processing
potentially-underspecified basis elements simpler in some cases.
---
Patch is 33.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117637.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+26-1)
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+221-29)
- (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+212-12)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 76d97f106dcb88..0d1e4ede795ce5 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1083,6 +1083,9 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
%indices_2 = affine.apply #map2()[%linear_index]
```
+ In other words, `%0:3 = affine.delinearize_index %x into (B, C)` produces
+ `%0 = {%x / (B * C), (%x mod (B * C)) / C, %x mod C}`.
+
The basis may either contain `N` or `N-1` elements, where `N` is the number of results.
If there are N basis elements, the first one will not be used during computations,
but may be used during analysis and canonicalization to eliminate terms from
@@ -1098,7 +1101,12 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
%0:3 = affine.delinearize_index %linear_index into (244, 244) : index, index
```
- Note that, due to the constraints of affine maps, all the basis elements must
+ Note that, for symmetry with `getPaddedBasis()`, if `hasOuterBound` is `true`
+ when one of the `OpFoldResult` builders is called but the first element of the
+ basis is `nullptr`, that first element is ignored and the builder proceeds as if
+ there was no outer bound.
+
+ Due to the constraints of affine maps, all the basis elements must
be strictly positive. A dynamic basis element being 0 or negative causes
undefined behavior.
}];
@@ -1136,6 +1144,11 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
/// Return a vector that contains the basis of the operation, removing
/// the outer bound if one is present.
SmallVector<OpFoldResult> getEffectiveBasis();
+
+ /// Return the vector with one basis element per result of the operation. If
+ /// there is no outer bound specified, the leading entry of this result will be
+ /// nullptr.
+ SmallVector<OpFoldResult> getPaddedBasis();
}];
let hasVerifier = 1;
@@ -1160,6 +1173,9 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
sum(i = 0 to N-1) %idx_i * product(j = i + 1 to N-1) B_j
```
+ In other words, `%0 = affine.linearize_index [%z, %y, %x] by (Z, Y, X)`
+ gives `%0 = %x + %y * X + %z * X * Y`, or `%0 = %x + X * (%y + Y * (%z))`.
+
The basis may either have `N` or `N-1` elements, where `N` is the number of
inputs to linearize_index. If `N` inputs are provided, the first one is not used
in computation, but may be used during analysis or canonicalization as a bound
@@ -1168,6 +1184,10 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
If all `N` basis elements are provided, the linearize_index operation is said to
"have an outer bound".
+ As a convenience, and for symmetry with `getPaddedBasis()`, ifg the first
+ element of a set of `OpFoldResult`s passed to the builders of this operation is
+ `nullptr`, that element is ignored.
+
If the `disjoint` property is present, this is an optimization hint that,
for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
except that `%idx_0` may be negative to make the index as a whole negative.
@@ -1224,6 +1244,11 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
/// Return a vector that contains the basis of the operation, removing
/// the outer bound if one is present.
SmallVector<OpFoldResult> getEffectiveBasis();
+
+ /// Return the vector with one basis element per index operand of the operation.
+ /// If there is no outer bound specified, the leading entry of this basis will be
+ /// nullptr.
+ SmallVector<OpFoldResult> getPaddedBasis();
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 1c5466730a5589..b7c5e8eff8a8cd 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4520,6 +4520,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
Value linearIndex, ValueRange basis,
bool hasOuterBound) {
+ if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
+ hasOuterBound = false;
+ basis = basis.drop_front();
+ }
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
@@ -4533,6 +4537,10 @@ void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
Value linearIndex,
ArrayRef<OpFoldResult> basis,
bool hasOuterBound) {
+ if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
+ hasOuterBound = false;
+ basis = basis.drop_front();
+ }
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
@@ -4614,6 +4622,13 @@ SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
}
+SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
+ SmallVector<OpFoldResult> ret = getMixedBasis();
+ if (!hasOuterBound())
+ ret.insert(ret.begin(), OpFoldResult());
+ return ret;
+}
+
namespace {
// Drops delinearization indices that correspond to unit-extent basis
@@ -4632,25 +4647,27 @@ struct DropUnitExtentBasis
return zero.value();
};
- bool hasOuterBound = delinearizeOp.hasOuterBound();
// Replace all indices corresponding to unit-extent basis with 0.
// Remaining basis can be used to get a new `affine.delinearize_index` op.
SmallVector<OpFoldResult> newBasis;
- for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
- std::optional<int64_t> basisVal = getConstantIntValue(basis);
+ for (auto [index, basis] :
+ llvm::enumerate(delinearizeOp.getPaddedBasis())) {
+ std::optional<int64_t> basisVal =
+ basis ? getConstantIntValue(basis) : std::nullopt;
if (basisVal && *basisVal == 1)
- replacements[index + (hasOuterBound ? 0 : 1)] = getZero();
+ replacements[index] = getZero();
else
newBasis.push_back(basis);
}
- if (newBasis.size() == delinearizeOp.getStaticBasis().size())
+ if (newBasis.size() == delinearizeOp.getNumResults())
return rewriter.notifyMatchFailure(delinearizeOp,
"no unit basis elements");
- if (!newBasis.empty() || !hasOuterBound) {
+ if (!newBasis.empty()) {
+ // Will drop the leading nullptr from `basis` if there was no outer bound.
auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
- loc, delinearizeOp.getLinearIndex(), newBasis, hasOuterBound);
+ loc, delinearizeOp.getLinearIndex(), newBasis);
int newIndex = 0;
// Map back the new delinearized indices to the values they replace.
for (auto &replacement : replacements) {
@@ -4831,6 +4848,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
ValueRange multiIndex, ValueRange basis,
bool disjoint) {
+ if (!basis.empty() && basis.front() == Value())
+ basis = basis.drop_front();
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
@@ -4843,6 +4862,8 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
ValueRange multiIndex,
ArrayRef<OpFoldResult> basis,
bool disjoint) {
+ if (!basis.empty() && basis.front() == OpFoldResult())
+ basis = basis.drop_front();
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
@@ -4918,7 +4939,14 @@ SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
builder);
}
- return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+ return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+}
+
+SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
+ SmallVector<OpFoldResult> ret = getMixedBasis();
+ if (!hasOuterBound())
+ ret.insert(ret.begin(), OpFoldResult());
+ return ret;
}
namespace {
@@ -4980,38 +5008,202 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
}
};
-/// Cancel out linearize_index(delinearize_index(x, B), B).
+/// Return the product of `terms`, creating an `affine.apply` if any of them are
+/// non-constant values. If any of `terms` is `nullptr`, return `nullptr`.
+static OpFoldResult computeProduct(Location loc, OpBuilder &builder,
+ ArrayRef<OpFoldResult> terms) {
+ int64_t nDynamic = 0;
+ SmallVector<Value> dynamicPart;
+ AffineExpr result = builder.getAffineConstantExpr(1);
+ for (OpFoldResult term : terms) {
+ if (!term)
+ return term;
+ std::optional<int64_t> maybeConst = getConstantIntValue(term);
+ if (maybeConst) {
+ result = result * builder.getAffineConstantExpr(*maybeConst);
+ } else {
+ dynamicPart.push_back(term.get<Value>());
+ result = result * builder.getAffineSymbolExpr(nDynamic++);
+ }
+ }
+ if (auto constant = dyn_cast<AffineConstantExpr>(result))
+ return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
+ return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
+}
+
+/// If conseceutive outputs of a delinearize_index are linearized with the same
+/// bounds, canonicalize away the redundant arithmetic.
+///
+/// That is, if we have
+/// ```
+/// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
+/// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
+/// by (...e, B1, B2, ..., BK, ...f)
+/// ```
///
-/// That is, rewrite
+/// We can rewrite this to
+/// ```
+/// B = B1 * B2 ... BK
+/// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
+/// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
/// ```
-/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN)
-/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ...
-/// %bN)
+/// where we replace all results of %s unaffected by the change with results
+/// from %sMerged.
+///
+/// As a special case, if all results of the delinearize are merged in this way
+/// we can replace those usages with %x, thus cancelling the delinearization
+/// entirely, as in
/// ```
-/// to replacing `%y` with `%x`.
-struct CancelLinearizeOfDelinearizeExact final
+/// %s:3 = affine.delinearize_index %x into (2, 4, 8)
+/// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
+/// ```
+/// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
+struct CancelLinearizeOfDelinearizePortion final
: OpRewritePattern<affine::AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
+ struct Match {
+ AffineDelinearizeIndexOp delinearize;
+ unsigned linStart = 0;
+ unsigned delinStart = 0;
+ unsigned length = 0;
+ };
+
LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
PatternRewriter &rewriter) const override {
- auto delinearizeOp = linearizeOp.getMultiIndex()
- .front()
- .getDefiningOp<affine::AffineDelinearizeIndexOp>();
- if (!delinearizeOp)
- return rewriter.notifyMatchFailure(
- linearizeOp, "last entry doesn't come from a delinearize");
+ SmallVector<Match> matches;
+
+ const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
+ ArrayRef<OpFoldResult> linBasisRef = linBasis;
+
+ ValueRange multiIndex = linearizeOp.getMultiIndex();
+ unsigned numLinArgs = multiIndex.size();
+ unsigned linArgIdx = 0;
+ // We only want to replace one run from the same delinearize op per
+ // pattern invocation lest we run into invalidation issues.
+ llvm::SmallPtrSet<Operation *, 2> seen;
+ while (linArgIdx < numLinArgs) {
+ auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
+ if (!asResult) {
+ linArgIdx++;
+ continue;
+ }
- if (linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
- return rewriter.notifyMatchFailure(
- linearizeOp, "basis of linearize and delinearize don't match exactly "
- "(excluding outer bounds)");
+ auto delinearizeOp =
+ dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
+ if (!delinearizeOp) {
+ linArgIdx++;
+ continue;
+ }
+
+ /// Result 0 of the delinearize and argument 0 of the linearize can
+ /// leave their maximum value unspecified. However, even if this happens
+ /// we can still sometimes start the match process. Specifically, if
+ /// - The argument we're matching is result 0 and argument 0 (so the
+ /// bounds don't matter). For example,
+ ///
+ /// %0:2 = affine.delinearize_index %x into (8) : index, index
+ /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
+ /// allows cancellation
+ /// - The delinearization doesn't specify a bound, but the linearization
+ /// is `disjoint`, which asserts that the bound on the linearization is
+ /// correct.
+ unsigned firstDelinArg = asResult.getResultNumber();
+ SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
+ OpFoldResult firstDelinBound = delinBasis[firstDelinArg];
+ OpFoldResult firstLinBound = linBasis[linArgIdx];
+ bool boundsMatch = firstDelinBound == firstLinBound;
+ bool bothAtFront = linArgIdx == 0 && firstDelinArg == 0;
+ bool knownByDisjoint =
+ linearizeOp.getDisjoint() && firstDelinArg == 0 && !firstDelinBound;
+ if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
+ linArgIdx++;
+ continue;
+ }
- if (delinearizeOp.getResults() != linearizeOp.getMultiIndex())
+ unsigned j = 1;
+ unsigned numDelinOuts = delinearizeOp.getNumResults();
+ for (; j + linArgIdx < numLinArgs && j + firstDelinArg < numDelinOuts;
+ ++j) {
+ if (multiIndex[linArgIdx + j] !=
+ delinearizeOp.getResult(firstDelinArg + j))
+ break;
+ if (linBasis[linArgIdx + j] != delinBasis[firstDelinArg + j])
+ break;
+ }
+ // If there're multiple matches against the same delinearize_index,
+ // only rewrite the first one we find to prevent invalidations. The next
+ // ones will be taken caer of by subsequent pattern invocations.
+ if (j <= 1 || !seen.insert(delinearizeOp).second) {
+ linArgIdx++;
+ continue;
+ }
+ matches.push_back(Match{delinearizeOp, linArgIdx, firstDelinArg, j});
+ linArgIdx += j;
+ }
+
+ if (matches.empty())
return rewriter.notifyMatchFailure(
- linearizeOp, "not all indices come from delinearize");
+ linearizeOp, "no run of delinearize outputs to deal with");
+
+ SmallVector<std::tuple<Value, Value>> delinearizeReplacements;
+ SmallVector<Value> newIndex;
+ newIndex.reserve(numLinArgs);
+ SmallVector<OpFoldResult> newBasis;
+ newBasis.reserve(numLinArgs);
+ unsigned prevMatchEnd = 0;
+ for (Match m : matches) {
+ unsigned gap = m.linStart - prevMatchEnd;
+ llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
+ llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
+ // Update here so we don't forget this during early continues
+ prevMatchEnd = m.linStart + m.length;
+
+ // We use the slice from the linearize's basis above because of the
+ // "bounds inferred from `disjoint`" case above.
+ OpFoldResult newSize =
+ computeProduct(linearizeOp.getLoc(), rewriter,
+ linBasisRef.slice(m.linStart, m.length));
+
+ // Trivial case where we can just skip past the delinearize all together
+ if (m.length == m.delinearize.getNumResults()) {
+ newIndex.push_back(m.delinearize.getLinearIndex());
+ newBasis.push_back(newSize);
+ continue;
+ }
+ SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
+ newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
+ newDelinBasis.begin() + m.delinStart + m.length);
+ newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
+ auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
+ m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
+ newDelinBasis);
+
+ // Swap all the uses of the unaffected delinearize outputs to the new
+ // delinearization so that the old code can be removed if this
+ // linearize_index is the only user of the merged results.
+ llvm::append_range(
+ delinearizeReplacements,
+ llvm::zip_equal(
+ m.delinearize.getResults().take_front(m.delinStart),
+ newDelinearize.getResults().take_front(m.delinStart)));
+ llvm::append_range(
+ delinearizeReplacements,
+ llvm::zip_equal(
+ m.delinearize.getResults().drop_front(m.delinStart + m.length),
+ newDelinearize.getResults().drop_front(m.delinStart + 1)));
+
+ Value newLinArg = newDelinearize.getResult(m.delinStart);
+ newIndex.push_back(newLinArg);
+ newBasis.push_back(newSize);
+ }
+ llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
+ llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
+ rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
+ linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
- rewriter.replaceOp(linearizeOp, delinearizeOp.getLinearIndex());
+ for (auto [from, to] : delinearizeReplacements)
+ rewriter.replaceAllUsesWith(from, to);
return success();
}
};
@@ -5049,7 +5241,7 @@ struct DropLinearizeLeadingZero final
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
- patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
+ patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
DropLinearizeUnitComponentsIfDisjointOrZero>(context);
}
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index d3f61f7e503f9b..c153d32670d574 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1917,12 +1917,12 @@ func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index {
// -----
-// CHECK-LABEL: func @cancel_linearize_denearize_exact(
+// CHECK-LABEL: func @cancel_linearize_delinearize_exact(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index)
// CHECK: return %[[ARG0]]
-func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
+func.func @cancel_linearize_delinearize_exact(%arg0: index, %arg1: index, %arg2: index) -> index {
%0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
%1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
return %1 : index
@@ -1930,12 +1930,12 @@ func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: i
// -----
-// CHECK-LABEL: func @cancel_linearize_denearize_linearize_extra_bound(
+// CHECK-LABEL: func @cancel_linearize_delinearize_linearize_extra_bound(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/117637
More information about the Mlir-commits
mailing list