[Mlir-commits] [mlir] [mlir][Affine] Cancel delinearize_index ops fully reversed by apply (PR #163440)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Oct 14 12:31:24 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-affine
Author: Krzysztof Drewniak (krzysz00)
<details>
<summary>Changes</summary>
If an `affine.apply` uses every result of an
`affine.delinearize_index` operaration in an expresession of the form x_0 * S_0 + x_1 * S_1 + ... + x_n * S_n + ..., where S_i is the "stride" of the i-th delinerization result (the value it got divided by), then, that chain of additions contains the inverse of the affine.delinearize_index.
We don't want to compose affine.delinearize_index into affine.apply in general, since this leads to "simplifications" (mainly the `x % y => x - (x / y) * y` rewrite) thate are bad for code generation and algetbraic reasoning. However, if we do see an exact inverse, we should cancel it out.
---
Full diff: https://github.com/llvm/llvm-project/pull/163440.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+141)
- (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+130)
``````````diff
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 7e5ce26b5f733..291b729e07aef 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1125,6 +1125,142 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
return success(*map != initialMap);
}
+/// Recursively traverse `e`. If `e` or one of its sub-expressions has the form
+/// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`,
+/// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove)
+/// into `replacementsMap`. If no entries were added to `replacementsMap`,
+/// nothing was found.
+static void shortenAddChainsContainingAll(
+ AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove,
+ AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) {
+ auto binOp = dyn_cast<AffineBinaryOpExpr>(e);
+ if (!binOp)
+ return;
+ AffineExpr lhs = binOp.getLHS();
+ AffineExpr rhs = binOp.getRHS();
+ if (binOp.getKind() != AffineExprKind::Add) {
+ shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap);
+ shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap);
+ return;
+ }
+ SmallVector<AffineExpr> toPreserve;
+ llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove);
+ AffineExpr thisTerm = rhs;
+ AffineExpr nextTerm = lhs;
+
+ while (thisTerm) {
+ if (!ourTracker.erase(thisTerm)) {
+ toPreserve.push_back(thisTerm);
+ shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal,
+ replacementsMap);
+ }
+ auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm);
+ if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) {
+ thisTerm = nextTerm;
+ nextTerm = AffineExpr();
+ } else {
+ thisTerm = nextBinOp.getRHS();
+ nextTerm = nextBinOp.getLHS();
+ }
+ }
+ if (!ourTracker.empty())
+ return;
+ // We reverse the terms to be preserved here in order to preserve
+ // associativity between them.
+ AffineExpr newExpr = newVal;
+ for (AffineExpr preserved : llvm::reverse(toPreserve))
+ newExpr = newExpr + preserved;
+ replacementsMap.insert({e, newExpr});
+}
+
+/// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N +
+/// ...` (not necessarily in order) where the set of the `x_i` is the set of
+/// outputs of an `affine.delinearize_index` whos inverse is that expression,
+/// replace that expression with the input of that delinearize_index op.
+///
+/// `unitDimInput` is the input that was detected as the potential start to this
+/// replacement chain - if it isn't the rightmost result of the delinearization,
+/// this method fails. (This is intended to ensure we don't have redundant scans
+/// over the same expression).
+///
+/// While this currently only handles delinearizations with a constant basis,
+/// that isn't a fundamental limitation.
+///
+/// This is a utility function for `replaceDimOrSym` below.
+static LogicalResult replaceAffineDelinearizeIndexInverseExpression(
+ AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map,
+ SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &syms) {
+ if (!delinOp.getDynamicBasis().empty())
+ return failure();
+ if (resultToReplace != delinOp.getMultiIndex().back())
+ return failure();
+
+ MLIRContext *ctx = delinOp.getContext();
+ SmallVector<AffineExpr> resToExpr(delinOp.getNumResults(), AffineExpr());
+ for (auto [pos, dim] : llvm::enumerate(dims)) {
+ auto asResult = dyn_cast_if_present<OpResult>(dim);
+ if (!asResult)
+ continue;
+ if (asResult.getOwner() == delinOp.getOperation())
+ resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx);
+ }
+ for (auto [pos, sym] : llvm::enumerate(syms)) {
+ auto asResult = dyn_cast_if_present<OpResult>(sym);
+ if (!asResult)
+ continue;
+ if (asResult.getOwner() == delinOp.getOperation())
+ resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx);
+ }
+ if (llvm::any_of(resToExpr, [](auto e) { return e == AffineExpr(); })) {
+ return failure();
+ }
+
+ bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>);
+ int64_t stride = 1;
+ llvm::SmallDenseSet<AffineExpr, 4> expectedExprs;
+ // This isn't zip_equal since sometimes the delinearize basis is missing a
+ // size for the first result.
+ for (auto [binding, size] : llvm::zip(
+ llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) {
+ expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx));
+ stride *= size;
+ }
+ if (resToExpr.size() != delinOp.getStaticBasis().size())
+ expectedExprs.insert(resToExpr[0] * stride);
+
+ DenseMap<AffineExpr, AffineExpr> replacements;
+ AffineExpr delinInExpr = isDimReplacement
+ ? getAffineDimExpr(dims.size(), ctx)
+ : getAffineSymbolExpr(syms.size(), ctx);
+
+ for (AffineExpr e : map->getResults())
+ shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements);
+ if (replacements.empty())
+ return failure();
+
+ AffineMap origMap = *map;
+ if (isDimReplacement)
+ dims.push_back(delinOp.getLinearIndex());
+ else
+ syms.push_back(delinOp.getLinearIndex());
+ *map = origMap.replace(replacements, dims.size(), syms.size());
+
+ // Blank out dead dimensions and symbols
+ for (AffineExpr e : resToExpr) {
+ if (auto d = dyn_cast<AffineDimExpr>(e)) {
+ unsigned pos = d.getPosition();
+ if (!map->isFunctionOfDim(pos))
+ dims[pos] = nullptr;
+ }
+ if (auto s = dyn_cast<AffineSymbolExpr>(e)) {
+ unsigned pos = s.getPosition();
+ if (!map->isFunctionOfSymbol(pos))
+ syms[pos] = nullptr;
+ }
+ }
+ return success();
+}
+
/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
/// defining AffineApplyOp expression and operands.
/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
@@ -1157,6 +1293,11 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
syms);
}
+ if (auto delinOp = v.getDefiningOp<affine::AffineDelinearizeIndexOp>()) {
+ return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims,
+ syms);
+ }
+
auto affineApply = v.getDefiningOp<AffineApplyOp>();
if (!affineApply)
return failure();
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index e56079c1cccd4..1169cd1c29d74 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -2235,6 +2235,136 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind
// -----
+// CHECK-LABEL: func @delin_apply_cancel_exact
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK-COUNT-6: memref.store %[[ARG0]], %[[ARG1]][%[[ARG0]]]
+// CHECK-NOT: memref.store
+// CHECK: return
+func.func @delin_apply_cancel_exact(%arg0: index, %arg1: memref<?xindex>) {
+ %a:3 = affine.delinearize_index %arg0 into (4, 5) : index, index, index
+ %b:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index
+ %c:2 = affine.delinearize_index %arg0 into (20) : index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%a#2, %a#1, %a#0]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ %t2 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s2 * 20 + s1 * 5)>()[%a#2, %a#1, %a#0]
+ memref.store %t2, %arg1[%t2] : memref<?xindex>
+
+ %t3 = affine.apply affine_map<()[s0, s1, s2] -> (s1 * 20 + s2 * 5 + s0)>()[%a#2, %a#0, %a#1]
+ memref.store %t3, %arg1[%t3] : memref<?xindex>
+
+ %t4 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 5 + s2 * 20)>()[%b#2, %b#1, %b#0]
+ memref.store %t4, %arg1[%t4] : memref<?xindex>
+
+ %t5 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20)>()[%c#1, %c#0]
+ memref.store %t5, %arg1[%t5] : memref<?xindex>
+
+ %t6 = affine.apply affine_map<()[s0, s1] -> (s1 * 20 + s0)>()[%c#1, %c#0]
+ memref.store %t6, %arg1[%t5] : memref<?xindex>
+
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @delin_apply_cancel_exact_dim
+// CHECK: affine.for %[[arg1:.+]] = 0 to 256
+// CHECK: memref.store %[[arg1]]
+// CHECK: return
+func.func @delin_apply_cancel_exact_dim(%arg0: memref<?xindex>) {
+ affine.for %arg1 = 0 to 256 {
+ %a:3 = affine.delinearize_index %arg1 into (2, 2, 64) : index, index, index
+ %i = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 * 128 + d2 * 64)>(%a#2, %a#0, %a#1)
+ memref.store %i, %arg0[%i] : memref<?xindex>
+ }
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 512)>
+// CHECK-LABEL: func @delin_apply_cancel_const_term
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_const_term(%arg0: index, %arg1: memref<?xindex>) {
+ %a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1, s2] -> (s0 + s1 * 128 + s2 * 64 + 512)>()[%a#2, %a#0, %a#1]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 + 512)>
+// CHECK-LABEL: func @delin_apply_cancel_var_term
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>, %[[ARG2:.+]]: index)
+// CHECK: affine.apply #[[$MAP]]()[%[[ARG2]], %[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_var_term(%arg0: index, %arg1: memref<?xindex>, %arg2: index) {
+ %a:3 = affine.delinearize_index %arg0 into (2, 2, 64) : index, index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1, s2, s3] -> (s0 + s1 * 128 + s2 * 64 + s3 + 512)>()[%a#2, %a#0, %a#1, %arg2]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 2 + s0 ceildiv 4)>
+// CHECK-LABEL: func @delin_apply_cancel_nested_exprs
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: affine.apply #[[$MAP]]()[%[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_nested_exprs(%arg0: index, %arg1: memref<?xindex>) {
+ %a:2 = affine.delinearize_index %arg0 into (20) : index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1] -> ((s0 + s1 * 20) ceildiv 4 + (s1 * 20 + s0) * 2)>()[%a#1, %a#0]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)>
+// CHECK-LABEL: func @delin_apply_cancel_preserve_rotation
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: %[[A:.+]]:2 = affine.delinearize_index %[[ARG0]] into (20)
+// CHECK: affine.apply #[[$MAP]]()[%[[A]]#1, %[[ARG0]]]
+// CHECK: return
+func.func @delin_apply_cancel_preserve_rotation(%arg0: index, %arg1: memref<?xindex>) {
+ %a:2 = affine.delinearize_index %arg0 into (20) : index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20 + s0)>()[%a#1, %a#0]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 5)>
+// CHECK-LABEL: func @delin_apply_dont_cancel_partial
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: memref<?xindex>)
+// CHECK: %[[A:.+]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 5)
+// CHECK: affine.apply #[[$MAP]]()[%[[A]]#2, %[[A]]#1]
+// CHECK: return
+func.func @delin_apply_dont_cancel_partial(%arg0: index, %arg1: memref<?xindex>) {
+ %a:3 = affine.delinearize_index %arg0 into (3, 4, 5) : index, index, index
+
+ %t1 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 5)>()[%a#2, %a#1]
+ memref.store %t1, %arg1[%t1] : memref<?xindex>
+
+ return
+}
+
+// -----
+
// CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
// CHECK-SAME: (%[[ARG0:.*]]: index)
// CHECK: %[[RET:.*]]:3 = affine.delinearize_index %[[ARG0]] into (3, 4, 2) : index, index
``````````
</details>
https://github.com/llvm/llvm-project/pull/163440
More information about the Mlir-commits
mailing list