[Mlir-commits] [mlir] 0069867 - [mlir][affine] Add ValueBounds-based simplification for delinearize(linearize) pairs (#187245)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 30 02:25:10 PDT 2026
Author: Zhewen Yu
Date: 2026-03-30T10:25:04+01:00
New Revision: 00698678e404699f6c776679272a7e3392c46306
URL: https://github.com/llvm/llvm-project/commit/00698678e404699f6c776679272a7e3392c46306
DIFF: https://github.com/llvm/llvm-project/commit/00698678e404699f6c776679272a7e3392c46306.diff
LOG: [mlir][affine] Add ValueBounds-based simplification for delinearize(linearize) pairs (#187245)
`affine.linearize_index` pairs
(`CancelDelinearizeOfLinearizeDisjointExactTail`) only match when basis
elements are exactly equal as `OpFoldResult` values. This means they
cannot simplify cases where dynamic basis products are semantically
equal but represented by different SSA values or affine expressions.
This patch adds a new pass `affine-simplify-with-bounds` with two
rewrite patterns that use `ValueBoundsConstraintSet` to prove equality
of basis products:
- **`SimplifyDelinearizeOfLinearizeDisjointManyToOneTail`**: matches
when multiple consecutive linearize dimensions have a product equal to a
single delinearize dimension (many-to-one).
- **`SimplifyDelinearizeOfLinearizeDisjointOneToManyTail`**: matches
when a single linearize dimension equals the product of multiple
consecutive delinearize dimensions (one-to-many).
Both patterns scan from the tail (innermost dimensions) and support
partial matching. Unmatched prefix dimensions are left as residual
linearize/delinearize operations.
Assisted-by: Cursor (Claude)
---------
Signed-off-by: Yu-Zhewen <zhewenyu at amd.com>
Added:
mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp
mlir/test/Dialect/Affine/simplify-with-bounds.mlir
Modified:
mlir/include/mlir/Dialect/Affine/Transforms/Passes.td
mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td b/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td
index 430edffc29038..03f2d532eb016 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td
@@ -430,6 +430,18 @@ def SimplifyAffineMinMaxPass : InterfacePass<"affine-simplify-min-max", "Functio
}];
}
+def SimplifyAffineWithBounds : Pass<"affine-simplify-with-bounds"> {
+ let summary = "Simplify affine index operations using value bounds analysis";
+ let description = [{
+ This pass simplifies `affine.delinearize_index` / `affine.linearize_index`
+ pairs by using value bounds analysis to match basis products. Unlike the
+ built-in canonicalization patterns which only use exact `OpFoldResult`
+ comparisons, this pass can prove equality of dynamic basis products through
+ `ValueBoundsConstraintSet`.
+ }];
+ let dependentDialects = ["affine::AffineDialect", "arith::ArithDialect"];
+}
+
def AffineExpandIndexOps : Pass<"affine-expand-index-ops"> {
let summary = "Lower affine operations operating on indices into more fundamental operations";
let constructor = "mlir::affine::createAffineExpandIndexOpsPass()";
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index 272054448374e..84adb8e6a1e6d 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -49,6 +49,10 @@ LogicalResult lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
LogicalResult lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
AffineLinearizeIndexOp op);
+/// Populate patterns that simplify `affine.delinearize_index` /
+/// `affine.linearize_index` pairs using value bounds analysis.
+void populateSimplifyAffineWithBoundsPatterns(RewritePatternSet &patterns);
+
/// Populate patterns that expand affine index operations into more fundamental
/// operations (not necessarily restricted to Affine dialect).
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index 7bce124817032..9d912139810b2 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
RaiseMemrefDialect.cpp
ReifyValueBounds.cpp
SuperVectorize.cpp
+ SimplifyAffineWithBounds.cpp
SimplifyAffineStructures.cpp
SimplifyAffineMinMax.cpp
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp
new file mode 100644
index 0000000000000..6e7d5d91334c0
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp
@@ -0,0 +1,255 @@
+//===- SimplifyAffineWithBounds.cpp ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements simplification patterns for affine.delinearize_index /
+// affine.linearize_index pairs using value bounds analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "affine-simplify-with-bounds"
+
+using namespace mlir;
+using namespace mlir::affine;
+
+/// Accumulate a single basis element into the running product expression.
+/// Static values become affine constants, and dynamic values become symbols.
+static void buildProductExpr(OpFoldResult basis, AffineExpr &productExpr,
+ SmallVectorImpl<Value> &operands,
+ MLIRContext *ctx) {
+ if (auto val = getConstantIntValue(basis)) {
+ productExpr = productExpr * getAffineConstantExpr(*val, ctx);
+ } else {
+ operands.push_back(cast<Value>(basis));
+ productExpr = productExpr * getAffineSymbolExpr(operands.size() - 1, ctx);
+ }
+}
+
+/// Try to find k consecutive elements from `lhs` (starting from tail offset)
+/// whose product equals the single next element from `rhs`.
+/// The product is accumulated incrementally to avoid redundant computation.
+/// Returns the number of matched elements k, or std::nullopt if no match.
+static std::optional<size_t> tryMatchProduct(ArrayRef<OpFoldResult> lhs,
+ size_t lhsTailConsumed,
+ ArrayRef<OpFoldResult> rhs,
+ size_t rhsTailConsumed,
+ MLIRContext *ctx) {
+ // Build a Variable for the single rhs element.
+ AffineExpr rhsExpr = getAffineConstantExpr(1, ctx);
+ SmallVector<Value> rhsOperands;
+ buildProductExpr(rhs[rhs.size() - rhsTailConsumed - 1], rhsExpr, rhsOperands,
+ ctx);
+ ValueBoundsConstraintSet::Variable rhsVar(
+ AffineMap::get(0, rhsOperands.size(), rhsExpr, ctx), rhsOperands);
+
+ // Incrementally accumulate lhs product and check for equality.
+ AffineExpr lhsExpr = getAffineConstantExpr(1, ctx);
+ SmallVector<Value> lhsOperands;
+ for (size_t k = 1; k + lhsTailConsumed <= lhs.size(); ++k) {
+ buildProductExpr(lhs[lhs.size() - lhsTailConsumed - k], lhsExpr,
+ lhsOperands, ctx);
+ AffineMap lhsMap = AffineMap::get(0, lhsOperands.size(), lhsExpr, ctx);
+ ValueBoundsConstraintSet::Variable lhsVar(lhsMap, lhsOperands);
+ FailureOr<bool> result = ValueBoundsConstraintSet::areEqual(lhsVar, rhsVar);
+ if (succeeded(result) && *result)
+ return k;
+ }
+ return std::nullopt;
+}
+
+namespace {
+
+/// Simplify delinearize(linearize) pairs from the tail by matching groups of
+/// dimensions whose basis products are equal via ValueBounds analysis.
+///
+/// For each step from the tail, tries:
+/// 1. Many-to-one: k linearize dims -> 1 delinearize dim
+/// 2. One-to-many: 1 linearize dim -> k delinearize dims
+///
+/// Matched trailing dimensions are peeled off. Unmatched prefix dimensions
+/// are left as residual linearize/delinearize operations.
+///
+/// Example (many-to-one, D*E == Z):
+/// %lin = affine.linearize_index disjoint [%a, %b, %c, %d, %e]
+/// by (A, B, C, D, E)
+/// %result:3 = affine.delinearize_index %lin into (X, Y, Z)
+/// ->
+/// %prefix_lin = affine.linearize_index disjoint [%a, %b, %c] by (A, B, C)
+/// %prefix:2 = affine.delinearize_index %prefix_lin into (X, Y)
+/// %tail = affine.linearize_index disjoint [%d, %e] by (D, E)
+/// %result = [%prefix#0, %prefix#1, %tail]
+struct SimplifyDelinearizeOfLinearizeDisjoint final
+ : OpRewritePattern<AffineDelinearizeIndexOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(AffineDelinearizeIndexOp delinearizeOp,
+ PatternRewriter &rewriter) const override {
+ auto linearizeOp =
+ delinearizeOp.getLinearIndex().getDefiningOp<AffineLinearizeIndexOp>();
+ if (!linearizeOp)
+ return rewriter.notifyMatchFailure(delinearizeOp,
+ "index doesn't come from linearize");
+
+ if (!linearizeOp.getDisjoint())
+ return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
+
+ SmallVector<OpFoldResult> linBasis = linearizeOp.getMixedBasis();
+ SmallVector<OpFoldResult> delinBasis = delinearizeOp.getMixedBasis();
+ ValueRange linInputs = linearizeOp.getMultiIndex();
+ MLIRContext *ctx = rewriter.getContext();
+
+ // Track how many elements consumed from each tail.
+ size_t linTailConsumed = 0;
+ size_t delinTailConsumed = 0;
+
+ // For each matched group (innermost first), record the number of
+ // linearize and delinearize dimensions it spans. Many-to-one groups
+ // have linCount > 1, one-to-many groups have delinCount > 1.
+ SmallVector<std::pair<size_t, size_t>> matchedGroups;
+
+ while (linTailConsumed < linBasis.size() &&
+ delinTailConsumed < delinBasis.size()) {
+ // Try many-to-one: k lin dims -> 1 delin dim.
+ if (std::optional<size_t> k = tryMatchProduct(
+ linBasis, linTailConsumed, delinBasis, delinTailConsumed, ctx)) {
+ matchedGroups.emplace_back(*k, 1);
+ linTailConsumed += *k;
+ delinTailConsumed += 1;
+ continue;
+ }
+ // Try one-to-many: 1 lin dim -> k delin dims.
+ if (std::optional<size_t> k = tryMatchProduct(
+ delinBasis, delinTailConsumed, linBasis, linTailConsumed, ctx)) {
+ matchedGroups.emplace_back(1, *k);
+ delinTailConsumed += *k;
+ linTailConsumed += 1;
+ continue;
+ }
+ break;
+ }
+
+ if (matchedGroups.empty())
+ return rewriter.notifyMatchFailure(delinearizeOp,
+ "no trailing dimensions matched");
+
+ SmallVector<Value> results;
+
+ // Build residual prefix ops for unmatched dimensions.
+ if (delinTailConsumed < delinBasis.size()) {
+ // Partial match: create residual linearize + delinearize for the
+ // unmatched prefix.
+ Value residualLinearize = AffineLinearizeIndexOp::create(
+ rewriter, linearizeOp.getLoc(), linInputs.drop_back(linTailConsumed),
+ ArrayRef(linBasis).drop_back(linTailConsumed),
+ linearizeOp.getDisjoint());
+ auto residualDelinearize = AffineDelinearizeIndexOp::create(
+ rewriter, delinearizeOp.getLoc(), residualLinearize,
+ ArrayRef(delinBasis).drop_back(delinTailConsumed),
+ delinearizeOp.hasOuterBound());
+ results.append(residualDelinearize.getResults().begin(),
+ residualDelinearize.getResults().end());
+ } else if (!delinearizeOp.hasOuterBound()) {
+ // All basis elements consumed, but the original delinearize has no outer
+ // bound which requires special handling.
+ ValueRange remainingInputs = linInputs.drop_back(linTailConsumed);
+ if (remainingInputs.empty()) {
+ // The outermost delinearize result is guaranteed to be zero.
+ results.push_back(arith::ConstantIndexOp::create(
+ rewriter, delinearizeOp.getLoc(), 0));
+ } else if (remainingInputs.size() == 1) {
+ // Pass through the single remaining input.
+ results.push_back(remainingInputs.front());
+ } else {
+ // Re-linearize the remaining inputs to produce the outermost result.
+ Value newLin = AffineLinearizeIndexOp::create(
+ rewriter, linearizeOp.getLoc(), remainingInputs,
+ ArrayRef(linBasis).drop_back(linTailConsumed),
+ linearizeOp.getDisjoint());
+ results.push_back(newLin);
+ }
+ }
+
+ // Build results for each matched group.
+ size_t linInputOffset = linInputs.size() - linTailConsumed;
+ size_t linBasisOffset = linBasis.size() - linTailConsumed;
+ size_t delinBasisOffset = delinBasis.size() - delinTailConsumed;
+ for (auto [linCount, delinCount] : llvm::reverse(matchedGroups)) {
+ if (linCount == 1 && delinCount == 1) {
+ // Exact 1:1 match: pass through directly.
+ results.push_back(linInputs[linInputOffset]);
+ } else if (linCount > 1) {
+ // Many-to-one: re-linearize the group's lin inputs.
+ Value newLin = AffineLinearizeIndexOp::create(
+ rewriter, linearizeOp.getLoc(),
+ linInputs.slice(linInputOffset, linCount),
+ ArrayRef(linBasis).slice(linBasisOffset, linCount),
+ /*disjoint=*/true);
+ results.push_back(newLin);
+ } else {
+ // One-to-many: delinearize the single lin input.
+ auto newDelin = AffineDelinearizeIndexOp::create(
+ rewriter, delinearizeOp.getLoc(), linInputs[linInputOffset],
+ ArrayRef(delinBasis).slice(delinBasisOffset, delinCount),
+ /*hasOuterBound=*/true);
+ results.append(newDelin.getResults().begin(),
+ newDelin.getResults().end());
+ }
+ linInputOffset += linCount;
+ linBasisOffset += linCount;
+ delinBasisOffset += delinCount;
+ }
+
+ rewriter.replaceOp(delinearizeOp, results);
+ return success();
+ }
+};
+
+} // namespace
+
+void affine::populateSimplifyAffineWithBoundsPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<SimplifyDelinearizeOfLinearizeDisjoint>(patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass definition
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace affine {
+#define GEN_PASS_DEF_SIMPLIFYAFFINEWITHBOUNDS
+#include "mlir/Dialect/Affine/Transforms/Passes.h.inc"
+} // namespace affine
+} // namespace mlir
+
+namespace {
+struct SimplifyAffineWithBoundsPass
+ : affine::impl::SimplifyAffineWithBoundsBase<SimplifyAffineWithBoundsPass> {
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ // Add canonicalization patterns first so cheap exact-match cases are
+ // handled without invoking value bounds analysis.
+ AffineDelinearizeIndexOp::getCanonicalizationPatterns(patterns,
+ &getContext());
+ AffineLinearizeIndexOp::getCanonicalizationPatterns(patterns,
+ &getContext());
+ populateSimplifyAffineWithBoundsPatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Affine/simplify-with-bounds.mlir b/mlir/test/Dialect/Affine/simplify-with-bounds.mlir
new file mode 100644
index 0000000000000..a9c517a5c5add
--- /dev/null
+++ b/mlir/test/Dialect/Affine/simplify-with-bounds.mlir
@@ -0,0 +1,167 @@
+// RUN: mlir-opt -affine-simplify-with-bounds %s | FileCheck %s
+
+// CHECK-LABEL: func @many_to_one_static_tail
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index, %[[C:.*]]: index
+// CHECK-DAG: %[[LIN:.*]] = affine.linearize_index disjoint [%[[B]], %[[C]]] by (8, 8)
+// CHECK-DAG: return %[[A]], %[[LIN]]
+func.func @many_to_one_static_tail(%a: index, %b: index, %c: index) -> (index, index) {
+ %0 = affine.linearize_index disjoint [%a, %b, %c] by (4, 8, 8) : index
+ %1:2 = affine.delinearize_index %0 into (4, 64) : index, index
+ return %1#0, %1#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @many_to_one_dynamic_tail
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index, %[[C:.*]]: index, %[[DYN:.*]]: index
+// CHECK-DAG: %[[LIN:.*]] = affine.linearize_index disjoint [%[[B]], %[[C]]] by (%[[DYN]], 8)
+// CHECK-DAG: return %[[A]], %[[LIN]]
+func.func @many_to_one_dynamic_tail(%a: index, %b: index, %c: index, %dyn: index) -> (index, index) {
+ %dyn_times_8 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%dyn]
+ %0 = affine.linearize_index disjoint [%a, %b, %c] by (4, %dyn, 8) : index
+ %1:2 = affine.delinearize_index %0 into (4, %dyn_times_8) : index, index
+ return %1#0, %1#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @one_to_many_static_tail
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index
+// CHECK-DAG: %[[DELIN:.*]]:2 = affine.delinearize_index %[[B]] into (8, 8)
+// CHECK-DAG: return %[[A]], %[[DELIN]]#0, %[[DELIN]]#1
+func.func @one_to_many_static_tail(%a: index, %b: index) -> (index, index, index) {
+ %0 = affine.linearize_index disjoint [%a, %b] by (4, 64) : index
+ %1:3 = affine.delinearize_index %0 into (4, 8, 8) : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @one_to_many_dynamic_tail
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index, %[[DYN:.*]]: index
+// CHECK-DAG: %[[DELIN:.*]]:2 = affine.delinearize_index %[[B]] into (%[[DYN]], 8)
+// CHECK-DAG: return %[[A]], %[[DELIN]]#0, %[[DELIN]]#1
+func.func @one_to_many_dynamic_tail(%a: index, %b: index, %dyn: index) -> (index, index, index) {
+ %dyn_times_8 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%dyn]
+ %0 = affine.linearize_index disjoint [%a, %b] by (4, %dyn_times_8) : index
+ %1:3 = affine.delinearize_index %0 into (4, %dyn, 8) : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @one_to_one_dynamic_tail
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index, %[[DYN:.*]]: index
+// CHECK-DAG: return %[[A]], %[[B]]
+func.func @one_to_one_dynamic_tail(%a: index, %b: index, %dyn: index) -> (index, index) {
+ %dyn_ceildiv = affine.apply affine_map<()[s0] -> (s0 ceildiv 128)>()[%dyn]
+ %dyn_ceildiv_dup = affine.apply affine_map<()[s0] -> (s0 ceildiv 128)>()[%dyn]
+ %0 = affine.linearize_index disjoint [%a, %b] by (4, %dyn_ceildiv) : index
+ %1:2 = affine.delinearize_index %0 into (4, %dyn_ceildiv_dup) : index, index
+ return %1#0, %1#1 : index, index
+}
+
+// -----
+
+// Mixed: many-to-one and one-to-one in the same pair.
+// The many-to-one pattern should match [%c, %d] -> 1 delin dim,
+// and canonicalization handles the 1:1 prefix.
+// CHECK-LABEL: func @mixed_many_to_one_and_one_to_one
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index, %[[C:.*]]: index, %[[D:.*]]: index
+// CHECK-DAG: %[[LIN:.*]] = affine.linearize_index disjoint [%[[C]], %[[D]]] by (4, 8)
+// CHECK-DAG: return %[[A]], %[[B]], %[[LIN]]
+func.func @mixed_many_to_one_and_one_to_one(%a: index, %b: index, %c: index, %d: index) -> (index, index, index) {
+ %0 = affine.linearize_index disjoint [%a, %b, %c, %d] by (2, 3, 4, 8) : index
+ %1:3 = affine.delinearize_index %0 into (2, 3, 32) : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// Partial match: only tail matches, prefix is left as residual.
+// CHECK-LABEL: func @partial_tail_many_to_one
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index, %[[C:.*]]: index, %[[D:.*]]: index
+// CHECK: %[[RESLIN:.*]] = affine.linearize_index disjoint [%[[A]], %[[B]]] by (5, 3)
+// CHECK: %[[RESDELIN:.*]]:2 = affine.delinearize_index %[[RESLIN]] into (7, 9)
+// CHECK: %[[TAILLIN:.*]] = affine.linearize_index disjoint [%[[C]], %[[D]]] by (4, 8)
+// CHECK: return %[[RESDELIN]]#0, %[[RESDELIN]]#1, %[[TAILLIN]]
+func.func @partial_tail_many_to_one(%a: index, %b: index, %c: index, %d: index) -> (index, index, index) {
+ %0 = affine.linearize_index disjoint [%a, %b, %c, %d] by (5, 3, 4, 8) : index
+ %1:3 = affine.delinearize_index %0 into (7, 9, 32) : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// Many-to-one with no outer bound: all basis elements consumed.
+// The outermost delinearize result (unbounded) passes through from the
+// outermost linearize input.
+// CHECK-LABEL: func @many_to_one_no_outer_bound
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index, %[[C:.*]]: index
+// CHECK-DAG: %[[LIN:.*]] = affine.linearize_index disjoint [%[[B]], %[[C]]] by (8, 8)
+// CHECK-DAG: return %[[A]], %[[LIN]]
+func.func @many_to_one_no_outer_bound(%a: index, %b: index, %c: index) -> (index, index) {
+ %0 = affine.linearize_index disjoint [%a, %b, %c] by (8, 8) : index
+ %1:2 = affine.delinearize_index %0 into (64) : index, index
+ return %1#0, %1#1 : index, index
+}
+
+// -----
+
+// One-to-many with no outer bound: all basis elements consumed.
+// CHECK-LABEL: func @one_to_many_no_outer_bound
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index
+// CHECK-DAG: %[[DELIN:.*]]:2 = affine.delinearize_index %[[B]] into (8, 8)
+// CHECK-DAG: return %[[A]], %[[DELIN]]#0, %[[DELIN]]#1
+func.func @one_to_many_no_outer_bound(%a: index, %b: index) -> (index, index, index) {
+ %0 = affine.linearize_index disjoint [%a, %b] by (64) : index
+ %1:3 = affine.delinearize_index %0 into (8, 8) : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// Partial match with empty residual linearize.
+// CHECK-LABEL: func @partial_match_empty_residual_lin
+// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: return %[[C0]], %[[A]], %[[B]]
+func.func @partial_match_empty_residual_lin(%a: index, %b: index) -> (index, index, index) {
+ %0 = affine.linearize_index disjoint [%a, %b] by (4, 8) : index
+ %1:3 = affine.delinearize_index %0 into (10, 4, 8) : index, index, index
+ return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// Negative test: no disjoint flag.
+// CHECK-LABEL: func @no_disjoint
+// CHECK: affine.linearize_index [
+// CHECK: affine.delinearize_index
+func.func @no_disjoint(%a: index, %b: index, %c: index) -> (index, index) {
+ %0 = affine.linearize_index [%a, %b, %c] by (4, 8, 8) : index
+ %1:2 = affine.delinearize_index %0 into (4, 64) : index, index
+ return %1#0, %1#1 : index, index
+}
+
+// -----
+
+// Negative test: products don't match.
+// CHECK-LABEL: func @products_dont_match
+// CHECK: affine.linearize_index disjoint
+// CHECK: affine.delinearize_index
+func.func @products_dont_match(%a: index, %b: index, %c: index) -> (index, index) {
+ %0 = affine.linearize_index disjoint [%a, %b, %c] by (4, 8, 8) : index
+ %1:2 = affine.delinearize_index %0 into (4, 63) : index, index
+ return %1#0, %1#1 : index, index
+}
+
+// -----
+
+// Negative test: input not from linearize.
+// CHECK-LABEL: func @input_not_linearize
+// CHECK: affine.delinearize_index %{{.*}} into
+func.func @input_not_linearize(%x: index) -> (index, index) {
+ %0:2 = affine.delinearize_index %x into (4, 8) : index, index
+ return %0#0, %0#1 : index, index
+}
More information about the Mlir-commits
mailing list