[Mlir-commits] [mlir] [mlir][affine] Add ValueBounds-based simplification for delinearize(linearize) pairs (PR #187245)

Zhewen Yu llvmlistbot at llvm.org
Fri Mar 27 05:14:38 PDT 2026


https://github.com/Yu-Zhewen updated https://github.com/llvm/llvm-project/pull/187245

>From 2d55c6e1161f6dc1c49ad21c7a69d3cfa5ebb7b1 Mon Sep 17 00:00:00 2001
From: Yu-Zhewen <zhewenyu at amd.com>
Date: Wed, 18 Mar 2026 04:38:40 -0700
Subject: [PATCH 1/3] simplify affine with bounds

Signed-off-by: Yu-Zhewen <zhewenyu at amd.com>
---
 .../mlir/Dialect/Affine/Transforms/Passes.td  |  12 +
 .../Dialect/Affine/Transforms/Transforms.h    |   4 +
 .../Dialect/Affine/Transforms/CMakeLists.txt  |   1 +
 .../Transforms/SimplifyAffineWithBounds.cpp   | 373 ++++++++++++++++++
 .../Dialect/Affine/simplify-with-bounds.mlir  | 156 ++++++++
 5 files changed, 546 insertions(+)
 create mode 100644 mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp
 create mode 100644 mlir/test/Dialect/Affine/simplify-with-bounds.mlir

diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td b/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td
index 430edffc29038..03f2d532eb016 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td
@@ -430,6 +430,18 @@ def SimplifyAffineMinMaxPass : InterfacePass<"affine-simplify-min-max", "Functio
   }];
 }
 
+def SimplifyAffineWithBounds : Pass<"affine-simplify-with-bounds"> {
+  let summary = "Simplify affine index operations using value bounds analysis";
+  let description = [{
+    This pass simplifies `affine.delinearize_index` / `affine.linearize_index`
+    pairs by using value bounds analysis to match basis products. Unlike the
+    built-in canonicalization patterns which only use exact `OpFoldResult`
+    comparisons, this pass can prove equality of dynamic basis products through
+    `ValueBoundsConstraintSet`.
+  }];
+  let dependentDialects = ["affine::AffineDialect", "arith::ArithDialect"];
+}
+
 def AffineExpandIndexOps : Pass<"affine-expand-index-ops"> {
   let summary = "Lower affine operations operating on indices into more fundamental operations";
   let constructor = "mlir::affine::createAffineExpandIndexOpsPass()";
diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
index 272054448374e..84adb8e6a1e6d 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h
@@ -49,6 +49,10 @@ LogicalResult lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
 LogicalResult lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
                                           AffineLinearizeIndexOp op);
 
+/// Populate patterns that simplify `affine.delinearize_index` /
+/// `affine.linearize_index` pairs using value bounds analysis.
+void populateSimplifyAffineWithBoundsPatterns(RewritePatternSet &patterns);
+
 /// Populate patterns that expand affine index operations into more fundamental
 /// operations (not necessarily restricted to Affine dialect).
 void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
index 7bce124817032..9d912139810b2 100644
--- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt
@@ -17,6 +17,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
   RaiseMemrefDialect.cpp
   ReifyValueBounds.cpp
   SuperVectorize.cpp
+  SimplifyAffineWithBounds.cpp
   SimplifyAffineStructures.cpp
   SimplifyAffineMinMax.cpp
 
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp
new file mode 100644
index 0000000000000..4a0d199c543f3
--- /dev/null
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp
@@ -0,0 +1,373 @@
+//===- SimplifyAffineIndexOps.cpp - Simplify affine index ops -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements simplification patterns for affine.delinearize_index /
+// affine.linearize_index pairs using value bounds analysis.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Transforms/Passes.h"
+#include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "affine-simplify-with-bounds"
+
+using namespace mlir;
+using namespace mlir::affine;
+
+/// Build a ValueBoundsConstraintSet::Variable representing the product of
+/// the given basis elements. Static elements become constants in an affine
+/// expression; dynamic elements become symbols.
+static ValueBoundsConstraintSet::Variable
+buildProductVariable(ArrayRef<OpFoldResult> bases, MLIRContext *ctx) {
+  AffineExpr productExpr = getAffineConstantExpr(1, ctx);
+  SmallVector<Value> operands;
+  for (OpFoldResult basis : bases) {
+    if (auto attr = dyn_cast<Attribute>(basis)) {
+      int64_t val = cast<IntegerAttr>(attr).getInt();
+      productExpr = productExpr * getAffineConstantExpr(val, ctx);
+    } else {
+      Value val = cast<Value>(basis);
+      operands.push_back(val);
+      productExpr = productExpr * getAffineSymbolExpr(operands.size() - 1, ctx);
+    }
+  }
+  AffineMap productMap = AffineMap::get(0, operands.size(), productExpr, ctx);
+  return ValueBoundsConstraintSet::Variable(productMap, operands);
+}
+
+/// Check if two groups of basis elements have equal products using value bounds
+/// analysis.
+static bool areProductsEqual(ArrayRef<OpFoldResult> lhs,
+                             ArrayRef<OpFoldResult> rhs, MLIRContext *ctx) {
+  auto lhsVar = buildProductVariable(lhs, ctx);
+  auto rhsVar = buildProductVariable(rhs, ctx);
+  FailureOr<bool> result = ValueBoundsConstraintSet::areEqual(lhsVar, rhsVar);
+  return succeeded(result) && *result;
+}
+
+namespace {
+
+/// Simplify delinearize(linearize) pairs from the tail by matching multiple
+/// linearize dimensions whose product equals a single delinearize dimension
+/// (many-to-one).
+///
+/// Scans from the rightmost basis elements. For each trailing delinearize
+/// dimension, accumulates consecutive linearize dimension products until an
+/// equal product is found via ValueBounds. Matched trailing dimensions are
+/// peeled off, and residual ops are created for unmatched prefixes.
+///
+/// Example:
+///   %lin = affine.linearize_index disjoint [%a, %b, %c, %d, %e]
+///              by (A, B, C, D, E)
+///   %result:3 = affine.delinearize_index %lin into (X, Y, Z)
+///
+/// If D*E == Z but neither C, B*C, nor A*B*C equals Y, scanning stops
+/// and the unmatched prefix is left as residual ops:
+///   %prefix_lin = affine.linearize_index disjoint [%a, %b, %c] by (A, B, C)
+///   %prefix:2 = affine.delinearize_index %prefix_lin into (X, Y)
+///   %tail = affine.linearize_index disjoint [%d, %e] by (D, E)
+///   %result = [%prefix#0, %prefix#1, %tail]
+struct SimplifyDelinearizeOfLinearizeDisjointManyToOneTail final
+    : OpRewritePattern<AffineDelinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AffineDelinearizeIndexOp delinearizeOp,
+                                PatternRewriter &rewriter) const override {
+    auto linearizeOp =
+        delinearizeOp.getLinearIndex().getDefiningOp<AffineLinearizeIndexOp>();
+    if (!linearizeOp)
+      return rewriter.notifyMatchFailure(delinearizeOp,
+                                         "index doesn't come from linearize");
+
+    if (!linearizeOp.getDisjoint())
+      return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
+
+    SmallVector<OpFoldResult> linBasis = linearizeOp.getMixedBasis();
+    SmallVector<OpFoldResult> delinBasis = delinearizeOp.getMixedBasis();
+    ValueRange linInputs = linearizeOp.getMultiIndex();
+    MLIRContext *ctx = rewriter.getContext();
+
+    // Track how many elements consumed from each tail.
+    size_t linTailConsumed = 0;
+    size_t delinTailConsumed = 0;
+
+    // For each matched delinearize dimension (innermost first), store the
+    // number of linearize dimensions that map to it.
+    SmallVector<size_t> groupLinCounts;
+
+    while (linTailConsumed < linBasis.size() &&
+           delinTailConsumed < delinBasis.size()) {
+      // Try matching k linearize dimensions to one delinearize dimension.
+      bool found = false;
+      for (size_t k = 1; k + linTailConsumed <= linBasis.size(); ++k) {
+        // Get the next k linearize dimensions from the tail.
+        ArrayRef<OpFoldResult> linSlice =
+            ArrayRef(linBasis).slice(linBasis.size() - linTailConsumed - k, k);
+        // Get the next one delinearize dimension from the tail.
+        ArrayRef<OpFoldResult> delinSlice =
+            ArrayRef(delinBasis)
+                .slice(delinBasis.size() - delinTailConsumed - 1, 1);
+
+        if (areProductsEqual(linSlice, delinSlice, ctx)) {
+          groupLinCounts.push_back(k);
+          linTailConsumed += k;
+          delinTailConsumed += 1;
+          found = true;
+          break;
+        }
+      }
+      if (!found)
+        break;
+    }
+
+    if (delinTailConsumed == 0)
+      return rewriter.notifyMatchFailure(delinearizeOp,
+                                         "no trailing dimensions matched");
+
+    SmallVector<Value> results;
+    if (delinTailConsumed < delinBasis.size()) {
+      // Partial match: create residual linearize + delinearize for the
+      // unmatched prefix.
+      Value residualLinearize = AffineLinearizeIndexOp::create(
+          rewriter, linearizeOp.getLoc(), linInputs.drop_back(linTailConsumed),
+          ArrayRef(linBasis).drop_back(linTailConsumed),
+          linearizeOp.getDisjoint());
+      auto residualDelinearize = AffineDelinearizeIndexOp::create(
+          rewriter, delinearizeOp.getLoc(), residualLinearize,
+          ArrayRef(delinBasis).drop_back(delinTailConsumed),
+          delinearizeOp.hasOuterBound());
+      results.append(residualDelinearize.getResults().begin(),
+                     residualDelinearize.getResults().end());
+    } else if (!delinearizeOp.hasOuterBound()) {
+      // All basis elements consumed, but the original delinearize has no outer
+      // bound which requires special handling.
+      ValueRange remainingInputs = linInputs.drop_back(linTailConsumed);
+      if (remainingInputs.empty()) {
+        // The outermost delinearize result is guaranteed to be zero.
+        results.push_back(arith::ConstantIndexOp::create(
+            rewriter, delinearizeOp.getLoc(), 0));
+      } else if (remainingInputs.size() == 1) {
+        // Pass through the single remaining input.
+        results.push_back(remainingInputs.front());
+      } else {
+        // Re-linearize the remaining inputs to produce the outermost result.
+        Value newLin = AffineLinearizeIndexOp::create(
+            rewriter, linearizeOp.getLoc(), remainingInputs,
+            ArrayRef(linBasis).drop_back(linTailConsumed),
+            linearizeOp.getDisjoint());
+        results.push_back(newLin);
+      }
+    }
+
+    // Produce one result per matched group. If the group size is 1,
+    // the input passes through directly. Otherwise, a smaller linearize is
+    // created over just that group's basis elements.
+    ValueRange matchedInputs = linInputs.take_back(linTailConsumed);
+    ArrayRef<OpFoldResult> matchedBasis =
+        ArrayRef(linBasis).take_back(linTailConsumed);
+    size_t offset = 0;
+    for (size_t count : llvm::reverse(groupLinCounts)) {
+      if (count == 1) {
+        results.push_back(matchedInputs[offset]);
+      } else {
+        Value newLin = AffineLinearizeIndexOp::create(
+            rewriter, linearizeOp.getLoc(), matchedInputs.slice(offset, count),
+            matchedBasis.slice(offset, count),
+            /*disjoint=*/true);
+        results.push_back(newLin);
+      }
+      offset += count;
+    }
+
+    rewriter.replaceOp(delinearizeOp, results);
+    return success();
+  }
+};
+
+/// Simplify delinearize(linearize) pairs from the tail by matching a single
+/// linearize dimension whose basis equals the product of multiple delinearize
+/// dimensions (one-to-many).
+///
+/// Scans from the rightmost basis elements. For each trailing linearize
+/// dimension, accumulates consecutive delinearize dimension products until an
+/// equal product is found via ValueBounds. Matched trailing dimensions are
+/// peeled off, and residual ops are created for unmatched prefixes.
+///
+/// Example:
+///   %lin = affine.linearize_index disjoint [%a, %b, %c] by (A, B, C)
+///   %result:5 = affine.delinearize_index %lin into (X, Y, Z, W, V)
+///
+/// If C == W*V but neither Z, Y*Z, nor X*Y*Z equals B, scanning stops
+/// and the unmatched prefix is left as residual ops:
+///   %prefix_lin = affine.linearize_index disjoint [%a, %b] by (A, B)
+///   %prefix:3 = affine.delinearize_index %prefix_lin into (X, Y, Z)
+///   %tail:2 = affine.delinearize_index %c into (W, V)
+///   %result = [%prefix#0, %prefix#1, %prefix#2, %tail#0, %tail#1]
+struct SimplifyDelinearizeOfLinearizeDisjointOneToManyTail final
+    : OpRewritePattern<AffineDelinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(AffineDelinearizeIndexOp delinearizeOp,
+                                PatternRewriter &rewriter) const override {
+    auto linearizeOp =
+        delinearizeOp.getLinearIndex().getDefiningOp<AffineLinearizeIndexOp>();
+    if (!linearizeOp)
+      return rewriter.notifyMatchFailure(delinearizeOp,
+                                         "index doesn't come from linearize");
+
+    if (!linearizeOp.getDisjoint())
+      return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
+
+    SmallVector<OpFoldResult> linBasis = linearizeOp.getMixedBasis();
+    SmallVector<OpFoldResult> delinBasis = delinearizeOp.getMixedBasis();
+    ValueRange linInputs = linearizeOp.getMultiIndex();
+    MLIRContext *ctx = rewriter.getContext();
+
+    // Track how many elements consumed from each tail.
+    size_t linTailConsumed = 0;
+    size_t delinTailConsumed = 0;
+
+    // For each matched linearize dimension (innermost first), store the
+    // number of delinearize dimensions it expands to.
+    SmallVector<size_t> groupDelinCounts;
+
+    while (linTailConsumed < linBasis.size() &&
+           delinTailConsumed < delinBasis.size()) {
+      // Try matching k delinearize dimensions to one linearize dimension.
+      bool found = false;
+      for (size_t k = 1; k + delinTailConsumed <= delinBasis.size(); ++k) {
+        // Get the next one linearize dimension from the tail.
+        ArrayRef<OpFoldResult> linSlice =
+            ArrayRef(linBasis).slice(linBasis.size() - linTailConsumed - 1, 1);
+        // Get the next k delinearize dimensions from the tail.
+        ArrayRef<OpFoldResult> delinSlice =
+            ArrayRef(delinBasis)
+                .slice(delinBasis.size() - delinTailConsumed - k, k);
+
+        if (areProductsEqual(linSlice, delinSlice, ctx)) {
+          groupDelinCounts.push_back(k);
+          linTailConsumed += 1;
+          delinTailConsumed += k;
+          found = true;
+          break;
+        }
+      }
+      if (!found)
+        break;
+    }
+
+    if (linTailConsumed == 0)
+      return rewriter.notifyMatchFailure(delinearizeOp,
+                                         "no trailing dimensions matched");
+
+    SmallVector<Value> results;
+
+    if (delinTailConsumed < delinBasis.size()) {
+      // Partial match: create residual linearize + delinearize for the
+      // unmatched prefix.
+      Value residualLinearize = AffineLinearizeIndexOp::create(
+          rewriter, linearizeOp.getLoc(), linInputs.drop_back(linTailConsumed),
+          ArrayRef(linBasis).drop_back(linTailConsumed),
+          linearizeOp.getDisjoint());
+      auto residualDelinearize = AffineDelinearizeIndexOp::create(
+          rewriter, delinearizeOp.getLoc(), residualLinearize,
+          ArrayRef(delinBasis).drop_back(delinTailConsumed),
+          delinearizeOp.hasOuterBound());
+      results.append(residualDelinearize.getResults().begin(),
+                     residualDelinearize.getResults().end());
+    } else if (!delinearizeOp.hasOuterBound()) {
+      // All basis elements consumed, but the original delinearize has no outer
+      // bound which requires special handling.
+      ValueRange remainingInputs = linInputs.drop_back(linTailConsumed);
+      if (remainingInputs.empty()) {
+        // The outermost delinearize result is guaranteed to be zero.
+        results.push_back(arith::ConstantIndexOp::create(
+            rewriter, delinearizeOp.getLoc(), 0));
+      } else if (remainingInputs.size() == 1) {
+        // Pass through the single remaining input.
+        results.push_back(remainingInputs.front());
+      } else {
+        // Re-linearize the remaining inputs to produce the outermost result.
+        Value newLin = AffineLinearizeIndexOp::create(
+            rewriter, linearizeOp.getLoc(), remainingInputs,
+            ArrayRef(linBasis).drop_back(linTailConsumed),
+            linearizeOp.getDisjoint());
+        results.push_back(newLin);
+      }
+    }
+
+    // Produce results for each matched group. If the group size is 1, the
+    // input passes through directly. Otherwise, a smaller delinearize is
+    // created over just that group's basis elements.
+    ValueRange matchedInputs = linInputs.take_back(linTailConsumed);
+    ArrayRef<OpFoldResult> matchedDelinBasis =
+        ArrayRef(delinBasis).take_back(delinTailConsumed);
+    size_t inputOffset = 0;
+    size_t delinOffset = 0;
+    for (size_t count : llvm::reverse(groupDelinCounts)) {
+      if (count == 1) {
+        results.push_back(matchedInputs[inputOffset]);
+      } else {
+        auto newDelin = AffineDelinearizeIndexOp::create(
+            rewriter, delinearizeOp.getLoc(), matchedInputs[inputOffset],
+            matchedDelinBasis.slice(delinOffset, count),
+            /*hasOuterBound=*/true);
+        results.append(newDelin.getResults().begin(),
+                       newDelin.getResults().end());
+      }
+      inputOffset += 1;
+      delinOffset += count;
+    }
+
+    rewriter.replaceOp(delinearizeOp, results);
+    return success();
+  }
+};
+
+} // namespace
+
+void affine::populateSimplifyAffineWithBoundsPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<SimplifyDelinearizeOfLinearizeDisjointManyToOneTail,
+               SimplifyDelinearizeOfLinearizeDisjointOneToManyTail>(
+      patterns.getContext());
+}
+
+//===----------------------------------------------------------------------===//
+// Pass definition
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace affine {
+#define GEN_PASS_DEF_SIMPLIFYAFFINEWITHBOUNDS
+#include "mlir/Dialect/Affine/Transforms/Passes.h.inc"
+} // namespace affine
+} // namespace mlir
+
+namespace {
+struct SimplifyAffineWithBoundsPass
+    : affine::impl::SimplifyAffineWithBoundsBase<SimplifyAffineWithBoundsPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    // Add canonicalization patterns first so cheap exact-match cases are
+    // handled without invoking value bounds analysis.
+    AffineDelinearizeIndexOp::getCanonicalizationPatterns(patterns,
+                                                          &getContext());
+    AffineLinearizeIndexOp::getCanonicalizationPatterns(patterns,
+                                                        &getContext());
+    populateSimplifyAffineWithBoundsPatterns(patterns);
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Affine/simplify-with-bounds.mlir b/mlir/test/Dialect/Affine/simplify-with-bounds.mlir
new file mode 100644
index 0000000000000..7d2e8068b916f
--- /dev/null
+++ b/mlir/test/Dialect/Affine/simplify-with-bounds.mlir
@@ -0,0 +1,156 @@
+// RUN: mlir-opt -affine-simplify-with-bounds %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: func @many_to_one_static_tail
+// CHECK-SAME:    %[[A:.*]]: index, %[[B:.*]]: index, %[[C:.*]]: index
+// CHECK-DAG:     %[[LIN:.*]] = affine.linearize_index disjoint [%[[B]], %[[C]]] by (8, 8)
+// CHECK-DAG:     return %[[A]], %[[LIN]]
+func.func @many_to_one_static_tail(%a: index, %b: index, %c: index) -> (index, index) {
+  %0 = affine.linearize_index disjoint [%a, %b, %c] by (4, 8, 8) : index
+  %1:2 = affine.delinearize_index %0 into (4, 64) : index, index
+  return %1#0, %1#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @many_to_one_dynamic_tail
+// CHECK-SAME:    %[[A:.*]]: index, %[[B:.*]]: index, %[[C:.*]]: index, %[[DYN:.*]]: index
+// CHECK-DAG:     %[[LIN:.*]] = affine.linearize_index disjoint [%[[B]], %[[C]]] by (%[[DYN]], 8)
+// CHECK-DAG:     return %[[A]], %[[LIN]]
+func.func @many_to_one_dynamic_tail(%a: index, %b: index, %c: index, %dyn: index) -> (index, index) {
+  %dyn_times_8 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%dyn]
+  %0 = affine.linearize_index disjoint [%a, %b, %c] by (4, %dyn, 8) : index
+  %1:2 = affine.delinearize_index %0 into (4, %dyn_times_8) : index, index
+  return %1#0, %1#1 : index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @one_to_many_static_tail
+// CHECK-SAME:    %[[A:.*]]: index, %[[B:.*]]: index
+// CHECK-DAG:     %[[DELIN:.*]]:2 = affine.delinearize_index %[[B]] into (8, 8)
+// CHECK-DAG:     return %[[A]], %[[DELIN]]#0, %[[DELIN]]#1
+func.func @one_to_many_static_tail(%a: index, %b: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%a, %b] by (4, 64) : index
+  %1:3 = affine.delinearize_index %0 into (4, 8, 8) : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @one_to_many_dynamic_tail
+// CHECK-SAME:    %[[A:.*]]: index, %[[B:.*]]: index, %[[DYN:.*]]: index
+// CHECK-DAG:     %[[DELIN:.*]]:2 = affine.delinearize_index %[[B]] into (%[[DYN]], 8)
+// CHECK-DAG:     return %[[A]], %[[DELIN]]#0, %[[DELIN]]#1
+func.func @one_to_many_dynamic_tail(%a: index, %b: index, %dyn: index) -> (index, index, index) {
+  %dyn_times_8 = affine.apply affine_map<()[s0] -> (s0 * 8)>()[%dyn]
+  %0 = affine.linearize_index disjoint [%a, %b] by (4, %dyn_times_8) : index
+  %1:3 = affine.delinearize_index %0 into (4, %dyn, 8) : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @one_to_one_dynamic_tail
+// CHECK-SAME:    %[[A:.*]]: index, %[[B:.*]]: index, %[[DYN:.*]]: index
+// CHECK-DAG:     return %[[A]], %[[B]]
+func.func @one_to_one_dynamic_tail(%a: index, %b: index, %dyn: index) -> (index, index) {
+  %dyn_ceildiv = affine.apply affine_map<()[s0] -> (s0 ceildiv 128)>()[%dyn]
+  %dyn_ceildiv_dup = affine.apply affine_map<()[s0] -> (s0 ceildiv 128)>()[%dyn]
+  %0 = affine.linearize_index disjoint [%a, %b] by (4, %dyn_ceildiv) : index
+  %1:2 = affine.delinearize_index %0 into (4, %dyn_ceildiv_dup) : index, index
+  return %1#0, %1#1 : index, index
+}
+
+// -----
+
+// Mixed: many-to-one and one-to-one in the same pair.
+// The many-to-one pattern should match [%c, %d] -> 1 delin dim,
+// and canonicalization handles the 1:1 prefix.
+// CHECK-LABEL: func @mixed_many_to_one_and_one_to_one
+// CHECK-SAME:    %[[A:.*]]: index, %[[B:.*]]: index, %[[C:.*]]: index, %[[D:.*]]: index
+// CHECK-DAG:     %[[LIN:.*]] = affine.linearize_index disjoint [%[[C]], %[[D]]] by (4, 8)
+// CHECK-DAG:     return %[[A]], %[[B]], %[[LIN]]
+func.func @mixed_many_to_one_and_one_to_one(%a: index, %b: index, %c: index, %d: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%a, %b, %c, %d] by (2, 3, 4, 8) : index
+  %1:3 = affine.delinearize_index %0 into (2, 3, 32) : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// Partial match: only tail matches, prefix is left as residual.
+// CHECK-LABEL: func @partial_tail_many_to_one
+// CHECK-SAME:    %[[A:.*]]: index, %[[B:.*]]: index, %[[C:.*]]: index, %[[D:.*]]: index
+// CHECK:         %[[RESLIN:.*]] = affine.linearize_index disjoint [%[[A]], %[[B]]] by (5, 3)
+// CHECK:         %[[RESDELIN:.*]]:2 = affine.delinearize_index %[[RESLIN]] into (7, 9)
+// CHECK:         %[[TAILLIN:.*]] = affine.linearize_index disjoint [%[[C]], %[[D]]] by (4, 8)
+// CHECK:         return %[[RESDELIN]]#0, %[[RESDELIN]]#1, %[[TAILLIN]]
+func.func @partial_tail_many_to_one(%a: index, %b: index, %c: index, %d: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%a, %b, %c, %d] by (5, 3, 4, 8) : index
+  %1:3 = affine.delinearize_index %0 into (7, 9, 32) : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// Many-to-one with no outer bound: all basis elements consumed.
+// The outermost delinearize result (unbounded) passes through from the
+// outermost linearize input.
+// CHECK-LABEL: func @many_to_one_no_outer_bound
+// CHECK-SAME:    %[[A:.*]]: index, %[[B:.*]]: index, %[[C:.*]]: index
+// CHECK-DAG:     %[[LIN:.*]] = affine.linearize_index disjoint [%[[B]], %[[C]]] by (8, 8)
+// CHECK-DAG:     return %[[A]], %[[LIN]]
+func.func @many_to_one_no_outer_bound(%a: index, %b: index, %c: index) -> (index, index) {
+  %0 = affine.linearize_index disjoint [%a, %b, %c] by (8, 8) : index
+  %1:2 = affine.delinearize_index %0 into (64) : index, index
+  return %1#0, %1#1 : index, index
+}
+
+// -----
+
+// One-to-many with no outer bound: all basis elements consumed.
+// CHECK-LABEL: func @one_to_many_no_outer_bound
+// CHECK-SAME:    %[[A:.*]]: index, %[[B:.*]]: index
+// CHECK-DAG:     %[[DELIN:.*]]:2 = affine.delinearize_index %[[B]] into (8, 8)
+// CHECK-DAG:     return %[[A]], %[[DELIN]]#0, %[[DELIN]]#1
+func.func @one_to_many_no_outer_bound(%a: index, %b: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%a, %b] by (64) : index
+  %1:3 = affine.delinearize_index %0 into (8, 8) : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// Negative test: no disjoint flag.
+// CHECK-LABEL: func @no_disjoint
+// CHECK:         affine.linearize_index [
+// CHECK:         affine.delinearize_index
+func.func @no_disjoint(%a: index, %b: index, %c: index) -> (index, index) {
+  %0 = affine.linearize_index [%a, %b, %c] by (4, 8, 8) : index
+  %1:2 = affine.delinearize_index %0 into (4, 64) : index, index
+  return %1#0, %1#1 : index, index
+}
+
+// -----
+
+// Negative test: products don't match.
+// CHECK-LABEL: func @products_dont_match
+// CHECK:         affine.linearize_index disjoint
+// CHECK:         affine.delinearize_index
+func.func @products_dont_match(%a: index, %b: index, %c: index) -> (index, index) {
+  %0 = affine.linearize_index disjoint [%a, %b, %c] by (4, 8, 8) : index
+  %1:2 = affine.delinearize_index %0 into (4, 63) : index, index
+  return %1#0, %1#1 : index, index
+}
+
+// -----
+
+// Negative test: input not from linearize.
+// CHECK-LABEL: func @input_not_linearize
+// CHECK:         affine.delinearize_index %{{.*}} into
+func.func @input_not_linearize(%x: index) -> (index, index) {
+  %0:2 = affine.delinearize_index %x into (4, 8) : index, index
+  return %0#0, %0#1 : index, index
+}

>From 64fa934aff5d8e6866770125ed109588a9a446d1 Mon Sep 17 00:00:00 2001
From: Yu-Zhewen <zhewenyu at amd.com>
Date: Wed, 25 Mar 2026 16:52:51 -0700
Subject: [PATCH 2/3] address review comments

Signed-off-by: Yu-Zhewen <zhewenyu at amd.com>
---
 .../Transforms/SimplifyAffineWithBounds.cpp   | 120 +++++++++---------
 .../Dialect/Affine/simplify-with-bounds.mlir  |  15 ++-
 2 files changed, 72 insertions(+), 63 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp
index 4a0d199c543f3..620515191ef7e 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp
@@ -1,4 +1,4 @@
-//===- SimplifyAffineIndexOps.cpp - Simplify affine index ops -------------===//
+//===- SimplifyAffineWithBounds.cpp ---------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -45,14 +45,38 @@ buildProductVariable(ArrayRef<OpFoldResult> bases, MLIRContext *ctx) {
   return ValueBoundsConstraintSet::Variable(productMap, operands);
 }
 
-/// Check if two groups of basis elements have equal products using value bounds
-/// analysis.
-static bool areProductsEqual(ArrayRef<OpFoldResult> lhs,
-                             ArrayRef<OpFoldResult> rhs, MLIRContext *ctx) {
-  auto lhsVar = buildProductVariable(lhs, ctx);
-  auto rhsVar = buildProductVariable(rhs, ctx);
-  FailureOr<bool> result = ValueBoundsConstraintSet::areEqual(lhsVar, rhsVar);
-  return succeeded(result) && *result;
+/// Try to find k consecutive elements from `lhs` (starting from tail offset)
+/// whose product equals the single next element from `rhs`.
+/// The product is accumulated incrementally to avoid redundant computation.
+/// Returns the number of matched elements k, or std::nullopt if no match.
+static std::optional<size_t> tryMatchProduct(ArrayRef<OpFoldResult> lhs,
+                                             size_t lhsTailConsumed,
+                                             ArrayRef<OpFoldResult> rhs,
+                                             size_t rhsTailConsumed,
+                                             MLIRContext *ctx) {
+  auto rhsVar =
+      buildProductVariable(rhs.slice(rhs.size() - rhsTailConsumed - 1, 1), ctx);
+
+  AffineExpr productExpr = getAffineConstantExpr(1, ctx);
+  SmallVector<Value> operands;
+
+  for (size_t k = 1; k + lhsTailConsumed <= lhs.size(); ++k) {
+    OpFoldResult basis = lhs[lhs.size() - lhsTailConsumed - k];
+    if (auto attr = dyn_cast<Attribute>(basis)) {
+      int64_t val = cast<IntegerAttr>(attr).getInt();
+      productExpr = productExpr * getAffineConstantExpr(val, ctx);
+    } else {
+      operands.push_back(cast<Value>(basis));
+      productExpr = productExpr * getAffineSymbolExpr(operands.size() - 1, ctx);
+    }
+
+    AffineMap productMap = AffineMap::get(0, operands.size(), productExpr, ctx);
+    ValueBoundsConstraintSet::Variable lhsVar(productMap, operands);
+    FailureOr<bool> result = ValueBoundsConstraintSet::areEqual(lhsVar, rhsVar);
+    if (succeeded(result) && *result)
+      return k;
+  }
+  return std::nullopt;
 }
 
 namespace {
@@ -79,7 +103,7 @@ namespace {
 ///   %result = [%prefix#0, %prefix#1, %tail]
 struct SimplifyDelinearizeOfLinearizeDisjointManyToOneTail final
     : OpRewritePattern<AffineDelinearizeIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
+  using Base::Base;
 
   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp delinearizeOp,
                                 PatternRewriter &rewriter) const override {
@@ -108,26 +132,13 @@ struct SimplifyDelinearizeOfLinearizeDisjointManyToOneTail final
     while (linTailConsumed < linBasis.size() &&
            delinTailConsumed < delinBasis.size()) {
       // Try matching k linearize dimensions to one delinearize dimension.
-      bool found = false;
-      for (size_t k = 1; k + linTailConsumed <= linBasis.size(); ++k) {
-        // Get the next k linearize dimensions from the tail.
-        ArrayRef<OpFoldResult> linSlice =
-            ArrayRef(linBasis).slice(linBasis.size() - linTailConsumed - k, k);
-        // Get the next one delinearize dimension from the tail.
-        ArrayRef<OpFoldResult> delinSlice =
-            ArrayRef(delinBasis)
-                .slice(delinBasis.size() - delinTailConsumed - 1, 1);
-
-        if (areProductsEqual(linSlice, delinSlice, ctx)) {
-          groupLinCounts.push_back(k);
-          linTailConsumed += k;
-          delinTailConsumed += 1;
-          found = true;
-          break;
-        }
-      }
-      if (!found)
+      std::optional<size_t> k = tryMatchProduct(
+          linBasis, linTailConsumed, delinBasis, delinTailConsumed, ctx);
+      if (!k)
         break;
+      groupLinCounts.push_back(*k);
+      linTailConsumed += *k;
+      delinTailConsumed += 1;
     }
 
     if (delinTailConsumed == 0)
@@ -172,17 +183,17 @@ struct SimplifyDelinearizeOfLinearizeDisjointManyToOneTail final
     // Produce one result per matched group. If the group size is 1,
     // the input passes through directly. Otherwise, a smaller linearize is
     // created over just that group's basis elements.
-    ValueRange matchedInputs = linInputs.take_back(linTailConsumed);
-    ArrayRef<OpFoldResult> matchedBasis =
-        ArrayRef(linBasis).take_back(linTailConsumed);
+    size_t inputMatchStart = linInputs.size() - linTailConsumed;
+    size_t basisMatchStart = linBasis.size() - linTailConsumed;
     size_t offset = 0;
     for (size_t count : llvm::reverse(groupLinCounts)) {
       if (count == 1) {
-        results.push_back(matchedInputs[offset]);
+        results.push_back(linInputs[inputMatchStart + offset]);
       } else {
         Value newLin = AffineLinearizeIndexOp::create(
-            rewriter, linearizeOp.getLoc(), matchedInputs.slice(offset, count),
-            matchedBasis.slice(offset, count),
+            rewriter, linearizeOp.getLoc(),
+            linInputs.slice(inputMatchStart + offset, count),
+            ArrayRef(linBasis).slice(basisMatchStart + offset, count),
             /*disjoint=*/true);
         results.push_back(newLin);
       }
@@ -215,7 +226,7 @@ struct SimplifyDelinearizeOfLinearizeDisjointManyToOneTail final
 ///   %result = [%prefix#0, %prefix#1, %prefix#2, %tail#0, %tail#1]
 struct SimplifyDelinearizeOfLinearizeDisjointOneToManyTail final
     : OpRewritePattern<AffineDelinearizeIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
+  using Base::Base;
 
   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp delinearizeOp,
                                 PatternRewriter &rewriter) const override {
@@ -244,26 +255,13 @@ struct SimplifyDelinearizeOfLinearizeDisjointOneToManyTail final
     while (linTailConsumed < linBasis.size() &&
            delinTailConsumed < delinBasis.size()) {
       // Try matching k delinearize dimensions to one linearize dimension.
-      bool found = false;
-      for (size_t k = 1; k + delinTailConsumed <= delinBasis.size(); ++k) {
-        // Get the next one linearize dimension from the tail.
-        ArrayRef<OpFoldResult> linSlice =
-            ArrayRef(linBasis).slice(linBasis.size() - linTailConsumed - 1, 1);
-        // Get the next k delinearize dimensions from the tail.
-        ArrayRef<OpFoldResult> delinSlice =
-            ArrayRef(delinBasis)
-                .slice(delinBasis.size() - delinTailConsumed - k, k);
-
-        if (areProductsEqual(linSlice, delinSlice, ctx)) {
-          groupDelinCounts.push_back(k);
-          linTailConsumed += 1;
-          delinTailConsumed += k;
-          found = true;
-          break;
-        }
-      }
-      if (!found)
+      std::optional<size_t> k = tryMatchProduct(delinBasis, delinTailConsumed,
+                                                linBasis, linTailConsumed, ctx);
+      if (!k)
         break;
+      groupDelinCounts.push_back(*k);
+      delinTailConsumed += *k;
+      linTailConsumed += 1;
     }
 
     if (linTailConsumed == 0)
@@ -309,18 +307,18 @@ struct SimplifyDelinearizeOfLinearizeDisjointOneToManyTail final
     // Produce results for each matched group. If the group size is 1, the
     // input passes through directly. Otherwise, a smaller delinearize is
     // created over just that group's basis elements.
-    ValueRange matchedInputs = linInputs.take_back(linTailConsumed);
-    ArrayRef<OpFoldResult> matchedDelinBasis =
-        ArrayRef(delinBasis).take_back(delinTailConsumed);
+    size_t linMatchStart = linInputs.size() - linTailConsumed;
+    size_t delinMatchStart = delinBasis.size() - delinTailConsumed;
     size_t inputOffset = 0;
     size_t delinOffset = 0;
     for (size_t count : llvm::reverse(groupDelinCounts)) {
       if (count == 1) {
-        results.push_back(matchedInputs[inputOffset]);
+        results.push_back(linInputs[linMatchStart + inputOffset]);
       } else {
         auto newDelin = AffineDelinearizeIndexOp::create(
-            rewriter, delinearizeOp.getLoc(), matchedInputs[inputOffset],
-            matchedDelinBasis.slice(delinOffset, count),
+            rewriter, delinearizeOp.getLoc(),
+            linInputs[linMatchStart + inputOffset],
+            ArrayRef(delinBasis).slice(delinMatchStart + delinOffset, count),
             /*hasOuterBound=*/true);
         results.append(newDelin.getResults().begin(),
                        newDelin.getResults().end());
diff --git a/mlir/test/Dialect/Affine/simplify-with-bounds.mlir b/mlir/test/Dialect/Affine/simplify-with-bounds.mlir
index 7d2e8068b916f..a9c517a5c5add 100644
--- a/mlir/test/Dialect/Affine/simplify-with-bounds.mlir
+++ b/mlir/test/Dialect/Affine/simplify-with-bounds.mlir
@@ -1,7 +1,5 @@
 // RUN: mlir-opt -affine-simplify-with-bounds %s | FileCheck %s
 
-// -----
-
 // CHECK-LABEL: func @many_to_one_static_tail
 // CHECK-SAME:    %[[A:.*]]: index, %[[B:.*]]: index, %[[C:.*]]: index
 // CHECK-DAG:     %[[LIN:.*]] = affine.linearize_index disjoint [%[[B]], %[[C]]] by (8, 8)
@@ -123,6 +121,19 @@ func.func @one_to_many_no_outer_bound(%a: index, %b: index) -> (index, index, in
 
 // -----
 
+// Partial match with empty residual linearize.
+// CHECK-LABEL: func @partial_match_empty_residual_lin
+// CHECK-SAME:    %[[A:.*]]: index, %[[B:.*]]: index
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         return %[[C0]], %[[A]], %[[B]]
+func.func @partial_match_empty_residual_lin(%a: index, %b: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%a, %b] by (4, 8) : index
+  %1:3 = affine.delinearize_index %0 into (10, 4, 8) : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
 // Negative test: no disjoint flag.
 // CHECK-LABEL: func @no_disjoint
 // CHECK:         affine.linearize_index [

>From d7c88596e0c35db88bcf3431f72ae72a16783268 Mon Sep 17 00:00:00 2001
From: Yu-Zhewen <zhewenyu at amd.com>
Date: Fri, 27 Mar 2026 05:14:11 -0700
Subject: [PATCH 3/3] unify into one pattern

Signed-off-by: Yu-Zhewen <zhewenyu at amd.com>
---
 .../Transforms/SimplifyAffineWithBounds.cpp   | 280 +++++-------------
 1 file changed, 82 insertions(+), 198 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp
index 620515191ef7e..6e7d5d91334c0 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineWithBounds.cpp
@@ -11,10 +11,12 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -24,25 +26,17 @@
 using namespace mlir;
 using namespace mlir::affine;
 
-/// Build a ValueBoundsConstraintSet::Variable representing the product of
-/// the given basis elements. Static elements become constants in an affine
-/// expression; dynamic elements become symbols.
-static ValueBoundsConstraintSet::Variable
-buildProductVariable(ArrayRef<OpFoldResult> bases, MLIRContext *ctx) {
-  AffineExpr productExpr = getAffineConstantExpr(1, ctx);
-  SmallVector<Value> operands;
-  for (OpFoldResult basis : bases) {
-    if (auto attr = dyn_cast<Attribute>(basis)) {
-      int64_t val = cast<IntegerAttr>(attr).getInt();
-      productExpr = productExpr * getAffineConstantExpr(val, ctx);
-    } else {
-      Value val = cast<Value>(basis);
-      operands.push_back(val);
-      productExpr = productExpr * getAffineSymbolExpr(operands.size() - 1, ctx);
-    }
+/// Accumulate a single basis element into the running product expression.
+/// Static values become affine constants, and dynamic values become symbols.
+static void buildProductExpr(OpFoldResult basis, AffineExpr &productExpr,
+                             SmallVectorImpl<Value> &operands,
+                             MLIRContext *ctx) {
+  if (auto val = getConstantIntValue(basis)) {
+    productExpr = productExpr * getAffineConstantExpr(*val, ctx);
+  } else {
+    operands.push_back(cast<Value>(basis));
+    productExpr = productExpr * getAffineSymbolExpr(operands.size() - 1, ctx);
   }
-  AffineMap productMap = AffineMap::get(0, operands.size(), productExpr, ctx);
-  return ValueBoundsConstraintSet::Variable(productMap, operands);
 }
 
 /// Try to find k consecutive elements from `lhs` (starting from tail offset)
@@ -54,24 +48,22 @@ static std::optional<size_t> tryMatchProduct(ArrayRef<OpFoldResult> lhs,
                                              ArrayRef<OpFoldResult> rhs,
                                              size_t rhsTailConsumed,
                                              MLIRContext *ctx) {
-  auto rhsVar =
-      buildProductVariable(rhs.slice(rhs.size() - rhsTailConsumed - 1, 1), ctx);
-
-  AffineExpr productExpr = getAffineConstantExpr(1, ctx);
-  SmallVector<Value> operands;
-
+  // Build a Variable for the single rhs element.
+  AffineExpr rhsExpr = getAffineConstantExpr(1, ctx);
+  SmallVector<Value> rhsOperands;
+  buildProductExpr(rhs[rhs.size() - rhsTailConsumed - 1], rhsExpr, rhsOperands,
+                   ctx);
+  ValueBoundsConstraintSet::Variable rhsVar(
+      AffineMap::get(0, rhsOperands.size(), rhsExpr, ctx), rhsOperands);
+
+  // Incrementally accumulate lhs product and check for equality.
+  AffineExpr lhsExpr = getAffineConstantExpr(1, ctx);
+  SmallVector<Value> lhsOperands;
   for (size_t k = 1; k + lhsTailConsumed <= lhs.size(); ++k) {
-    OpFoldResult basis = lhs[lhs.size() - lhsTailConsumed - k];
-    if (auto attr = dyn_cast<Attribute>(basis)) {
-      int64_t val = cast<IntegerAttr>(attr).getInt();
-      productExpr = productExpr * getAffineConstantExpr(val, ctx);
-    } else {
-      operands.push_back(cast<Value>(basis));
-      productExpr = productExpr * getAffineSymbolExpr(operands.size() - 1, ctx);
-    }
-
-    AffineMap productMap = AffineMap::get(0, operands.size(), productExpr, ctx);
-    ValueBoundsConstraintSet::Variable lhsVar(productMap, operands);
+    buildProductExpr(lhs[lhs.size() - lhsTailConsumed - k], lhsExpr,
+                     lhsOperands, ctx);
+    AffineMap lhsMap = AffineMap::get(0, lhsOperands.size(), lhsExpr, ctx);
+    ValueBoundsConstraintSet::Variable lhsVar(lhsMap, lhsOperands);
     FailureOr<bool> result = ValueBoundsConstraintSet::areEqual(lhsVar, rhsVar);
     if (succeeded(result) && *result)
       return k;
@@ -81,27 +73,26 @@ static std::optional<size_t> tryMatchProduct(ArrayRef<OpFoldResult> lhs,
 
 namespace {
 
-/// Simplify delinearize(linearize) pairs from the tail by matching multiple
-/// linearize dimensions whose product equals a single delinearize dimension
-/// (many-to-one).
+/// Simplify delinearize(linearize) pairs from the tail by matching groups of
+/// dimensions whose basis products are equal via ValueBounds analysis.
+///
+/// For each step from the tail, tries:
+///   1. Many-to-one: k linearize dims -> 1 delinearize dim
+///   2. One-to-many: 1 linearize dim -> k delinearize dims
 ///
-/// Scans from the rightmost basis elements. For each trailing delinearize
-/// dimension, accumulates consecutive linearize dimension products until an
-/// equal product is found via ValueBounds. Matched trailing dimensions are
-/// peeled off, and residual ops are created for unmatched prefixes.
+/// Matched trailing dimensions are peeled off. Unmatched prefix dimensions
+/// are left as residual linearize/delinearize operations.
 ///
-/// Example:
+/// Example (many-to-one, D*E == Z):
 ///   %lin = affine.linearize_index disjoint [%a, %b, %c, %d, %e]
 ///              by (A, B, C, D, E)
 ///   %result:3 = affine.delinearize_index %lin into (X, Y, Z)
-///
-/// If D*E == Z but neither C, B*C, nor A*B*C equals Y, scanning stops
-/// and the unmatched prefix is left as residual ops:
+/// ->
 ///   %prefix_lin = affine.linearize_index disjoint [%a, %b, %c] by (A, B, C)
 ///   %prefix:2 = affine.delinearize_index %prefix_lin into (X, Y)
 ///   %tail = affine.linearize_index disjoint [%d, %e] by (D, E)
 ///   %result = [%prefix#0, %prefix#1, %tail]
-struct SimplifyDelinearizeOfLinearizeDisjointManyToOneTail final
+struct SimplifyDelinearizeOfLinearizeDisjoint final
     : OpRewritePattern<AffineDelinearizeIndexOp> {
   using Base::Base;
 
@@ -125,27 +116,39 @@ struct SimplifyDelinearizeOfLinearizeDisjointManyToOneTail final
     size_t linTailConsumed = 0;
     size_t delinTailConsumed = 0;
 
-    // For each matched delinearize dimension (innermost first), store the
-    // number of linearize dimensions that map to it.
-    SmallVector<size_t> groupLinCounts;
+    // For each matched group (innermost first), record the number of
+    // linearize and delinearize dimensions it spans. Many-to-one groups
+    // have linCount > 1, one-to-many groups have delinCount > 1.
+    SmallVector<std::pair<size_t, size_t>> matchedGroups;
 
     while (linTailConsumed < linBasis.size() &&
            delinTailConsumed < delinBasis.size()) {
-      // Try matching k linearize dimensions to one delinearize dimension.
-      std::optional<size_t> k = tryMatchProduct(
-          linBasis, linTailConsumed, delinBasis, delinTailConsumed, ctx);
-      if (!k)
-        break;
-      groupLinCounts.push_back(*k);
-      linTailConsumed += *k;
-      delinTailConsumed += 1;
+      // Try many-to-one: k lin dims -> 1 delin dim.
+      if (std::optional<size_t> k = tryMatchProduct(
+              linBasis, linTailConsumed, delinBasis, delinTailConsumed, ctx)) {
+        matchedGroups.emplace_back(*k, 1);
+        linTailConsumed += *k;
+        delinTailConsumed += 1;
+        continue;
+      }
+      // Try one-to-many: 1 lin dim -> k delin dims.
+      if (std::optional<size_t> k = tryMatchProduct(
+              delinBasis, delinTailConsumed, linBasis, linTailConsumed, ctx)) {
+        matchedGroups.emplace_back(1, *k);
+        delinTailConsumed += *k;
+        linTailConsumed += 1;
+        continue;
+      }
+      break;
     }
 
-    if (delinTailConsumed == 0)
+    if (matchedGroups.empty())
       return rewriter.notifyMatchFailure(delinearizeOp,
                                          "no trailing dimensions matched");
 
     SmallVector<Value> results;
+
+    // Build residual prefix ops for unmatched dimensions.
     if (delinTailConsumed < delinBasis.size()) {
       // Partial match: create residual linearize + delinearize for the
       // unmatched prefix.
@@ -180,151 +183,34 @@ struct SimplifyDelinearizeOfLinearizeDisjointManyToOneTail final
       }
     }
 
-    // Produce one result per matched group. If the group size is 1,
-    // the input passes through directly. Otherwise, a smaller linearize is
-    // created over just that group's basis elements.
-    size_t inputMatchStart = linInputs.size() - linTailConsumed;
-    size_t basisMatchStart = linBasis.size() - linTailConsumed;
-    size_t offset = 0;
-    for (size_t count : llvm::reverse(groupLinCounts)) {
-      if (count == 1) {
-        results.push_back(linInputs[inputMatchStart + offset]);
-      } else {
+    // Build results for each matched group.
+    size_t linInputOffset = linInputs.size() - linTailConsumed;
+    size_t linBasisOffset = linBasis.size() - linTailConsumed;
+    size_t delinBasisOffset = delinBasis.size() - delinTailConsumed;
+    for (auto [linCount, delinCount] : llvm::reverse(matchedGroups)) {
+      if (linCount == 1 && delinCount == 1) {
+        // Exact 1:1 match: pass through directly.
+        results.push_back(linInputs[linInputOffset]);
+      } else if (linCount > 1) {
+        // Many-to-one: re-linearize the group's lin inputs.
         Value newLin = AffineLinearizeIndexOp::create(
             rewriter, linearizeOp.getLoc(),
-            linInputs.slice(inputMatchStart + offset, count),
-            ArrayRef(linBasis).slice(basisMatchStart + offset, count),
+            linInputs.slice(linInputOffset, linCount),
+            ArrayRef(linBasis).slice(linBasisOffset, linCount),
             /*disjoint=*/true);
         results.push_back(newLin);
-      }
-      offset += count;
-    }
-
-    rewriter.replaceOp(delinearizeOp, results);
-    return success();
-  }
-};
-
-/// Simplify delinearize(linearize) pairs from the tail by matching a single
-/// linearize dimension whose basis equals the product of multiple delinearize
-/// dimensions (one-to-many).
-///
-/// Scans from the rightmost basis elements. For each trailing linearize
-/// dimension, accumulates consecutive delinearize dimension products until an
-/// equal product is found via ValueBounds. Matched trailing dimensions are
-/// peeled off, and residual ops are created for unmatched prefixes.
-///
-/// Example:
-///   %lin = affine.linearize_index disjoint [%a, %b, %c] by (A, B, C)
-///   %result:5 = affine.delinearize_index %lin into (X, Y, Z, W, V)
-///
-/// If C == W*V but neither Z, Y*Z, nor X*Y*Z equals B, scanning stops
-/// and the unmatched prefix is left as residual ops:
-///   %prefix_lin = affine.linearize_index disjoint [%a, %b] by (A, B)
-///   %prefix:3 = affine.delinearize_index %prefix_lin into (X, Y, Z)
-///   %tail:2 = affine.delinearize_index %c into (W, V)
-///   %result = [%prefix#0, %prefix#1, %prefix#2, %tail#0, %tail#1]
-struct SimplifyDelinearizeOfLinearizeDisjointOneToManyTail final
-    : OpRewritePattern<AffineDelinearizeIndexOp> {
-  using Base::Base;
-
-  LogicalResult matchAndRewrite(AffineDelinearizeIndexOp delinearizeOp,
-                                PatternRewriter &rewriter) const override {
-    auto linearizeOp =
-        delinearizeOp.getLinearIndex().getDefiningOp<AffineLinearizeIndexOp>();
-    if (!linearizeOp)
-      return rewriter.notifyMatchFailure(delinearizeOp,
-                                         "index doesn't come from linearize");
-
-    if (!linearizeOp.getDisjoint())
-      return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
-
-    SmallVector<OpFoldResult> linBasis = linearizeOp.getMixedBasis();
-    SmallVector<OpFoldResult> delinBasis = delinearizeOp.getMixedBasis();
-    ValueRange linInputs = linearizeOp.getMultiIndex();
-    MLIRContext *ctx = rewriter.getContext();
-
-    // Track how many elements consumed from each tail.
-    size_t linTailConsumed = 0;
-    size_t delinTailConsumed = 0;
-
-    // For each matched linearize dimension (innermost first), store the
-    // number of delinearize dimensions it expands to.
-    SmallVector<size_t> groupDelinCounts;
-
-    while (linTailConsumed < linBasis.size() &&
-           delinTailConsumed < delinBasis.size()) {
-      // Try matching k delinearize dimensions to one linearize dimension.
-      std::optional<size_t> k = tryMatchProduct(delinBasis, delinTailConsumed,
-                                                linBasis, linTailConsumed, ctx);
-      if (!k)
-        break;
-      groupDelinCounts.push_back(*k);
-      delinTailConsumed += *k;
-      linTailConsumed += 1;
-    }
-
-    if (linTailConsumed == 0)
-      return rewriter.notifyMatchFailure(delinearizeOp,
-                                         "no trailing dimensions matched");
-
-    SmallVector<Value> results;
-
-    if (delinTailConsumed < delinBasis.size()) {
-      // Partial match: create residual linearize + delinearize for the
-      // unmatched prefix.
-      Value residualLinearize = AffineLinearizeIndexOp::create(
-          rewriter, linearizeOp.getLoc(), linInputs.drop_back(linTailConsumed),
-          ArrayRef(linBasis).drop_back(linTailConsumed),
-          linearizeOp.getDisjoint());
-      auto residualDelinearize = AffineDelinearizeIndexOp::create(
-          rewriter, delinearizeOp.getLoc(), residualLinearize,
-          ArrayRef(delinBasis).drop_back(delinTailConsumed),
-          delinearizeOp.hasOuterBound());
-      results.append(residualDelinearize.getResults().begin(),
-                     residualDelinearize.getResults().end());
-    } else if (!delinearizeOp.hasOuterBound()) {
-      // All basis elements consumed, but the original delinearize has no outer
-      // bound which requires special handling.
-      ValueRange remainingInputs = linInputs.drop_back(linTailConsumed);
-      if (remainingInputs.empty()) {
-        // The outermost delinearize result is guaranteed to be zero.
-        results.push_back(arith::ConstantIndexOp::create(
-            rewriter, delinearizeOp.getLoc(), 0));
-      } else if (remainingInputs.size() == 1) {
-        // Pass through the single remaining input.
-        results.push_back(remainingInputs.front());
-      } else {
-        // Re-linearize the remaining inputs to produce the outermost result.
-        Value newLin = AffineLinearizeIndexOp::create(
-            rewriter, linearizeOp.getLoc(), remainingInputs,
-            ArrayRef(linBasis).drop_back(linTailConsumed),
-            linearizeOp.getDisjoint());
-        results.push_back(newLin);
-      }
-    }
-
-    // Produce results for each matched group. If the group size is 1, the
-    // input passes through directly. Otherwise, a smaller delinearize is
-    // created over just that group's basis elements.
-    size_t linMatchStart = linInputs.size() - linTailConsumed;
-    size_t delinMatchStart = delinBasis.size() - delinTailConsumed;
-    size_t inputOffset = 0;
-    size_t delinOffset = 0;
-    for (size_t count : llvm::reverse(groupDelinCounts)) {
-      if (count == 1) {
-        results.push_back(linInputs[linMatchStart + inputOffset]);
       } else {
+        // One-to-many: delinearize the single lin input.
         auto newDelin = AffineDelinearizeIndexOp::create(
-            rewriter, delinearizeOp.getLoc(),
-            linInputs[linMatchStart + inputOffset],
-            ArrayRef(delinBasis).slice(delinMatchStart + delinOffset, count),
+            rewriter, delinearizeOp.getLoc(), linInputs[linInputOffset],
+            ArrayRef(delinBasis).slice(delinBasisOffset, delinCount),
             /*hasOuterBound=*/true);
         results.append(newDelin.getResults().begin(),
                        newDelin.getResults().end());
       }
-      inputOffset += 1;
-      delinOffset += count;
+      linInputOffset += linCount;
+      linBasisOffset += linCount;
+      delinBasisOffset += delinCount;
     }
 
     rewriter.replaceOp(delinearizeOp, results);
@@ -336,9 +222,7 @@ struct SimplifyDelinearizeOfLinearizeDisjointOneToManyTail final
 
 void affine::populateSimplifyAffineWithBoundsPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<SimplifyDelinearizeOfLinearizeDisjointManyToOneTail,
-               SimplifyDelinearizeOfLinearizeDisjointOneToManyTail>(
-      patterns.getContext());
+  patterns.add<SimplifyDelinearizeOfLinearizeDisjoint>(patterns.getContext());
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list