[Mlir-commits] [mlir] [mlir][Affine] Genarilze the linearize(delinearize()) simplifications (PR #117637)
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Dec 2 09:31:03 PST 2024
================
@@ -4980,38 +5008,202 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
}
};
-/// Cancel out linearize_index(delinearize_index(x, B), B).
+/// Return the product of `terms`, creating an `affine.apply` if any of them are
+/// non-constant values. If any of `terms` is `nullptr`, return `nullptr`.
+static OpFoldResult computeProduct(Location loc, OpBuilder &builder,
+ ArrayRef<OpFoldResult> terms) {
+ int64_t nDynamic = 0;
+ SmallVector<Value> dynamicPart;
+ AffineExpr result = builder.getAffineConstantExpr(1);
+ for (OpFoldResult term : terms) {
+ if (!term)
+ return term;
+ std::optional<int64_t> maybeConst = getConstantIntValue(term);
+ if (maybeConst) {
+ result = result * builder.getAffineConstantExpr(*maybeConst);
+ } else {
+ dynamicPart.push_back(term.get<Value>());
+ result = result * builder.getAffineSymbolExpr(nDynamic++);
+ }
+ }
+ if (auto constant = dyn_cast<AffineConstantExpr>(result))
+ return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
+ return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
+}
+
+/// If conseceutive outputs of a delinearize_index are linearized with the same
+/// bounds, canonicalize away the redundant arithmetic.
+///
+/// That is, if we have
+/// ```
+/// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
+/// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
+/// by (...e, B1, B2, ..., BK, ...f)
+/// ```
///
-/// That is, rewrite
+/// We can rewrite this to
+/// ```
+/// B = B1 * B2 ... BK
+/// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
+/// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
/// ```
-/// %0:N = affine.delinearize_index %x by (%b1, %b2, ... %bN)
-/// %y = affine.linearize_index [%0#0, %0#1, ... %0#(N-1)] by (%b1, %b2, ...
-/// %bN)
+/// where we replace all results of %s unaffected by the change with results
+/// from %sMerged.
+///
+/// As a special case, if all results of the delinearize are merged in this way
+/// we can replace those usages with %x, thus cancelling the delinearization
+/// entirely, as in
/// ```
-/// to replacing `%y` with `%x`.
-struct CancelLinearizeOfDelinearizeExact final
+/// %s:3 = affine.delinearize_index %x into (2, 4, 8)
+/// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
+/// ```
+/// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
+struct CancelLinearizeOfDelinearizePortion final
: OpRewritePattern<affine::AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;
+ struct Match {
+ AffineDelinearizeIndexOp delinearize;
+ unsigned linStart = 0;
+ unsigned delinStart = 0;
+ unsigned length = 0;
+ };
+
LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
PatternRewriter &rewriter) const override {
- auto delinearizeOp = linearizeOp.getMultiIndex()
- .front()
- .getDefiningOp<affine::AffineDelinearizeIndexOp>();
- if (!delinearizeOp)
- return rewriter.notifyMatchFailure(
- linearizeOp, "last entry doesn't come from a delinearize");
+ SmallVector<Match> matches;
+
+ const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
+ ArrayRef<OpFoldResult> linBasisRef = linBasis;
+
+ ValueRange multiIndex = linearizeOp.getMultiIndex();
+ unsigned numLinArgs = multiIndex.size();
+ unsigned linArgIdx = 0;
+ // We only want to replace one run from the same delinearize op per
+ // pattern invocation lest we run into invalidation issues.
+ llvm::SmallPtrSet<Operation *, 2> seen;
+ while (linArgIdx < numLinArgs) {
+ auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
+ if (!asResult) {
+ linArgIdx++;
+ continue;
+ }
- if (linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
- return rewriter.notifyMatchFailure(
- linearizeOp, "basis of linearize and delinearize don't match exactly "
- "(excluding outer bounds)");
+ auto delinearizeOp =
+ dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
+ if (!delinearizeOp) {
+ linArgIdx++;
+ continue;
+ }
+
+ /// Result 0 of the delinearize and argument 0 of the linearize can
+ /// leave their maximum value unspecified. However, even if this happens
+ /// we can still sometimes start the match process. Specifically, if
+ /// - The argument we're matching is result 0 and argument 0 (so the
+ /// bounds don't matter). For example,
+ ///
+ /// %0:2 = affine.delinearize_index %x into (8) : index, index
+ /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
+ /// allows cancellation
+ /// - The delinearization doesn't specify a bound, but the linearization
+ /// is `disjoint`, which asserts that the bound on the linearization is
+ /// correct.
+ unsigned firstDelinArg = asResult.getResultNumber();
----------------
krzysz00 wrote:
Yeah, solid point. I went with `delinArgIdx`
https://github.com/llvm/llvm-project/pull/117637
More information about the Mlir-commits
mailing list