[Mlir-commits] [mlir] 31aa7f3 - [mlir][Affine] Let affine.[de]linearize_index omit outer bounds (#116103)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 18 13:41:58 PST 2024


Author: Krzysztof Drewniak
Date: 2024-11-18T15:41:54-06:00
New Revision: 31aa7f34e07c901773993dac0f33568307f96da6

URL: https://github.com/llvm/llvm-project/commit/31aa7f34e07c901773993dac0f33568307f96da6
DIFF: https://github.com/llvm/llvm-project/commit/31aa7f34e07c901773993dac0f33568307f96da6.diff

LOG: [mlir][Affine] Let affine.[de]linearize_index omit outer bounds (#116103)

The affine.delinearize_index and affine.linearize_index operations, as
currently defined, require providing a length N basis to [de]linearize N
values. The first value in this basis is never used during lowering and
is unused during lowering. (Note that, even though it isn't used during
lowering it can still be used to, for example, remove length-1 outputs
from a delinearize).

This dead value makes sense in the original context of these operations,
which is linearizing or de-linearizing indexes to memref<>s, vector<>s,
and other shaped types, where that outer bound is avaliable and may be
useful for analysis.

However, other usecases exist where the outer bound is not known. For
example:

    %thread_id_x = gpu.thread_id x : index
%0:3 = affine.delinearize_index %thread_id_x into (4, 16) : index,index,
index

In this code, we don't know the upper bound of the thread ID, but we do
want to construct the ?x4x16 grid of delinearized values in order to
further partition the GPU threads.

In order to support such usecases, we broaden the definition of
affine.delinearize_index and affine.linearize_index to make the outer
bound optional.

In the case of affine.delinearize_index, where the number of results is
a function of the size of the passed-in basis, we augment all existing
builders with a `hasOuterBound` argument, which, for backwards
compatibilty and to preserve the natural usage of the op, defaults to
`true`. If this flag is true, the op returns one result per basis
element, if it is false, it returns one extra result in position 0.

We also update existing canonicalization patterns (and move one of them
into the folder) to handle these cases. Note that disagreements about
the outer bound now no longer prevent delinearize/linearize
cancelations.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/include/mlir/Dialect/Affine/Utils.h
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
    mlir/lib/Dialect/Affine/Utils/Utils.cpp
    mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
    mlir/test/Dialect/Affine/canonicalize.mlir
    mlir/test/Dialect/Affine/invalid.mlir
    mlir/test/python/dialects/affine.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index ea65911af43a1e..76d97f106dcb88 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1060,8 +1060,7 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
 // AffineDelinearizeIndexOp
 //===----------------------------------------------------------------------===//
 
-def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
-    [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
   let summary = "delinearize an index";
   let description = [{
     The `affine.delinearize_index` operation takes a single index value and
@@ -1083,6 +1082,25 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
     %indices_1 = affine.apply #map1()[%linear_index]
     %indices_2 = affine.apply #map2()[%linear_index]
     ```
+
+    The basis may either contain `N` or `N-1` elements, where `N` is the number of results.
+    If there are N basis elements, the first one will not be used during computations,
+    but may be used during analysis and canonicalization to eliminate terms from
+    the `affine.delinearize_index` or to enable conclusions about the total size of
+    `%linear_index`.
+
+    If the basis is fully provided, the delinearize_index operation is said to "have
+    an outer bound". The builders assume that an `affine.delinearize_index` has
+    an outer bound by default, as this is how the operation was initially defined.
+
+    That is, the example above could also have been written
+    ```mlir
+    %0:3 = affine.delinearize_index %linear_index into (244, 244) : index, index
+    ```
+
+    Note that, due to the constraints of affine maps, all the basis elements must
+    be strictly positive. A dynamic basis element being 0 or negative causes
+    undefined behavior.
   }];
 
   let arguments = (ins Index:$linear_index,
@@ -1097,17 +1115,27 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$linear_index, "ValueRange":$basis)>,
-    OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis)>,
-    OpBuilder<(ins "Value":$linear_index, "ArrayRef<int64_t>":$basis)>
+    OpBuilder<(ins "Value":$linear_index, "ValueRange":$dynamic_basis, "ArrayRef<int64_t>":$static_asis, CArg<"bool", "true">:$hasOuterBound)>,
+    OpBuilder<(ins "Value":$linear_index, "ValueRange":$basis, CArg<"bool", "true">:$hasOuterBound)>,
+    OpBuilder<(ins "Value":$linear_index, "ArrayRef<OpFoldResult>":$basis, CArg<"bool", "true">:$hasOuterBound)>,
+    OpBuilder<(ins "Value":$linear_index, "ArrayRef<int64_t>":$basis, CArg<"bool", "true">:$hasOuterBound)>
   ];
 
   let extraClassDeclaration = [{
+    /// Return true if the basis includes a bound on the first index input.
+    bool hasOuterBound() {
+      return getMultiIndex().size() == getStaticBasis().size();
+    }
+
     /// Returns a vector with all the static and dynamic basis values.
     SmallVector<OpFoldResult> getMixedBasis() {
       OpBuilder builder(getContext());
       return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
     }
+
+    /// Return a vector that contains the basis of the operation, removing
+    /// the outer bound if one is present.
+    SmallVector<OpFoldResult> getEffectiveBasis();
   }];
 
   let hasVerifier = 1;
@@ -1125,13 +1153,21 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     The `affine.linearize_index` operation takes a sequence of index values and a
     basis of the same length and linearizes the indices using that basis.
 
-    That is, for indices `%idx_1` through `%idx_N` and basis elements `b_1` through `b_N`,
-    it computes
+    That is, for indices `%idx_0` to `%idx_{N-1}` and basis elements `b_0`
+    (or `b_1`) up to `b_{N-1}` it computes
 
     ```
-    sum(i = 1 to N) %idx_i * product(j = i + 1 to N) B_j
+    sum(i = 0 to N-1) %idx_i * product(j = i + 1 to N-1) B_j
     ```
 
+    The basis may either have `N` or `N-1` elements, where `N` is the number of
+    inputs to linearize_index. If `N` inputs are provided, the first one is not used
+    in computation, but may be used during analysis or canonicalization as a bound
+    on `%idx_0`.
+
+    If all `N` basis elements are provided, the linearize_index operation is said to
+    "have an outer bound".
+
     If the `disjoint` property is present, this is an optimization hint that,
     for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index,
     except that `%idx_0` may be negative to make the index as a whole negative.
@@ -1141,7 +1177,9 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     Example:
 
     ```mlir
-    %linear_index = affine.linearize_index [%index_0, %index_1, %index_2] (2, 3, 5) : index
+    %linear_index = affine.linearize_index [%index_0, %index_1, %index_2] by (2, 3, 5) : index
+    // Same effect
+    %linear_index = affine.linearize_index [%index_0, %index_1, %index_2] by (3, 5) : index
     ```
 
     In the above example, `%linear_index` conceptually holds the following:
@@ -1172,12 +1210,20 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
   ];
 
   let extraClassDeclaration = [{
+    /// Return true if the basis includes a bound on the first index input.
+    bool hasOuterBound() {
+      return getMultiIndex().size() == getStaticBasis().size();
+    }
+
     /// Return a vector with all the static and dynamic basis values.
     SmallVector<OpFoldResult> getMixedBasis() {
       OpBuilder builder(getContext());
       return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
     }
 
+    /// Return a vector that contains the basis of the operation, removing
+    /// the outer bound if one is present.
+    SmallVector<OpFoldResult> getEffectiveBasis();
   }];
 
   let hasVerifier = 1;

diff  --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 0e98223969e08c..0f801ebb6f5898 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -307,17 +307,23 @@ struct DivModValue {
 DivModValue getDivMod(OpBuilder &b, Location loc, Value lhs, Value rhs);
 
 /// Generate the IR to delinearize `linearIndex` given the `basis` and return
-/// the multi-index.
+/// the multi-index. `hasOuterBound` indicates whether `basis` has an entry
+/// given the size of the first multi-index result - if it is true, the function
+/// will return `basis.size()` values, otherwise, it will return `basis.size() +
+/// 1`.
 FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
                                                Value linearIndex,
-                                               ArrayRef<Value> basis);
+                                               ArrayRef<Value> basis,
+                                               bool hasOuterBound = true);
 
 FailureOr<SmallVector<Value>> delinearizeIndex(OpBuilder &b, Location loc,
                                                Value linearIndex,
-                                               ArrayRef<OpFoldResult> basis);
+                                               ArrayRef<OpFoldResult> basis,
+                                               bool hasOuterBound = true);
 
 // Generate IR that extracts the linear index from a multi-index according to
-// a basis/shape.
+// a basis/shape. The basis may contain either `multiIndex.size()` or
+// `multiIndex.size() - 1` elements.
 OpFoldResult linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
                             ArrayRef<OpFoldResult> basis,
                             ImplicitLocOpBuilder &builder);

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index fbc9053a0e273b..4cf07bc167eab9 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -20,6 +20,7 @@
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVectorExtras.h"
@@ -4503,62 +4504,81 @@ LogicalResult AffineVectorStoreOp::verify() {
 // DelinearizeIndexOp
 //===----------------------------------------------------------------------===//
 
-LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
-    MLIRContext *context, std::optional<::mlir::Location> location,
-    ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
-    RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
-  AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
-                                          regions);
-  inferredReturnTypes.assign(adaptor.getStaticBasis().size(),
-                             IndexType::get(context));
-  return success();
+void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
+                                     OperationState &odsState,
+                                     Value linearIndex, ValueRange dynamicBasis,
+                                     ArrayRef<int64_t> staticBasis,
+                                     bool hasOuterBound) {
+  SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
+                                              : staticBasis.size() + 1,
+                                linearIndex.getType());
+  build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
+        staticBasis);
 }
 
 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
                                      OperationState &odsState,
-                                     Value linearIndex, ValueRange basis) {
+                                     Value linearIndex, ValueRange basis,
+                                     bool hasOuterBound) {
   SmallVector<Value> dynamicBasis;
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
                              staticBasis);
-  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
+        hasOuterBound);
 }
 
 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
                                      OperationState &odsState,
                                      Value linearIndex,
-                                     ArrayRef<OpFoldResult> basis) {
+                                     ArrayRef<OpFoldResult> basis,
+                                     bool hasOuterBound) {
   SmallVector<Value> dynamicBasis;
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
-  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
+  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
+        hasOuterBound);
 }
 
 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
                                      OperationState &odsState,
-                                     Value linearIndex,
-                                     ArrayRef<int64_t> basis) {
-  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis);
+                                     Value linearIndex, ArrayRef<int64_t> basis,
+                                     bool hasOuterBound) {
+  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
 }
 
 LogicalResult AffineDelinearizeIndexOp::verify() {
-  if (getStaticBasis().empty())
-    return emitOpError("basis should not be empty");
-  if (getNumResults() != getStaticBasis().size())
-    return emitOpError("should return an index for each basis element");
-  auto dynamicMarkersCount =
-      llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
+  ArrayRef<int64_t> staticBasis = getStaticBasis();
+  if (getNumResults() != staticBasis.size() &&
+      getNumResults() != staticBasis.size() + 1)
+    return emitOpError("should return an index for each basis element and up "
+                       "to one extra index");
+
+  auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
   if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
     return emitOpError(
         "mismatch between dynamic and static basis (kDynamic marker but no "
         "corresponding dynamic basis entry) -- this can only happen due to an "
         "incorrect fold/rewrite");
+
+  if (!llvm::all_of(staticBasis, [](int64_t v) {
+        return v > 0 || ShapedType::isDynamic(v);
+      }))
+    return emitOpError("no basis element may be statically non-positive");
+
   return success();
 }
 
 LogicalResult
 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
                                SmallVectorImpl<OpFoldResult> &result) {
+  // If we won't be doing any division or modulo (no basis or the one basis
+  // element is purely advisory), simply return the input value.
+  if (getNumResults() == 1) {
+    result.push_back(getLinearIndex());
+    return success();
+  }
+
   if (adaptor.getLinearIndex() == nullptr)
     return failure();
 
@@ -4567,7 +4587,11 @@ AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
 
   int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
   Type attrType = getLinearIndex().getType();
-  for (int64_t modulus : llvm::reverse(getStaticBasis().drop_front())) {
+
+  ArrayRef<int64_t> staticBasis = getStaticBasis();
+  if (hasOuterBound())
+    staticBasis = staticBasis.drop_front();
+  for (int64_t modulus : llvm::reverse(staticBasis)) {
     result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
     highPart = llvm::divideFloorSigned(highPart, modulus);
   }
@@ -4576,6 +4600,20 @@ AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
   return success();
 }
 
+SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
+  OpBuilder builder(getContext());
+  if (hasOuterBound()) {
+    if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
+      return getMixedValues(getStaticBasis().drop_front(),
+                            getDynamicBasis().drop_front(), builder);
+
+    return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
+                          builder);
+  }
+
+  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+}
+
 namespace {
 
 // Drops delinearization indices that correspond to unit-extent basis
@@ -4594,24 +4632,25 @@ struct DropUnitExtentBasis
       return zero.value();
     };
 
+    bool hasOuterBound = delinearizeOp.hasOuterBound();
     // Replace all indices corresponding to unit-extent basis with 0.
     // Remaining basis can be used to get a new `affine.delinearize_index` op.
-    SmallVector<OpFoldResult> newOperands;
+    SmallVector<OpFoldResult> newBasis;
     for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
       std::optional<int64_t> basisVal = getConstantIntValue(basis);
       if (basisVal && *basisVal == 1)
-        replacements[index] = getZero();
+        replacements[index + (hasOuterBound ? 0 : 1)] = getZero();
       else
-        newOperands.push_back(basis);
+        newBasis.push_back(basis);
     }
 
-    if (newOperands.size() == delinearizeOp.getStaticBasis().size())
+    if (newBasis.size() == delinearizeOp.getStaticBasis().size())
       return rewriter.notifyMatchFailure(delinearizeOp,
                                          "no unit basis elements");
 
-    if (!newOperands.empty()) {
+    if (!newBasis.empty() || !hasOuterBound) {
       auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
-          loc, delinearizeOp.getLinearIndex(), newOperands);
+          loc, delinearizeOp.getLinearIndex(), newBasis, hasOuterBound);
       int newIndex = 0;
       // Map back the new delinearized indices to the values they replace.
       for (auto &replacement : replacements) {
@@ -4626,27 +4665,6 @@ struct DropUnitExtentBasis
   }
 };
 
-/// Drop delinearization with a single basis element
-///
-/// By definition, `delinearize_index %linear into (%basis)` is
-/// `%linear floorDiv 1` (since `1` is the product of the basis elememts,
-/// ignoring the 0th one, and since there is no previous division we need
-/// to use the remainder of). Therefore, a single-element `delinearize`
-/// can be replaced by the underlying linear index.
-struct DropDelinearizeOneBasisElement
-    : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
-                                PatternRewriter &rewriter) const override {
-    if (delinearizeOp.getStaticBasis().size() != 1)
-      return rewriter.notifyMatchFailure(delinearizeOp,
-                                         "doesn't have a length-1 basis");
-    rewriter.replaceOp(delinearizeOp, delinearizeOp.getLinearIndex());
-    return success();
-  }
-};
-
 /// If a `affine.delinearize_index`'s input is a `affine.linearize_index
 /// disjoint` and the two operations have the same basis, replace the
 /// delinearizeation results with the inputs of the `affine.linearize_index`
@@ -4668,7 +4686,7 @@ struct CancelDelinearizeOfLinearizeDisjointExact
                                          "index doesn't come from linearize");
 
     if (!linearizeOp.getDisjoint() ||
-        linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
+        linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
       return rewriter.notifyMatchFailure(
           linearizeOp, "not disjoint or basis doesn't match delinearize");
 
@@ -4680,8 +4698,9 @@ struct CancelDelinearizeOfLinearizeDisjointExact
 
 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
-  patterns.insert<CancelDelinearizeOfLinearizeDisjointExact,
-                  DropDelinearizeOneBasisElement, DropUnitExtentBasis>(context);
+  patterns
+      .insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis>(
+          context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -4718,11 +4737,11 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
 }
 
 LogicalResult AffineLinearizeIndexOp::verify() {
-  if (getStaticBasis().empty())
-    return emitOpError("basis should not be empty");
-
-  if (getMultiIndex().size() != getStaticBasis().size())
-    return emitOpError("should be passed an index for each basis element");
+  size_t numIndexes = getMultiIndex().size();
+  size_t numBasisElems = getStaticBasis().size();
+  if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
+    return emitOpError("should be passed a basis element for each index except "
+                       "possibly the first");
 
   auto dynamicMarkersCount =
       llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
@@ -4736,6 +4755,14 @@ LogicalResult AffineLinearizeIndexOp::verify() {
 }
 
 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
+  // No indices linearizes to zero.
+  if (getMultiIndex().empty())
+    return IntegerAttr::get(getResult().getType(), 0);
+
+  // One single index linearizes to itself.
+  if (getMultiIndex().size() == 1)
+    return getMultiIndex().front();
+
   if (llvm::any_of(adaptor.getMultiIndex(),
                    [](Attribute a) { return a == nullptr; }))
     return nullptr;
@@ -4745,16 +4772,35 @@ OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
 
   int64_t result = 0;
   int64_t stride = 1;
-  for (auto [indexAttr, length] :
-       llvm::zip_equal(llvm::reverse(adaptor.getMultiIndex()),
-                       llvm::reverse(getStaticBasis()))) {
+  for (auto [length, indexAttr] :
+       llvm::zip_first(llvm::reverse(getStaticBasis()),
+                       llvm::reverse(adaptor.getMultiIndex()))) {
     result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
     stride = stride * length;
   }
+  // Handle the index element with no basis element.
+  if (!hasOuterBound())
+    result =
+        result +
+        cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
 
   return IntegerAttr::get(getResult().getType(), result);
 }
 
+SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
+  OpBuilder builder(getContext());
+  if (hasOuterBound()) {
+    if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
+      return getMixedValues(getStaticBasis().drop_front(),
+                            getDynamicBasis().drop_front(), builder);
+
+    return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
+                          builder);
+  }
+
+  return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
+}
+
 namespace {
 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
@@ -4772,14 +4818,20 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
 
   LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    size_t numIndices = op.getMultiIndex().size();
+    ValueRange multiIndex = op.getMultiIndex();
+    size_t numIndices = multiIndex.size();
     SmallVector<Value> newIndices;
     newIndices.reserve(numIndices);
     SmallVector<OpFoldResult> newBasis;
     newBasis.reserve(numIndices);
 
+    if (!op.hasOuterBound()) {
+      newIndices.push_back(multiIndex.front());
+      multiIndex = multiIndex.drop_front();
+    }
+
     SmallVector<OpFoldResult> basis = op.getMixedBasis();
-    for (auto [index, basisElem] : llvm::zip_equal(op.getMultiIndex(), basis)) {
+    for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
       std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
       if (!basisEntry || *basisEntry != 1) {
         newIndices.push_back(index);
@@ -4808,23 +4860,6 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
   }
 };
 
-/// Rewrite `affine.linearize_index [%%x] by (%b)`, into `%x`.
-///
-/// By definition, that operation is `affine.apply affine_map<()[s0] -> (s0)>,`
-/// which is the identity.
-struct DropLinearizeOneBasisElement final
-    : OpRewritePattern<affine::AffineLinearizeIndexOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
-                                PatternRewriter &rewriter) const override {
-    if (op.getStaticBasis().size() != 1 || op.getMultiIndex().size() != 1)
-      return rewriter.notifyMatchFailure(op, "doesn't have a a length-1 basis");
-    rewriter.replaceOp(op, op.getMultiIndex().front());
-    return success();
-  }
-};
-
 /// Cancel out linearize_index(delinearize_index(x, B), B).
 ///
 /// That is, rewrite
@@ -4847,10 +4882,10 @@ struct CancelLinearizeOfDelinearizeExact final
       return rewriter.notifyMatchFailure(
           linearizeOp, "last entry doesn't come from a delinearize");
 
-    if (linearizeOp.getMixedBasis() != delinearizeOp.getMixedBasis())
+    if (linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
       return rewriter.notifyMatchFailure(
-          linearizeOp,
-          "basis of linearize and delinearize don't match exactly");
+          linearizeOp, "basis of linearize and delinearize don't match exactly "
+                       "(excluding outer bounds)");
 
     if (delinearizeOp.getResults() != linearizeOp.getMultiIndex())
       return rewriter.notifyMatchFailure(
@@ -4881,9 +4916,12 @@ struct DropLinearizeLeadingZero final
     }
 
     SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
+    ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
+    if (op.hasOuterBound())
+      newMixedBasis = newMixedBasis.drop_front();
+
     rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
-        op, op.getMultiIndex().drop_front(),
-        ArrayRef<OpFoldResult>(mixedBasis).drop_front(), op.getDisjoint());
+        op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
     return success();
   }
 };
@@ -4892,7 +4930,6 @@ struct DropLinearizeLeadingZero final
 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
-               DropLinearizeOneBasisElement,
                DropLinearizeUnitComponentsIfDisjointOrZero>(context);
 }
 

diff  --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index 1930e987a33ffa..15478e0e1e3a5b 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -36,8 +36,9 @@ struct LowerDelinearizeIndexOps
   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    FailureOr<SmallVector<Value>> multiIndex = delinearizeIndex(
-        rewriter, op->getLoc(), op.getLinearIndex(), op.getMixedBasis());
+    FailureOr<SmallVector<Value>> multiIndex =
+        delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
+                         op.getEffectiveBasis(), /*hasOuterBound=*/false);
     if (failed(multiIndex))
       return failure();
     rewriter.replaceOp(op, *multiIndex);
@@ -51,6 +52,12 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
+    // Should be folded away, included here for safety.
+    if (op.getMultiIndex().empty()) {
+      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+      return success();
+    }
+
     SmallVector<OpFoldResult> multiIndex =
         getAsOpFoldResult(op.getMultiIndex());
     OpFoldResult linearIndex =

diff  --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 7fe422f75c8fad..3420db771ef426 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -1944,11 +1944,14 @@ static FailureOr<OpFoldResult> composedAffineMultiply(OpBuilder &b,
 
 FailureOr<SmallVector<Value>>
 mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
-                               ArrayRef<Value> basis) {
+                               ArrayRef<Value> basis, bool hasOuterBound) {
+  if (hasOuterBound)
+    basis = basis.drop_front();
+
   // Note: the divisors are backwards due to the scan.
   SmallVector<Value> divisors;
   OpFoldResult basisProd = b.getIndexAttr(1);
-  for (OpFoldResult basisElem : llvm::reverse(basis.drop_front())) {
+  for (OpFoldResult basisElem : llvm::reverse(basis)) {
     FailureOr<OpFoldResult> nextProd =
         composedAffineMultiply(b, loc, basisElem, basisProd);
     if (failed(nextProd))
@@ -1971,11 +1974,15 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
 
 FailureOr<SmallVector<Value>>
 mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex,
-                               ArrayRef<OpFoldResult> basis) {
+                               ArrayRef<OpFoldResult> basis,
+                               bool hasOuterBound) {
+  if (hasOuterBound)
+    basis = basis.drop_front();
+
   // Note: the divisors are backwards due to the scan.
   SmallVector<Value> divisors;
   OpFoldResult basisProd = b.getIndexAttr(1);
-  for (OpFoldResult basisElem : llvm::reverse(basis.drop_front())) {
+  for (OpFoldResult basisElem : llvm::reverse(basis)) {
     FailureOr<OpFoldResult> nextProd =
         composedAffineMultiply(b, loc, basisElem, basisProd);
     if (failed(nextProd))
@@ -2005,8 +2012,15 @@ OpFoldResult mlir::affine::linearizeIndex(ArrayRef<OpFoldResult> multiIndex,
 OpFoldResult mlir::affine::linearizeIndex(OpBuilder &builder, Location loc,
                                           ArrayRef<OpFoldResult> multiIndex,
                                           ArrayRef<OpFoldResult> basis) {
-  assert(multiIndex.size() == basis.size());
+  assert(multiIndex.size() == basis.size() ||
+         multiIndex.size() == basis.size() + 1);
   SmallVector<AffineExpr> basisAffine;
+
+  // Add a fake initial size in order to make the later index linearization
+  // computations line up if an outer bound is not provided.
+  if (multiIndex.size() == basis.size() + 1)
+    basisAffine.push_back(getAffineConstantExpr(1, builder.getContext()));
+
   for (size_t i = 0; i < basis.size(); ++i) {
     basisAffine.push_back(getAffineSymbolExpr(i, builder.getContext()));
   }

diff  --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
index ded1687ca560b2..650555cfb5fe13 100644
--- a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
@@ -35,10 +35,10 @@ func.func @dynamic_basis(%linear_index: index, %src: memref<?x?x?xf32>) -> (inde
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
-  %b0 = memref.dim %src, %c0 : memref<?x?x?xf32>
   %b1 = memref.dim %src, %c1 : memref<?x?x?xf32>
   %b2 = memref.dim %src, %c2 : memref<?x?x?xf32>
-  %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index
+  // Note: no outer bound.
+  %1:3 = affine.delinearize_index %linear_index into (%b1, %b2) : index, index, index
   return %1#0, %1#1, %1#2 : index, index, index
 }
 
@@ -60,10 +60,11 @@ func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index {
 // CHECK-DAG: #[[$map0:.+]] =  affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s2 + s3 + s0 * (s2 * s4))>
 
 // CHECK-LABEL: @linearize_dynamic
-// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index, %[[arg5:.+]]: index)
-// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg5]], %[[arg2]], %[[arg4]]]
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index)
+// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg4]], %[[arg2]], %[[arg3]]]
 // CHECK: return %[[val_0]]
-func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> index {
-  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4, %arg5) : index
+func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> index {
+  // Note: no outer bounds
+  %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index
   func.return %0 : index
 }

diff  --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index ec00b31258d072..b54a13cffe7771 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1496,6 +1496,20 @@ func.func @delinearize_fold_negative_constant() -> (index, index, index) {
 
 // -----
 
+// CHECK-LABEL: @delinearize_fold_negative_constant_no_outer_bound
+// CHECK-DAG: %[[C_2:.+]] = arith.constant -2 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
+// CHECK-NOT: affine.delinearize_index
+// CHECK: return %[[C_2]], %[[C1]], %[[C3]]
+func.func @delinearize_fold_negative_constant_no_outer_bound() -> (index, index, index) {
+  %c_22 = arith.constant -22 : index
+  %0:3 = affine.delinearize_index %c_22 into (3, 5) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
 // CHECK-LABEL: @delinearize_dont_fold_constant_dynamic_basis
 // CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
 // CHECK: %[[RET:.+]]:3 = affine.delinearize_index %[[C22]]
@@ -1525,6 +1539,23 @@ func.func @drop_unit_basis_in_delinearize(%arg0 : index, %arg1 : index, %arg2 :
 
 // -----
 
+func.func @drop_unit_basis_in_delinearize_no_outer_bound(%arg0 : index, %arg1 : index, %arg2 : index) ->
+    (index, index, index, index, index, index) {
+  %c1 = arith.constant 1 : index
+  %0:6 = affine.delinearize_index %arg0 into (%arg1, 1, 1, %arg2, %c1)
+      : index, index, index, index, index, index
+  return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : index, index, index, index, index, index
+}
+// CHECK-LABEL: func @drop_unit_basis_in_delinearize_no_outer_bound(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[DELINEARIZE:.+]]:3 = affine.delinearize_index %[[ARG0]] into (%[[ARG1]], %[[ARG2]])
+//       CHECK:   return %[[DELINEARIZE]]#0, %[[DELINEARIZE]]#1, %[[C0]], %[[C0]], %[[DELINEARIZE]]#2, %[[C0]]
+
+// -----
+
 func.func @drop_all_unit_bases(%arg0 : index) -> (index, index) {
   %0:2 = affine.delinearize_index %arg0 into (1, 1) : index, index
   return %0#0, %0#1 : index, index
@@ -1537,6 +1568,18 @@ func.func @drop_all_unit_bases(%arg0 : index) -> (index, index) {
 
 // -----
 
+func.func @drop_all_unit_bases_no_outer_bound(%arg0 : index) -> (index, index, index) {
+  %0:3 = affine.delinearize_index %arg0 into (1, 1) : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+// CHECK-LABEL: func @drop_all_unit_bases_no_outer_bound(
+//  CHECK-SAME:     %[[ARG0:.+]]: index)
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-NOT:   affine.delinearize_index
+//       CHECK:   return %[[ARG0]], %[[C0]], %[[C0]]
+
+// -----
+
 func.func @drop_single_loop_delinearize(%arg0 : index, %arg1 : index) -> index {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -1574,6 +1617,17 @@ func.func @delinearize_non_loop_like(%arg0: memref<?xi32>, %i : index) -> index
 
 // -----
 
+// CHECK-LABEL: func @delinearize_empty_basis
+// CHECK-SAME: (%[[ARG0:.+]]: index)
+// CHECK-NOT: affine.delinearize
+// CHECK: return %[[ARG0]]
+func.func @delinearize_empty_basis(%arg0: index) -> index {
+  %0 = affine.delinearize_index %arg0 into () : index
+  return %0 : index
+}
+
+// -----
+
 // CHECK-LABEL: @linearize_fold_constants
 // CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
 // CHECK-NOT: affine.linearize
@@ -1588,6 +1642,42 @@ func.func @linearize_fold_constants() -> index {
 
 // -----
 
+// CHECK-LABEL: @linearize_fold_constants_no_outer_bound
+// CHECK-DAG: %[[C22:.+]] = arith.constant 22 : index
+// CHECK-NOT: affine.linearize
+// CHECK: return %[[C22]]
+func.func @linearize_fold_constants_no_outer_bound() -> index {
+  %c2 = arith.constant 2 : index
+  %c1 = arith.constant 1 : index
+
+  %ret = affine.linearize_index [%c1, %c1, %c2] by (3, 5) : index
+  return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_fold_empty_basis
+// CHECK-SAME: (%[[ARG0:.+]]: index)
+// CHECK-NOT: affine.linearize
+// CHECK: return %[[ARG0]]
+func.func @linearize_fold_empty_basis(%arg0: index) -> index {
+  %ret = affine.linearize_index [%arg0] by () : index
+  return %ret : index
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_fold_only_outer_bound
+// CHECK-SAME: (%[[ARG0:.+]]: index)
+// CHECK-NOT: affine.linearize
+// CHECK: return %[[ARG0]]
+func.func @linearize_fold_only_outer_bound(%arg0: index) -> index {
+  %ret = affine.linearize_index [%arg0] by (2) : index
+  return %ret : index
+}
+
+// -----
+
 // CHECK-LABEL: @linearize_dont_fold_dynamic_basis
 // CHECK: %[[RET:.+]] = affine.linearize_index
 // CHECK: return %[[RET]]
@@ -1617,6 +1707,38 @@ func.func @cancel_delinearize_linearize_disjoint_exact(%arg0: index, %arg1: inde
 
 // -----
 
+// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_linearize_extra_bound(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index)
+//       CHECK:     return %[[ARG0]], %[[ARG1]], %[[ARG2]]
+func.func @cancel_delinearize_linearize_disjoint_linearize_extra_bound(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (4, %arg4) : index
+  %1:3 = affine.delinearize_index %0 into (4, %arg4)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_delinearize_extra_bound(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9]+]]: index)
+//       CHECK:     return %[[ARG0]], %[[ARG1]], %[[ARG2]]
+func.func @cancel_delinearize_linearize_disjoint_delinearize_extra_bound(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
+  %0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (4, %arg4) : index
+  %1:3 = affine.delinearize_index %0 into (%arg3, 4, %arg4)
+      : index, index, index
+  return %1#0, %1#1, %1#2 : index, index, index
+}
+
+// -----
+
 // Without `disjoint`, the cancelation isn't guaranteed to be the identity.
 // CHECK-LABEL: func @no_cancel_delinearize_linearize_exact(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
@@ -1666,6 +1788,17 @@ func.func @linearize_unit_basis_disjoint(%arg0: index, %arg1: index, %arg2: inde
 
 // -----
 
+// CHECK-LABEL: @linearize_unit_basis_disjoint_no_outer_bound
+// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index)
+// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (%[[arg3]]) : index
+// CHECK: return %[[ret]]
+func.func @linearize_unit_basis_disjoint_no_outer_bound(%arg0: index, %arg1: index, %arg2: index, %arg3: index) -> index {
+  %ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (1, %arg3) : index
+  return %ret : index
+}
+
+// -----
+
 // CHECK-LABEL: @linearize_unit_basis_zero
 // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index)
 // CHECK: %[[ret:.+]] = affine.linearize_index [%[[arg0]], %[[arg1]]] by (3, %[[arg2]]) : index
@@ -1713,6 +1846,32 @@ func.func @cancel_linearize_denearize_exact(%arg0: index, %arg1: index, %arg2: i
 
 // -----
 
+// CHECK-LABEL: func @cancel_linearize_denearize_linearize_extra_bound(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     return %[[ARG0]]
+func.func @cancel_linearize_denearize_linearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0:3 = affine.delinearize_index %arg0 into (4, %arg2) : index, index, index
+  %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (%arg1, 4, %arg2) : index
+  return %1 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @cancel_linearize_denearize_delinearize_extra_bound(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: index)
+//       CHECK:     return %[[ARG0]]
+func.func @cancel_linearize_denearize_delinearize_extra_bound(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %0:3 = affine.delinearize_index %arg0 into (%arg1, 4, %arg2) : index, index, index
+  %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (4, %arg2) : index
+  return %1 : index
+}
+
+// -----
+
 // Don't cancel because the values from the delinearize aren't used in order
 // CHECK-LABEL: func @no_cancel_linearize_denearize_permuted(
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
@@ -1756,3 +1915,16 @@ func.func @affine_leading_zero(%arg0: index, %arg1: index) -> index {
   return %ret : index
 }
 
+// -----
+
+// CHECK-LABEL: func @affine_leading_zero_no_outer_bound(
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: index,
+//  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: index)
+//       CHECK:     %[[RET:.+]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (3, 5)
+//       CHECK:     return %[[RET]]
+func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> index {
+  %c0 = arith.constant 0 : index
+  %ret = affine.linearize_index [%c0, %arg0, %arg1] by (3, 5) : index
+  return %ret : index
+}
+

diff  --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir
index 2996194170900f..1539b4f4848276 100644
--- a/mlir/test/Dialect/Affine/invalid.mlir
+++ b/mlir/test/Dialect/Affine/invalid.mlir
@@ -533,37 +533,29 @@ func.func @missing_for_min(%arg0: index, %arg1: index, %arg2: memref<100xf32>) {
 // -----
 
 func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) {
-  // expected-error at +1 {{'affine.delinearize_index' op should return an index for each basis element}}
+  // expected-error at +1 {{'affine.delinearize_index' op should return an index for each basis element and up to one extra index}}
   %1 = affine.delinearize_index %idx into (%basis0, %basis1) : index
   return
 }
 
 // -----
 
-func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) {
-  // expected-error at +1 {{'affine.delinearize_index' op basis should not be empty}}
-  affine.delinearize_index %idx into () : index
+func.func @delinearize(%idx: index) {
+  // expected-error at +1 {{'affine.delinearize_index' op no basis element may be statically non-positive}}
+  %1:2 = affine.delinearize_index %idx into (2, -2) : index, index
   return
 }
 
 // -----
 
 func.func @linearize(%idx: index, %basis0: index, %basis1 :index) -> index {
-  // expected-error at +1 {{'affine.linearize_index' op should be passed an index for each basis element}}
+  // expected-error at +1 {{'affine.linearize_index' op should be passed a basis element for each index except possibly the first}}
   %0 = affine.linearize_index [%idx] by (%basis0, %basis1) : index
   return %0 : index
 }
 
 // -----
 
-func.func @linearize_empty() -> index {
-  // expected-error at +1 {{'affine.linearize_index' op basis should not be empty}}
-  %0 = affine.linearize_index [] by () : index
-  return %0 : index
-}
-
-// -----
-
 func.func @dynamic_dimension_index() {
   "unknown.region"() ({
     %idx = "unknown.test"() : () -> (index)

diff  --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py
index 7faae6ccedc972..7ef128c1724c4a 100644
--- a/mlir/test/python/dialects/affine.py
+++ b/mlir/test/python/dialects/affine.py
@@ -50,7 +50,7 @@ def testAffineDelinearizeInfer():
     # CHECK: %[[C1:.*]] = arith.constant 1 : index
     c1 = arith.ConstantOp(T.index(), 1)
     # CHECK: %{{.*}}:2 = affine.delinearize_index %[[C1:.*]] into (2, 3) : index, index
-    two_indices = affine.AffineDelinearizeIndexOp(c1, [], [2, 3])
+    two_indices = affine.AffineDelinearizeIndexOp([T.index()] * 2, c1, [], [2, 3])
 
 
 # CHECK-LABEL: TEST: testAffineLoadOp
@@ -157,7 +157,7 @@ def testAffineForOpErrors():
         )
 
     try:
-        two_indices = affine.AffineDelinearizeIndexOp(c1, [], [1, 1])
+        two_indices = affine.AffineDelinearizeIndexOp([T.index()] * 2, c1, [], [1, 1])
         affine.AffineForOp(
             two_indices,
             c2,


        


More information about the Mlir-commits mailing list