[Mlir-commits] [mlir] [mlir][Affine] Let affine.[de]linearize_index omit outer bounds (PR #116103)

Krzysztof Drewniak llvmlistbot at llvm.org
Thu Nov 14 09:43:04 PST 2024


https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/116103

>From 154127f4423b08dd3833652754ef859536ddda73 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 13 Nov 2024 19:16:29 +0000
Subject: [PATCH 1/3] [mlir][Affine] Let affine.[de]linearize_index omit outer
 bounds

The affine.delinearize_index and affine.linearize_index operations, as
currently defined, require providing a length N basis to [de]linearize
N values. The first value in this basis is never used during lowering
and is unused during lowering. (Note that, even though it isn't used
during lowering it can still be used to, for example, remove length-1
outputs from a delinearize).

This dead value makes sense in the original context of these
operations, which is linearizing or de-linearizing indexes to
memref<>s, vector<>s, and other shaped types, where that outer bound
is avaliable and may be useful for analysis.

However, other usecases exist where the outer bound is not known. For
example:

    %thread_id_x = gpu.thread_id x : index
    %0:3 = affine.delinearize_index %thread_id_x into (4, 16) : index,index, index

In this code, we don't know the upper bound of the thread ID, but we
do want to construct the ?x4x16 grid of delinearized values in order
to further partition the GPU threads.

In order to support such usecases, we broaden the definition of
affine.delinearize_index and affine.linearize_index to make the outer
bound optional.

In the case of affine.delinearize_index, where the number of results
is a function of the size of the passed-in basis, we augment all
existing builders with a `hasOuterBound` argument, which, for
backwards compatibilty and to preserve the natural usage of the op,
defaults to `true`. If this flag is true, the op returns one result
per basis element, if it is false, it returns one extra result in
position 0.

We also update existing canonicalization patterns (and move one of
them into the folder) to handle these cases. Note that disagreements
about the outer bound now no longer prevent delinearize/linearize
cancelations.
---
 .../mlir/Dialect/Affine/IR/AffineOps.td       |  84 +++++++-
 mlir/include/mlir/Dialect/Affine/Utils.h      |  14 +-
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 181 +++++++++---------
 .../Transforms/AffineExpandIndexOps.cpp       |  11 +-
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       |  24 ++-
 .../Affine/affine-expand-index-ops.mlir       |  13 +-
 mlir/test/Dialect/Affine/canonicalize.mlir    | 172 +++++++++++++++++
 mlir/test/Dialect/Affine/invalid.mlir         |  18 +-
 8 files changed, 392 insertions(+), 125 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 6a495e11ae1ad5..53bc7ce0349241 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1059,8 +1059,7 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
 // AffineDelinearizeIndexOp
 //===----------------------------------------------------------------------===//
 
-def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
-    [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
   let summary = "delinearize an index";
   let description = [{
     The `affine.delinearize_index` operation takes a single index value and
@@ -1082,6 +1081,25 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
     %indices_1 = affine.apply #map1()[%linear_index]
     %indices_2 = affine.apply #map2()[%linear_index]
     ```
+
+    The basis may either contain `N` or `N-1` elements, where `N` is the nubrer of results.
+    If there are N basis elements, the first one will not be used during computations,
+    but may be used during analysis and canonicalization to eliminate terms from
+    the `affine.delinearize_index` or to enable conclusions about the total size of
+    `%linear_index`.
+
+    If the basis is fully provided, the delinearize_index operation is said to "have
+    an outer bound". The builders assume that an `affine.delinearize_index` has
+    an outer bound by default, as this is how the operation was initially defined.
+
+    That is, the example above could also have been written
+    ```mlir
+    %0:3 = affine.delinearize_index %linear_index into (244, 244) : index, index
+    ```
+
+    Note that, due to the constraints of affine maps, all the basis elements must
+    be strictly positive. A dynamic basis element being 0 or negative causes
+    undefined behavior.
   }];
 
   let arguments = (ins Index:$linear_index,
@@ -1096,17 +1114,37 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$linear_index, "ValueRange":$basis)>,
-    OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>,
-    OpBuilder<(ins "Value":$linear_index, "ArrayRef<int64_t>":$basis)>
+    OpBuilder<(ins "Value":$linear_index, "ValueRange":$dynamic_basis, "ArrayRef<int64_t>":$static_asis, CArg<"bool", "true">:$hasOuterBound)>,
+    OpBuilder<(ins "Value":$linear_index, "ValueRange":$basis, CArg<"bool", "true">:$hasOuterBound)>,
+    OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis, CArg<"bool", "true">:$hasOuterBound)>,
+    OpBuilder<(ins "Value":$linear_index, "ArrayRef<int64_t>":$basis, CArg<"bool", "true">:$hasOuterBound)>
   ];
 
   let extraClassDeclaration = [{
+    /// Return true if the basis includes a bound on the first index input.
+    bool hasOuterBound() {
+      return getMultiIndex().size() == getStaticBasis().size();
+    }
+
     /// Returns a vector with all the static and dynamic basis values.
     SmallVector<OpFoldResult> getMixedBasis() {
       OpBuilder builder(getContext());
       return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
     }
+
+    /// Return a vector that contains the basis of the operation, removing
+    /// the outer bound if one is present.
+    SmallVector<OpFoldResult> getEffectiveBasis() {
+      OpBuilder builder(getContext());
+      if (hasOuterBound()) {
+        if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
+          return ::mlir::getMixedValues(getStaticBasis().drop_front(), getDynamicBasis().drop_front(), builder);
+
+        return ::mlir::getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(), builder);
+      }
+
+      return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+    }
   }];
 
   let hasVerifier = 1;
@@ -1124,13 +1162,21 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     The `affine.linearize_index` operation takes a sequence of index values and a
     basis of the same length and linearizes the indices using that basis.
 
-    That is, for indices `%idx_1` through `%idx_N` and basis elements `b_1` through `b_N`,
-    it computes
+    That is, for indices `%idx_0` to `%idx_{N-1}` and basis elements `b_0`
+    (or `b_1`) up to `b_{N-1}` it computes
 
     ```
-    sum(i = 1 to N) %idx_i * product(j = i + 1 to N) B_j
+    sum(i = 0 to N-1) %idx_i * product(j = i + 1 to N-1) B_j
     ```
 
+    The basis may either have `N` or `N-1` elements, where `N` is the number of
+    inputs to linearize_index. If `N` inputs are provided, the first one is not used
+    in computation, but may be used during analysis or canonicalization as a bound
+    on `%idx_0`.
+
+    If all `N` basis elements are provided, the linearize_index operation is said to
+    "have an outer bound".
+
     If the `disjoint` property is present, this is an optimization hint that,
     for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
     except that `%idx_0` may be negative to make the index as a whole negative.
@@ -1140,7 +1186,9 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     Example:
 
     ```mlir
-    %linear_index = affine.linearize_index [%index_0, %index_1, %index_2] (2, 3, 5) : index
+    %linear_index = affine.linearize_index [%index_0, %index_1, %index_2] by (2, 3, 5) : index
+    // Same effect
+    %linear_index = affine.linearize_index [%index_0, %index_1, %index_2] by (3, 5) : index
     ```
 
     In the above example, `%linear_index` conceptually holds the following:
@@ -1171,12 +1219,30 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
   ];
 
   let extraClassDeclaration = [{
+    /// Return true if the basis includes a bound on the first index input.
+    bool hasOuterBound() {
+      return getMultiIndex().size() == getStaticBasis().size();
+    }
+
     /// Return a vector with all the static and dynamic basis values.
     SmallVector<OpFoldResult> getMixedBasis() {
       OpBuilder builder(getContext());
       return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
     }
 
+    /// Return a vector that contains the basis of the operation, removing
+    /// the outer bound if one is present.
+    SmallVector<OpFoldResult> getEffectiveBasis() {
+      OpBuilder builder(getContext());
+      if (hasOuterBound()) {
+        if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
+          return ::mlir::getMixedValues(getStaticBasis().drop_front(), getDynamicBasis().drop_front(), builder);
+
+        return ::mlir::getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(), builder);
+      }
+
+      return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+    }
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 0e98223969e08c..0f801ebb6f5898 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -307,17 +307,23 @@ struct DivModValue {
 DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
 
 /// Generate the IR to delinearize `linearIndex` given the `basis` and return
-/// the multi-index.
+/// the multi-index. `hasOuterBound` indicates whether `basis` has an entry
+/// given the size of the first multi-index result - if it is true, the function
+/// will return `basis.size()` values, otherwise, it will return `basis.size() +
+/// 1`.
 FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
                                                Value linearIndex,
-                                               ArrayRef<Value> basis);
+                                               ArrayRef<Value> basis,
+                                               bool hasOuterBound = true);
 
 FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
                                                Value linearIndex,
-                                               ArrayRef<OpFoldResult> basis);
+                                               ArrayRef<OpFoldResult> basis,
+                                               bool hasOuterBound = true);
 
 // Generate IR that extracts the linear index from a multi-index according to
-// a basis/shape.
+// a basis/shape. The basis may contain either `multiIndex.size()` or
+// `multiIndex.size() - 1` elements.
 OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
                             ArrayRef<OpFoldResult> basis,
                             ImplicitLocOpBuilder &builder);
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index fbc9053a0e273b..3693195c39fecb 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
@@ -4503,62 +4504,81 @@ LogicalResult AffineVectorStoreOp::verify() {
 // DelinearizeIndexOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
-    MLIRContext *context, std::optional<::mlir::Location> location,
-    ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
-    RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
-  AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
-                                          regions);
-  inferredReturnTypes.assign(adaptor.getStaticBasis().size(),
-                             IndexType::get(context));
-  return success();
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                     OperationState &odsState,
+                                     Value linearIndex, ValueRange dynamicBasis,
+                                     ArrayRef<int64_t> staticBasis,
+                                     bool hasOuterBound) {
+  SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
+                                              : staticBasis.size() + 1,
+                                linearIndex.getType());
+  build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
+        staticBasis);
 }
 
 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
                                      OperationState &odsState,
-                                     Value linearIndex, ValueRange basis) {
+                                     Value linearIndex, ValueRange basis,
+                                     bool hasOuterBound) {
   SmallVector<Value> dynamicBasis;
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
                              staticBasis);
-  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
+        hasOuterBound);
 }
 
 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
                                      OperationState &odsState,
                                      Value linearIndex,
-                                     ArrayRef<OpFoldResult> basis) {
+                                     ArrayRef<OpFoldResult> basis,
+                                     bool hasOuterBound) {
   SmallVector<Value> dynamicBasis;
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
-  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
+        hasOuterBound);
 }
 
 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
                                      OperationState &odsState,
-                                     Value linearIndex,
-                                     ArrayRef<int64_t> basis) {
-  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis);
+                                     Value linearIndex, ArrayRef<int64_t> basis,
+                                     bool hasOuterBound) {
+  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
 }
 
 LogicalResult AffineDelinearizeIndexOp::verify() {
-  if (getStaticBasis().empty())
-    return emitOpError("basis should not be empty");
-  if (getNumResults() != getStaticBasis().size())
-    return emitOpError("should return an index for each basis element");
-  auto dynamicMarkersCount =
-      llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
+  ArrayRef<int64_t> staticBasis = getStaticBasis();
+  if (getNumResults() != staticBasis.size() &&
+      getNumResults() != staticBasis.size() + 1)
+    return emitOpError("should return an index for each basis element and up "
+                       "to one extra index");
+
+  auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
   if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
     return emitOpError(
         "mismatch between dynamic and static basis (kDynamic marker but no "
         "corresponding dynamic basis entry) -- this can only happen due to an "
         "incorrect fold/rewrite");
+
+  if (!llvm::all_of(staticBasis, [](int64_t v) {
+        return v > 0 || ShapedType::isDynamic(v);
+      }))
+    return emitOpError("no basis element may be statically non-positive");
+
   return success();
 }
 
 LogicalResult
 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
                                SmallVectorImpl<OpFoldResult> &result) {
+  // If we won't be doing any division or modulo (no basis or the one basis
+  // element is purely advisory), simply return the input value.
+  if (getStaticBasis().size() == static_cast<size_t>(hasOuterBound())) {
+    result.push_back(getLinearIndex());
+    return success();
+  }
+
   if (adaptor.getLinearIndex() == nullptr)
     return failure();
 
@@ -4567,7 +4587,11 @@ AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
 
   int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
   Type attrType = getLinearIndex().getType();
-  for (int64_t modulus : llvm::reverse(getStaticBasis().drop_front())) {
+
+  ArrayRef<int64_t> staticBasis = getStaticBasis();
+  if (hasOuterBound())
+    staticBasis = staticBasis.drop_front();
+  for (int64_t modulus : llvm::reverse(staticBasis)) {
     result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
     highPart = llvm::divideFloorSigned(highPart, modulus);
   }
@@ -4594,24 +4618,25 @@ struct DropUnitExtentBasis
       return zero.value();
     };
 
+    bool hasOuterBound = delinearizeOp.hasOuterBound();
     // Replace all indices corresponding to unit-extent basis with 0.
     // Remaining basis can be used to get a new `affine.delinearize_index` op.
-    SmallVector<OpFoldResult> newOperands;
+    SmallVector<OpFoldResult> newBasis;
     for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
       std::optional<int64_t> basisVal = getConstantIntValue(basis);
       if (basisVal && *basisVal == 1)
-        replacements[index] = getZero();
+        replacements[index + (hasOuterBound ? 0 : 1)] = getZero();
       else
-        newOperands.push_back(basis);
+        newBasis.push_back(basis);
     }
 
-    if (newOperands.size() == delinearizeOp.getStaticBasis().size())
+    if (newBasis.size() == delinearizeOp.getStaticBasis().size())
       return rewriter.notifyMatchFailure(delinearizeOp,
                                          "no unit basis elements");
 
-    if (!newOperands.empty()) {
+    if (!newBasis.empty() || !hasOuterBound) {
       auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
-          loc, delinearizeOp.getLinearIndex(), newOperands);
+          loc, delinearizeOp.getLinearIndex(), newBasis, hasOuterBound);
       int newIndex = 0;
       // Map back the new delinearized indices to the values they replace.
       for (auto &replacement : replacements) {
@@ -4626,27 +4651,6 @@ struct DropUnitExtentBasis
   }
 };
 
-/// Drop delinearization with a single basis element
-///
-/// By definition, `delinearize_index %linear into (%basis)` is
-/// `%linear floorDiv 1` (since `1` is the product of the basis elememts,
-/// ignoring the 0th one, and since there is no previous division we need
-/// to use the remainder of). Therefore, a single-element `delinearize`
-/// can be replaced by the underlying linear index.
-struct DropDelinearizeOneBasisElement
-    : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
-                                PatternRewriter &rewriter) const override {
-    if (delinearizeOp.getStaticBasis().size() != 1)
-      return rewriter.notifyMatchFailure(delinearizeOp,
-                                         "doesn't have a length-1 basis");
-    rewriter.replaceOp(delinearizeOp, delinearizeOp.getLinearIndex());
-    return success();
-  }
-};
-
 /// If a `affine.delinearize_index`'s input is a `affine.linearize_index
 /// disjoint` and the two operations have the same basis, replace the
 /// delinearizeation results with the inputs of the `affine.linearize_index`
@@ -4668,7 +4672,7 @@ struct CancelDelinearizeOfLinearizeDisjointExact
                                          "index doesn't come from linearize");
 
     if (!linearizeOp.getDisjoint() ||
-        linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
+        linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
       return rewriter.notifyMatchFailure(
           linearizeOp, "not disjoint or basis doesn't match delinearize");
 
@@ -4680,8 +4684,9 @@ struct CancelDelinearizeOfLinearizeDisjointExact
 
 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.insert<CancelDelinearizeOfLinearizeDisjointExact,
-                  DropDelinearizeOneBasisElement, DropUnitExtentBasis>(context);
+  patterns
+      .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -4718,11 +4723,11 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
 }
 
 LogicalResult AffineLinearizeIndexOp::verify() {
-  if (getStaticBasis().empty())
-    return emitOpError("basis should not be empty");
-
-  if (getMultiIndex().size() != getStaticBasis().size())
-    return emitOpError("should be passed an index for each basis element");
+  size_t numIndexes = getMultiIndex().size();
+  size_t numBasisElems = getStaticBasis().size();
+  if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
+    return emitOpError("should be passed a basis element for each index except "
+                       "possibly the first");
 
   auto dynamicMarkersCount =
       llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
@@ -4736,6 +4741,14 @@ LogicalResult AffineLinearizeIndexOp::verify() {
 }
 
 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
+  // No indices linearizes to zero.
+  if (getMultiIndex().empty())
+    return IntegerAttr::get(getResult().getType(), 0);
+
+  // One single index linearizes to itself.
+  if (getMultiIndex().size() == 1)
+    return getMultiIndex().front();
+
   if (llvm::any_of(adaptor.getMultiIndex(),
                    [](Attribute a) { return a == nullptr; }))
     return nullptr;
@@ -4745,12 +4758,17 @@ OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
 
   int64_t result = 0;
   int64_t stride = 1;
-  for (auto [indexAttr, length] :
-       llvm::zip_equal(llvm::reverse(adaptor.getMultiIndex()),
-                       llvm::reverse(getStaticBasis()))) {
+  for (auto [length, indexAttr] :
+       llvm::zip_first(llvm::reverse(getStaticBasis()),
+                       llvm::reverse(adaptor.getMultiIndex()))) {
     result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
     stride = stride * length;
   }
+  // Handle the index element with no basis element.
+  if (!hasOuterBound())
+    result =
+        result +
+        cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
 
   return IntegerAttr::get(getResult().getType(), result);
 }
@@ -4772,14 +4790,20 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
 
   LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    size_t numIndices = op.getMultiIndex().size();
+    ValueRange multiIndex = op.getMultiIndex();
+    size_t numIndices = multiIndex.size();
     SmallVector<Value> newIndices;
     newIndices.reserve(numIndices);
     SmallVector<OpFoldResult> newBasis;
     newBasis.reserve(numIndices);
 
+    if (!op.hasOuterBound()) {
+      newIndices.push_back(multiIndex.front());
+      multiIndex = multiIndex.drop_front();
+    }
+
     SmallVector<OpFoldResult> basis = op.getMixedBasis();
-    for (auto [index, basisElem] : llvm::zip_equal(op.getMultiIndex(), basis)) {
+    for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
       std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
       if (!basisEntry || *basisEntry != 1) {
         newIndices.push_back(index);
@@ -4808,23 +4832,6 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
   }
 };
 
-/// Rewrite `affine.linearize_index [%%x] by (%b)`, into `%x`.
-///
-/// By definition, that operation is `affine.apply affine_map<()[s0] -> (s0)>,`
-/// which is the identity.
-struct DropLinearizeOneBasisElement final
-    : OpRewritePattern<affine::AffineLinearizeIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
-                                PatternRewriter &rewriter) const override {
-    if (op.getStaticBasis().size() != 1 || op.getMultiIndex().size() != 1)
-      return rewriter.notifyMatchFailure(op, "doesn't have a a length-1 basis");
-    rewriter.replaceOp(op, op.getMultiIndex().front());
-    return success();
-  }
-};
-
 /// Cancel out linearize_index(delinearize_index(x, B), B).
 ///
 /// That is, rewrite
@@ -4847,10 +4854,10 @@ struct CancelLinearizeOfDelinearizeExact final
       return rewriter.notifyMatchFailure(
           linearizeOp, "last entry doesn't come from a delinearize");
 
-    if (linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
+    if (linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
       return rewriter.notifyMatchFailure(
-          linearizeOp,
-          "basis of linearize and delinearize don't match exactly");
+          linearizeOp, "basis of linearize and delinearize don't match exactly "
+                       "(excluding outer bounds)");
 
     if (delinearizeOp.getResults() != linearizeOp.getMultiIndex())
       return rewriter.notifyMatchFailure(
@@ -4881,9 +4888,12 @@ struct DropLinearizeLeadingZero final
     }
 
     SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
+    ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
+    if (op.hasOuterBound())
+      newMixedBasis = newMixedBasis.drop_front();
+
     rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
-        op, op.getMultiIndex().drop_front(),
-        ArrayRef<OpFoldResult>(mixedBasis).drop_front(), op.getDisjoint());
+        op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
     return success();
   }
 };
@@ -4892,7 +4902,6 @@ struct DropLinearizeLeadingZero final
 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
-               DropLinearizeOneBasisElement,
                DropLinearizeUnitComponentsIfDisjointOrZero>(context);
 }
 
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 1930e987a33ffa..15478e0e1e3a5b 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -36,8 +36,9 @@ struct LowerDelinearizeIndexOps
   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    FailureOr<SmallVector<Value>> multiIndex = delinearizeIndex(
-        rewriter, op->getLoc(), op.getLinearIndex(), op.getMixedBasis());
+    FailureOr<SmallVector<Value>> multiIndex =
+        delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
+                         op.getEffectiveBasis(), /*hasOuterBound=*/false);
     if (failed(multiIndex))
       return failure();
     rewriter.replaceOp(op, *multiIndex);
@@ -51,6 +52,12 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
+    // Should be folded away, included here for safety.
+    if (op.getMultiIndex().empty()) {
+      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+      return success();
+    }
+
     SmallVector<OpFoldResult> multiIndex =
         getAsOpFoldResult(op.getMultiIndex());
     OpFoldResult linearIndex =
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 7fe422f75c8fad..3420db771ef426 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1944,11 +1944,14 @@ static FailureOr<OpFoldResult> composedAffineMultiply(OpBuilder &b,
 
 FailureOr<SmallVector<Value>>
 mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
-                               ArrayRef<Value> basis) {
+                               ArrayRef<Value> basis, bool hasOuterBound) {
+  if (hasOuterBound)
+    basis = basis.drop_front();
+
   // Note: the divisors are backwards due to the scan.
   SmallVector<Value> divisors;
   OpFoldResult basisProd = b.getIndexAttr(1);
-  for (OpFoldResult basisElem : llvm::reverse(basis.drop_front())) {
+  for (OpFoldResult basisElem : llvm::reverse(basis)) {
     FailureOr<OpFoldResult> nextProd =
         composedAffineMultiply(b, loc, basisElem, basisProd);
     if (failed(nextProd))
@@ -1971,11 +1974,15 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
 
 FailureOr<SmallVector<Value>>
 mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
-                               ArrayRef<OpFoldResult> basis) {
+                               ArrayRef<OpFoldResult> basis,
+                               bool hasOuterBound) {
+  if (hasOuterBound)
+    basis = basis.drop_front();
+
   // Note: the divisors are backwards due to the scan.
   SmallVector<Value> divisors;
   OpFoldResult basisProd = b.getIndexAttr(1);
-  for (OpFoldResult basisElem : llvm::reverse(basis.drop_front())) {
+  for (OpFoldResult basisElem : llvm::reverse(basis)) {
     FailureOr<OpFoldResult> nextProd =
         composedAffineMultiply(b, loc, basisElem, basisProd);
     if (failed(nextProd))
@@ -2005,8 +2012,15 @@ OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
 OpFoldResult mlir::affine::linearizeIndex(OpBuilder &builder, Location loc,
                                           ArrayRef<OpFoldResult> multiIndex,
                                           ArrayRef<OpFoldResult> basis) {
-  assert(multiIndex.size() == basis.size());
+  assert(multiIndex.size() == basis.size() ||
+         multiIndex.size() == basis.size() + 1);
   SmallVector<AffineExpr> basisAffine;
+
+  // Add a fake initial size in order to make the later index linearization
+  // computations line up if an outer bound is not provided.
+  if (multiIndex.size() == basis.size() + 1)
+    basisAffine.push_back(getAffineConstantExpr(1, builder.getContext()));
+
   for (size_t i = 0; i < basis.size(); ++i) {
     basisAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
   }
diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
index ded1687ca560b2..650555cfb5fe13 100644
--- a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
@@ -35,10 +35,10 @@ func.func @dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (inde
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
-  %b0 = memref.dim %src, %c0 : memref<?x?x?xf32>
   %b1 = memref.dim %src, %c1 : memref<?x?x?xf32>
   %b2 = memref.dim %src, %c2 : memref<?x?x?xf32>
-  %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
+  // Note: no outer bound.
+  %1:3 = affine.delinearize_index %linear_index into (%b1, %b2) : index, index, index
   return %1#0, %1#1, %1#2 : index, index, index
 }
 
@@ -60,10 +60,11 @@ func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
 // CHECK-DAG: #[[$map0:.+]] =  affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s2 + s3 + s0 * (s2 * s4))>
 
 // CHECK-LABEL: @linearize_dynamic
-// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index, %[[arg5:.+]]: index)
-// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg5]], %[[arg2]], %[[arg4]]]
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index)
+// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg4]], %[[arg2]], %[[arg3]]]
 // CHECK: return %[[val_0]]
-func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> index {
-  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4, %arg5) : index
+func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> index {
+  // Note: no outer bounds
+  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index
   func.return %0 : index
 }
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index ec00b31258d072..b54a13cffe7771 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1496,6 +1496,20 @@ func.func @delinearize_fold_negative_constant() -> (index, index, index) {
 
 // -----
 
+// CHECK-LABEL: @delinearize_fold_negative_constant_no_outer_bound
+// CHECK-DAG: %[[C_2:.+]] = arith.constant -2 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C_2]], %[[C1]], %[[C3]]
+func.func @delinearize_fold_negative_constant_no_outer_bound() -> (index, index, index) {
+  %c_22 = arith.constant -22 : index
+  %0:3 = affine.delinearize_index %c_22 into (3, 5) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
 // CHECK-LABEL: @delinearize_dont_fold_constant_dynamic_basis
 // CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
 // CHECK: %[[RET:.+]]:3 = affine.delinearize_index %[[C22]]
@@ -1525,6 +1539,23 @@ func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 :
 
 // -----
 
+func.func @drop_unit_basis_in_delinearize_no_outer_bound(%arg0 : index, %arg1 : index, %arg2 : index) ->
+    (index, index, index, index, index, index) {
+  %c1 = arith.constant 1 : index
+  %0:6 = affine.delinearize_index %arg0 into (%arg1, 1, 1, %arg2, %c1)
+      : index, index, index, index, index, index
+  return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : index, index, index, index, index, index
+}
+// CHECK-LABEL: func @drop_unit_basis_in_delinearize_no_outer_bound(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[DELINEARIZE:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], %[[ARG2]])
+//       CHECK:   return %[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[C0]], %[[C0]], %[[DELINEARIZE]]#2, %[[C0]]
+
+// -----
+
 func.func @drop_all_unit_bases(%arg0 : index) -> (index, index) {
   %0:2 = affine.delinearize_index %arg0 into (1, 1) : index, index
   return %0#0, %0#1 : index, index
@@ -1537,6 +1568,18 @@ func.func @drop_all_unit_bases(%arg0 : index) -> (index, index) {
 
 // -----
 
+func.func @drop_all_unit_bases_no_outer_bound(%arg0 : index) -> (index, index, index) {
+  %0:3 = affine.delinearize_index %arg0 into (1, 1) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+// CHECK-LABEL: func @drop_all_unit_bases_no_outer_bound(
+//  CHECK-SAME:     %[[ARG0:.+]]: index)
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-NOT:   affine.delinearize_index
+//       CHECK:   return %[[ARG0]], %[[C0]], %[[C0]]
+
+// -----
+
 func.func @drop_single_loop_delinearize(%arg0 : index, %arg1 : index) -> index {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -1574,6 +1617,17 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
 
 // -----
 
+// CHECK-LABEL: func @delinearize_empty_basis
+// CHECK-SAME: (%[[ARG0:.+]]: index)
+// CHECK-NOT: affine.delinearize
+// CHECK: return %[[ARG0]]
+func.func @delinearize_empty_basis(%arg0: index) -> index {
+  %0 = affine.delinearize_index %arg0 into () : index
+  return %0 : index
+}
+
+// -----
+
 // CHECK-LABEL: @linearize_fold_constants
 // CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
 // CHECK-NOT: affine.linearize
@@ -1588,6 +1642,42 @@ func.func @linearize_fold_constants() -> index {
 
 // -----
 
+// CHECK-LABEL: @linearize_fold_constants_no_outer_bound
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK-NOT: affine.linearize
+// CHECK: return %[[C22]]
+func.func @linearize_fold_constants_no_outer_bound() -> index {
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+
+  %ret = affine.linearize_index [%c1, %c1, %c2] by (3, 5) : index
+  return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_fold_empty_basis
+// CHECK-SAME: (%[[ARG0:.+]]: index)
+// CHECK-NOT: affine.linearize
+// CHECK: return %[[ARG0]]
+func.func @linearize_fold_empty_basis(%arg0: index) -> index {
+  %ret = affine.linearize_index [%arg0] by () : index
+  return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_fold_only_outer_bound
+// CHECK-SAME: (%[[ARG0:.+]]: index)
+// CHECK-NOT: affine.linearize
+// CHECK: return %[[ARG0]]
+func.func @linearize_fold_only_outer_bound(%arg0: index) -> index {
+  %ret = affine.linearize_index [%arg0] by (2) : index
+  return %ret : index
+}
+
+// -----
+
 // CHECK-LABEL: @linearize_dont_fold_dynamic_basis
 // CHECK: %[[RET:.+]] = affine.linearize_index
 // CHECK: return %[[RET]]
@@ -1617,6 +1707,38 @@ func.func @cancel_delinearize_linearize_disjoint_exact(%arg0: index, %arg1: inde
 
 // -----
 
+// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_linearize_extra_bound(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index)
+//       CHECK:     return %[[ARG0]], %[[ARG1]], %[[ARG2]]
+func.func @cancel_delinearize_linearize_disjoint_linearize_extra_bound(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (4, %arg4) : index
+  %1:3 = affine.delinearize_index %0 into (4, %arg4)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_delinearize_extra_bound(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index)
+//       CHECK:     return %[[ARG0]], %[[ARG1]], %[[ARG2]]
+func.func @cancel_delinearize_linearize_disjoint_delinearize_extra_bound(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (4, %arg4) : index
+  %1:3 = affine.delinearize_index %0 into (%arg3, 4, %arg4)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
 // Without `disjoint`, the cancelation isn't guaranteed to be the identity.
 // CHECK-LABEL: func @no_cancel_delinearize_linearize_exact(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
@@ -1666,6 +1788,17 @@ func.func @linearize_unit_basis_disjoint(%arg0: index, %arg1: index, %arg2: inde
 
 // -----
 
+// CHECK-LABEL: @linearize_unit_basis_disjoint_no_outer_bound
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
+// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (%[[arg3]]) : index
+// CHECK: return %[[ret]]
+func.func @linearize_unit_basis_disjoint_no_outer_bound(%arg0: index, %arg1: index, %arg2: index, %arg3: index) -> index {
+  %ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (1, %arg3) : index
+  return %ret : index
+}
+
+// -----
+
 // CHECK-LABEL: @linearize_unit_basis_zero
 // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
 // CHECK: %[[ret:.+]] = affine.linearize_index [%[[arg0]], %[[arg1]]] by (3, %[[arg2]]) : index
@@ -1713,6 +1846,32 @@ func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: i
 
 // -----
 
+// CHECK-LABEL: func @cancel_linearize_denearize_linearize_extra_bound(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     return %[[ARG0]]
+func.func @cancel_linearize_denearize_linearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0:3 = affine.delinearize_index %arg0 into (4, %arg2) : index, index, index
+  %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_denearize_delinearize_extra_bound(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     return %[[ARG0]]
+func.func @cancel_linearize_denearize_delinearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+  %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (4, %arg2) : index
+  return %1 : index
+}
+
+// -----
+
 // Don't cancel because the values from the delinearize aren't used in order
 // CHECK-LABEL: func @no_cancel_linearize_denearize_permuted(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
@@ -1756,3 +1915,16 @@ func.func @affine_leading_zero(%arg0: index, %arg1: index) -> index {
   return %ret : index
 }
 
+// -----
+
+// CHECK-LABEL: func @affine_leading_zero_no_outer_bound(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[RET:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (3, 5)
+//       CHECK:     return %[[RET]]
+func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> index {
+  %c0 = arith.constant 0 : index
+  %ret = affine.linearize_index [%c0, %arg0, %arg1] by (3, 5) : index
+  return %ret : index
+}
+
diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 2996194170900f..1539b4f4848276 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -533,37 +533,29 @@ func.func @missing_for_min(%arg0: index, %arg1: index, %arg2: memref<100xf32>) {
 // -----
 
 func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) {
-  // expected-error at +1 {{'affine.delinearize_index' op should return an index for each basis element}}
+  // expected-error at +1 {{'affine.delinearize_index' op should return an index for each basis element and up to one extra index}}
   %1 = affine.delinearize_index %idx into (%basis0, %basis1) : index
   return
 }
 
 // -----
 
-func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) {
-  // expected-error at +1 {{'affine.delinearize_index' op basis should not be empty}}
-  affine.delinearize_index %idx into () : index
+func.func @delinearize(%idx: index) {
+  // expected-error at +1 {{'affine.delinearize_index' op no basis element may be statically non-positive}}
+  %1:2 = affine.delinearize_index %idx into (2, -2) : index, index
   return
 }
 
 // -----
 
 func.func @linearize(%idx: index, %basis0: index, %basis1 :index) -> index {
-  // expected-error at +1 {{'affine.linearize_index' op should be passed an index for each basis element}}
+  // expected-error at +1 {{'affine.linearize_index' op should be passed a basis element for each index except possibly the first}}
   %0 = affine.linearize_index [%idx] by (%basis0, %basis1) : index
   return %0 : index
 }
 
 // -----
 
-func.func @linearize_empty() -> index {
-  // expected-error at +1 {{'affine.linearize_index' op basis should not be empty}}
-  %0 = affine.linearize_index [] by () : index
-  return %0 : index
-}
-
-// -----
-
 func.func @dynamic_dimension_index() {
   "unknown.region"() ({
     %idx = "unknown.test"() : () -> (index)

>From 658824c1f3f6cd49dba058629c93c56c1e017722 Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Wed, 13 Nov 2024 23:47:29 +0000
Subject: [PATCH 2/3] Python test

---
 mlir/test/python/dialects/affine.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index 0dc69d7ba522de..fb6908d1769720 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -50,7 +50,7 @@ def testAffineDelinearizeInfer():
     # CHECK: %[[C1:.*]] = arith.constant 1 : index
     c1 = arith.ConstantOp(T.index(), 1)
     # CHECK: %{{.*}}:2 = affine.delinearize_index %[[C1:.*]] into (2, 3) : index, index
-    two_indices = affine.AffineDelinearizeIndexOp(c1, [], [2, 3])
+    two_indices = affine.AffineDelinearizeIndexOp([T.index()] * 2, c1, [], [2, 3])
 
 
 # CHECK-LABEL: TEST: testAffineLoadOp
@@ -157,7 +157,7 @@ def testAffineForOpErrors():
         )
 
     try:
-        two_indices = affine.AffineDelinearizeIndexOp(c1, [], [1, 1])
+        two_indices = affine.AffineDelinearizeIndexOp([T.index()] * 2, c1, [], [1, 1])
         affine.AffineForOp(
             two_indices,
             c2,

>From 11332e3d16eae8b2e7d63fd2e24fcf838e26271e Mon Sep 17 00:00:00 2001
From: Krzysztof Drewniak <Krzysztof.Drewniak at amd.com>
Date: Thu, 14 Nov 2024 17:43:42 +0000
Subject: [PATCH 3/3] Address review comments

---
 .../mlir/Dialect/Affine/IR/AffineOps.td       | 26 ++--------------
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 30 ++++++++++++++++++-
 2 files changed, 32 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 53bc7ce0349241..f8a9f382b3d1ee 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1082,7 +1082,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
     %indices_2 = affine.apply #map2()[%linear_index]
     ```
 
-    The basis may either contain `N` or `N-1` elements, where `N` is the nubrer of results.
+    The basis may either contain `N` or `N-1` elements, where `N` is the number of results.
     If there are N basis elements, the first one will not be used during computations,
     but may be used during analysis and canonicalization to eliminate terms from
     the `affine.delinearize_index` or to enable conclusions about the total size of
@@ -1134,17 +1134,7 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
 
     /// Return a vector that contains the basis of the operation, removing
     /// the outer bound if one is present.
-    SmallVector<OpFoldResult> getEffectiveBasis() {
-      OpBuilder builder(getContext());
-      if (hasOuterBound()) {
-        if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
-          return ::mlir::getMixedValues(getStaticBasis().drop_front(), getDynamicBasis().drop_front(), builder);
-
-        return ::mlir::getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(), builder);
-      }
-
-      return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
-    }
+    SmallVector<OpFoldResult> getEffectiveBasis();
   }];
 
   let hasVerifier = 1;
@@ -1232,17 +1222,7 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
 
     /// Return a vector that contains the basis of the operation, removing
     /// the outer bound if one is present.
-    SmallVector<OpFoldResult> getEffectiveBasis() {
-      OpBuilder builder(getContext());
-      if (hasOuterBound()) {
-        if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
-          return ::mlir::getMixedValues(getStaticBasis().drop_front(), getDynamicBasis().drop_front(), builder);
-
-        return ::mlir::getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(), builder);
-      }
-
-      return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
-    }
+    SmallVector<OpFoldResult> getEffectiveBasis();
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 3693195c39fecb..4cf07bc167eab9 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4574,7 +4574,7 @@ AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
                                SmallVectorImpl<OpFoldResult> &result) {
   // If we won't be doing any division or modulo (no basis or the one basis
   // element is purely advisory), simply return the input value.
-  if (getStaticBasis().size() == static_cast<size_t>(hasOuterBound())) {
+  if (getNumResults() == 1) {
     result.push_back(getLinearIndex());
     return success();
   }
@@ -4600,6 +4600,20 @@ AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
   return success();
 }
 
+SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
+  OpBuilder builder(getContext());
+  if (hasOuterBound()) {
+    if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
+      return getMixedValues(getStaticBasis().drop_front(),
+                            getDynamicBasis().drop_front(), builder);
+
+    return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
+                          builder);
+  }
+
+  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+}
+
 namespace {
 
 // Drops delinearization indices that correspond to unit-extent basis
@@ -4773,6 +4787,20 @@ OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
   return IntegerAttr::get(getResult().getType(), result);
 }
 
+SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
+  OpBuilder builder(getContext());
+  if (hasOuterBound()) {
+    if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
+      return getMixedValues(getStaticBasis().drop_front(),
+                            getDynamicBasis().drop_front(), builder);
+
+    return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
+                          builder);
+  }
+
+  return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+}
+
 namespace {
 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,



More information about the Mlir-commits mailing list