[Mlir-commits] [mlir] [MLIR][Affine] Add vector support to affine.linearize_index and affine.delinearize_index (PR #188369)

Keshav Vinayak Jha llvmlistbot at llvm.org
Thu Mar 26 21:35:14 PDT 2026


https://github.com/keshavvinayak01 updated https://github.com/llvm/llvm-project/pull/188369

>From 3f09929fea3253c7ec367e81bb93763bff32bd02 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 24 Mar 2026 22:08:13 +0000
Subject: [PATCH 01/11] [mlir][affine] Add vector support to
 affine.linearize_index and affine.delinearize_index

Allow affine.delinearize_index and affine.linearize_index to operate on
vector<...xindex> types in addition to scalar index. The basis remains
scalar (it describes the shape of the index space, not per-lane data).

This enables expressing element-wise index computations across vector
lanes directly, rather than manually lowering to vector.broadcast +
vector.step + arith patterns.

Changes:
- Add Affine_IndexOrVectorOfIndex type constraint in AffineOps.td
- Implement custom parse/print for both ops
- Add type consistency verifiers (all results/inputs must match)
- Update canonicalizers to produce vector zeros where needed
- Add vector lowering path in AffineExpandIndexOpsAsAffine using
  arith.divsi/muli/subi/addi (which natively support vectors)
- Scalar behavior is completely unchanged

Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>
Signed-off-by: Keshav Vinayak <keshavvinayak01 at gmail.com>
---
 .../mlir/Dialect/Affine/IR/AffineOps.td       |  27 ++-
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 170 +++++++++++++++++-
 .../AffineExpandIndexOpsAsAffine.cpp          | 141 +++++++++++++--
 3 files changed, 301 insertions(+), 37 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 9cb0f3242db17..8a7e49c05f526 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -32,6 +32,12 @@ def Affine_Dialect : Dialect {
 class Affine_Op<string mnemonic, list<Trait> traits = []> :
     Op<Affine_Dialect, mnemonic, traits>;
 
+// Type constraint for index-like types: index or vector of index.
+def Affine_IndexOrVectorOfIndex :
+    Type<Or<[Index.predicate,
+             VectorOfAnyRankOf<[Index]>.predicate]>,
+         "index or vector of index">;
+
 // Require regions to have affine.yield.
 def ImplicitAffineTerminator
     : SingleBlockImplicitTerminator<"AffineYieldOp">;
@@ -1118,16 +1124,12 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
     - that is, the product of all basis elements is positive as an `index` as well.
   }];
 
-  let arguments = (ins Index:$linear_index,
+  let arguments = (ins Affine_IndexOrVectorOfIndex:$linear_index,
     Variadic<Index>:$dynamic_basis,
     DenseI64ArrayAttr:$static_basis);
-  let results = (outs Variadic<Index>:$multi_index);
+  let results = (outs Variadic<Affine_IndexOrVectorOfIndex>:$multi_index);
 
-  let assemblyFormat = [{
-    $linear_index `into`
-    custom<DynamicIndexList>($dynamic_basis, $static_basis, "{}", "::mlir::AsmParser::Delimiter::Paren")
-    attr-dict `:` type($multi_index)
-  }];
+  let hasCustomAssemblyFormat = 1;
 
   let builders = [
     OpBuilder<(ins "Value":$linear_index, "ValueRange":$dynamic_basis, "ArrayRef<int64_t>":$static_basis, CArg<"bool", "true">:$hasOuterBound)>,
@@ -1221,18 +1223,13 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     ```
   }];
 
-  let arguments = (ins Variadic<Index>:$multi_index,
+  let arguments = (ins Variadic<Affine_IndexOrVectorOfIndex>:$multi_index,
     Variadic<Index>:$dynamic_basis,
     DenseI64ArrayAttr:$static_basis,
     UnitProp:$disjoint);
-  let results = (outs Index:$linear_index);
+  let results = (outs Affine_IndexOrVectorOfIndex:$linear_index);
 
-  let assemblyFormat = [{
-    (`disjoint` $disjoint^)? ` `
-    `[` $multi_index `]` `by`
-    custom<DynamicIndexList>($dynamic_basis, $static_basis, "{}", "::mlir::AsmParser::Delimiter::Paren")
-    attr-dict `:` type($linear_index)
-  }];
+  let hasCustomAssemblyFormat = 1;
 
   let builders = [
     OpBuilder<(ins "ValueRange":$multi_index, "ValueRange":$basis, CArg<"bool", "false">:$disjoint)>,
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 839d34b41cbd4..2060a74b061e7 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -21,6 +21,7 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
@@ -4855,6 +4856,58 @@ LogicalResult AffineVectorStoreOp::verify() {
 // DelinearizeIndexOp
 //===----------------------------------------------------------------------===//
 
+/// Parse format:
+///   affine.delinearize_index %idx into (%c4, %c8)
+///     : index, index         (scalar)
+///   affine.delinearize_index %vec into (%c4, %c8)
+///     : vector<16xindex>, vector<16xindex>  (vector)
+ParseResult AffineDelinearizeIndexOp::parse(OpAsmParser &parser,
+                                            OperationState &result) {
+  OpAsmParser::UnresolvedOperand linearIndex;
+  if (parser.parseOperand(linearIndex) || parser.parseKeyword("into"))
+    return failure();
+
+  SmallVector<OpAsmParser::UnresolvedOperand> dynamicBasis;
+  DenseI64ArrayAttr staticBasis;
+  if (parseDynamicIndexList(parser, dynamicBasis, staticBasis, nullptr,
+                            AsmParser::Delimiter::Paren))
+    return failure();
+
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  if (parser.parseColon())
+    return failure();
+
+  SmallVector<Type> resultTypes;
+  if (parser.parseTypeList(resultTypes))
+    return failure();
+
+  // Infer the linear index type from the first result type. All types must
+  // match (enforced by the verifier).
+  Type indexType = resultTypes.empty() ? IndexType::get(parser.getContext())
+                                       : resultTypes.front();
+  if (parser.resolveOperand(linearIndex, indexType, result.operands))
+    return failure();
+  if (parser.resolveOperands(dynamicBasis, IndexType::get(parser.getContext()),
+                             result.operands))
+    return failure();
+
+  result.addTypes(resultTypes);
+  result.getOrAddProperties<AffineDelinearizeIndexOp::Properties>()
+      .static_basis = staticBasis;
+  return success();
+}
+
+void AffineDelinearizeIndexOp::print(OpAsmPrinter &p) {
+  p << ' ' << getLinearIndex() << " into ";
+  printDynamicIndexList(p, *this, getDynamicBasis(), getStaticBasisAttr(),
+                        /*scalableFlags=*/{}, AsmParser::Delimiter::Paren);
+  p.printOptionalAttrDict((*this)->getAttrs(), {getStaticBasisAttrName()});
+  p << " : ";
+  llvm::interleaveComma(getResultTypes(), p);
+}
+
 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
                                      OperationState &odsState,
                                      Value linearIndex, ValueRange dynamicBasis,
@@ -4925,6 +4978,14 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
       }))
     return emitOpError("no basis element may be statically non-positive");
 
+  // All result types must match the input type.
+  Type inputType = getLinearIndex().getType();
+  for (Type resultType : getResultTypes()) {
+    if (resultType != inputType)
+      return emitOpError("result types must match the linear index type, got ")
+             << resultType << " vs " << inputType;
+  }
+
   return success();
 }
 
@@ -5036,9 +5097,17 @@ struct DropUnitExtentBasis
     SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
     std::optional<Value> zero = std::nullopt;
     Location loc = delinearizeOp->getLoc();
+    Type indexType = delinearizeOp.getLinearIndex().getType();
     auto getZero = [&]() -> Value {
-      if (!zero)
-        zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+      if (!zero) {
+        Value scalarZero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+        if (auto vecTy = dyn_cast<VectorType>(indexType))
+          zero = arith::ConstantOp::create(
+              rewriter, loc,
+              DenseElementsAttr::get(vecTy, rewriter.getIndexAttr(0)));
+        else
+          zero = scalarZero;
+      }
       return zero.value();
     };
 
@@ -5204,9 +5273,9 @@ struct SplitDelinearizeSpanningLastLinearizeArg final
           "need at least two elements to form the basis product");
 
     Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create(
-        rewriter, linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
-        linearizeOp.getDynamicBasis(), linearizeOp.getStaticBasis().drop_back(),
-        linearizeOp.getDisjoint());
+        rewriter, linearizeOp.getLoc(), linearizeOp.getLinearIndex().getType(),
+        linearizeOp.getMultiIndex().drop_back(), linearizeOp.getDynamicBasis(),
+        linearizeOp.getStaticBasis().drop_back(), linearizeOp.getDisjoint());
     auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create(
         rewriter, delinearizeOp.getLoc(), linearizeWithoutBack,
         delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
@@ -5236,6 +5305,69 @@ void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
 // LinearizeIndexOp
 //===----------------------------------------------------------------------===//
 
+/// Parse format:
+///   affine.linearize_index [%x, %y] by (%c4, %c8) : index
+///   affine.linearize_index disjoint [%v0, %v1] by (%c4, %c8)
+///     : vector<16xindex>
+ParseResult AffineLinearizeIndexOp::parse(OpAsmParser &parser,
+                                          OperationState &result) {
+  bool disjoint = succeeded(parser.parseOptionalKeyword("disjoint"));
+
+  SmallVector<OpAsmParser::UnresolvedOperand> multiIndex;
+  if (parser.parseOperandList(multiIndex, AsmParser::Delimiter::Square) ||
+      parser.parseKeyword("by"))
+    return failure();
+
+  SmallVector<OpAsmParser::UnresolvedOperand> dynamicBasis;
+  DenseI64ArrayAttr staticBasis;
+  if (parseDynamicIndexList(parser, dynamicBasis, staticBasis, nullptr,
+                            AsmParser::Delimiter::Paren))
+    return failure();
+
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+
+  Type resultType;
+  if (parser.parseColonType(resultType))
+    return failure();
+
+  if (parser.resolveOperands(multiIndex, resultType, result.operands))
+    return failure();
+  if (parser.resolveOperands(dynamicBasis, IndexType::get(parser.getContext()),
+                             result.operands))
+    return failure();
+
+  result.addTypes(resultType);
+  auto &props = result.getOrAddProperties<AffineLinearizeIndexOp::Properties>();
+  props.static_basis = staticBasis;
+  props.disjoint = disjoint;
+  props.operandSegmentSizes = {static_cast<int32_t>(multiIndex.size()),
+                               static_cast<int32_t>(dynamicBasis.size())};
+  return success();
+}
+
+void AffineLinearizeIndexOp::print(OpAsmPrinter &p) {
+  if (getDisjoint())
+    p << " disjoint";
+  p << " [";
+  llvm::interleaveComma(getMultiIndex(), p);
+  p << "] by ";
+  printDynamicIndexList(p, *this, getDynamicBasis(), getStaticBasisAttr(),
+                        /*scalableFlags=*/{}, AsmParser::Delimiter::Paren);
+  p.printOptionalAttrDict(
+      (*this)->getAttrs(),
+      {getStaticBasisAttrName(), getOperandSegmentSizesAttrName()});
+  p << " : " << getLinearIndex().getType();
+}
+
+/// Infer the index type from a set of multi-index values. Returns the common
+/// type (index or vector<...xindex>), or IndexType if the set is empty.
+static Type inferIndexType(MLIRContext *ctx, ValueRange multiIndex) {
+  if (multiIndex.empty())
+    return IndexType::get(ctx);
+  return multiIndex.front().getType();
+}
+
 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
                                    OperationState &odsState,
                                    ValueRange multiIndex, ValueRange basis,
@@ -5246,7 +5378,9 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
                              staticBasis);
-  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
+  Type resultType = inferIndexType(odsBuilder.getContext(), multiIndex);
+  build(odsBuilder, odsState, resultType, multiIndex, dynamicBasis, staticBasis,
+        disjoint);
 }
 
 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
@@ -5259,14 +5393,18 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
   SmallVector<Value> dynamicBasis;
   SmallVector<int64_t> staticBasis;
   dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
-  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
+  Type resultType = inferIndexType(odsBuilder.getContext(), multiIndex);
+  build(odsBuilder, odsState, resultType, multiIndex, dynamicBasis, staticBasis,
+        disjoint);
 }
 
 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
                                    OperationState &odsState,
                                    ValueRange multiIndex,
                                    ArrayRef<int64_t> basis, bool disjoint) {
-  build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
+  Type resultType = inferIndexType(odsBuilder.getContext(), multiIndex);
+  build(odsBuilder, odsState, resultType, multiIndex, ValueRange{}, basis,
+        disjoint);
 }
 
 LogicalResult AffineLinearizeIndexOp::verify() {
@@ -5284,6 +5422,14 @@ LogicalResult AffineLinearizeIndexOp::verify() {
         "corresponding dynamic basis entry) -- this can only happen due to an "
         "incorrect fold/rewrite");
 
+  // All multi_index types must match the result type.
+  Type resultType = getLinearIndex().getType();
+  for (Value idx : getMultiIndex()) {
+    if (idx.getType() != resultType)
+      return emitOpError("multi_index types must match the result type, got ")
+             << idx.getType() << " vs " << resultType;
+  }
+
   return success();
 }
 
@@ -5402,7 +5548,13 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
                                          "no unit basis entries to replace");
 
     if (newIndices.empty()) {
-      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+      Type resultType = op.getLinearIndex().getType();
+      if (auto vecTy = dyn_cast<VectorType>(resultType)) {
+        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+            op, DenseElementsAttr::get(vecTy, rewriter.getIndexAttr(0)));
+      } else {
+        rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+      }
       return success();
     }
     rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
index e919bc6d36265..0178c5159df53 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
@@ -15,7 +15,9 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
 #include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
@@ -29,6 +31,33 @@ using namespace mlir;
 using namespace mlir::affine;
 
 namespace {
+
+/// Create a constant splat of the given type with the given integer value.
+static Value createTypedConstant(OpBuilder &b, Location loc, Type type,
+                                 int64_t value) {
+  if (auto vecTy = dyn_cast<VectorType>(type))
+    return arith::ConstantOp::create(
+        b, loc, DenseElementsAttr::get(vecTy, b.getIndexAttr(value)));
+  return arith::ConstantIndexOp::create(b, loc, value);
+}
+
+/// Materialize an OpFoldResult (which represents a scalar index or constant)
+/// as a Value matching the given target type. For vector target types, scalar
+/// constants are splatted. Returns failure for dynamic basis with vector types
+/// since that requires vector.broadcast which is not available here.
+static FailureOr<Value> materializeBasis(OpBuilder &b, Location loc,
+                                         OpFoldResult ofr, Type targetType) {
+  std::optional<int64_t> cst = getConstantIntValue(ofr);
+  if (cst)
+    return createTypedConstant(b, loc, targetType, *cst);
+  // Dynamic scalar basis value. For scalar target types, return as-is.
+  if (isa<IndexType>(targetType))
+    return getValueOrCreateConstantIndexOp(b, loc, ofr);
+  // Dynamic scalar basis with vector target type -- would need
+  // vector.broadcast, bail out.
+  return failure();
+}
+
 /// Lowers `affine.delinearize_index` into a sequence of division and remainder
 /// operations.
 struct LowerDelinearizeIndexOps
@@ -36,12 +65,51 @@ struct LowerDelinearizeIndexOps
   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    FailureOr<SmallVector<Value>> multiIndex =
-        delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
-                         op.getEffectiveBasis(), /*hasOuterBound=*/false);
-    if (failed(multiIndex))
-      return failure();
-    rewriter.replaceOp(op, *multiIndex);
+    // For scalar types, use the existing affine lowering path.
+    if (isa<IndexType>(op.getLinearIndex().getType())) {
+      FailureOr<SmallVector<Value>> multiIndex =
+          delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
+                           op.getEffectiveBasis(), /*hasOuterBound=*/false);
+      if (failed(multiIndex))
+        return failure();
+      rewriter.replaceOp(op, *multiIndex);
+      return success();
+    }
+
+    // Vector lowering: emit arith div/rem ops (which work element-wise on
+    // vectors).
+    Location loc = op.getLoc();
+    Value linearIndex = op.getLinearIndex();
+    Type type = linearIndex.getType();
+    SmallVector<OpFoldResult> basis = op.getEffectiveBasis();
+
+    // Compute cumulative products of basis from the right. These serve as
+    // divisors: for basis (B0, B1, B2), the divisors are (B1*B2, B2).
+    SmallVector<Value> divisors;
+    Value cumulativeProd = createTypedConstant(rewriter, loc, type, 1);
+    for (OpFoldResult basisElem : llvm::reverse(basis)) {
+      FailureOr<Value> basisVal =
+          materializeBasis(rewriter, loc, basisElem, type);
+      if (failed(basisVal))
+        return failure();
+      cumulativeProd =
+          arith::MulIOp::create(rewriter, loc, cumulativeProd, *basisVal);
+      divisors.push_back(cumulativeProd);
+    }
+
+    // Emit div/mod pairs from the most-significant dimension to the least.
+    SmallVector<Value> results;
+    results.reserve(divisors.size() + 1);
+    Value residual = linearIndex;
+    for (Value divisor : llvm::reverse(divisors)) {
+      Value quotient = arith::DivSIOp::create(rewriter, loc, residual, divisor);
+      Value product = arith::MulIOp::create(rewriter, loc, quotient, divisor);
+      Value remainder = arith::SubIOp::create(rewriter, loc, residual, product);
+      results.push_back(quotient);
+      residual = remainder;
+    }
+    results.push_back(residual);
+    rewriter.replaceOp(op, results);
     return success();
   }
 };
@@ -58,13 +126,60 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
       return success();
     }
 
-    SmallVector<OpFoldResult> multiIndex =
-        getAsOpFoldResult(op.getMultiIndex());
-    OpFoldResult linearIndex =
-        linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
-    Value linearIndexValue =
-        getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
-    rewriter.replaceOp(op, linearIndexValue);
+    // For scalar types, use the existing affine lowering path.
+    if (isa<IndexType>(op.getLinearIndex().getType())) {
+      SmallVector<OpFoldResult> multiIndex =
+          getAsOpFoldResult(op.getMultiIndex());
+      OpFoldResult linearIndex =
+          linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
+      Value linearIndexValue =
+          getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
+      rewriter.replaceOp(op, linearIndexValue);
+      return success();
+    }
+
+    // Vector lowering: emit arith ops (which work element-wise on vectors).
+    //
+    // linearize_index [i0, i1, ..., iN-1] by (B0, B1, ..., BN-1)
+    // = i0 * stride_0 + i1 * stride_1 + ... + iN-1
+    // where stride_k = B_{k+1} * B_{k+2} * ... * B_{N-1}
+    //
+    // We compute from the back: result = iN-1, stride = 1, then:
+    //   stride *= B_{k}, result += i_k * stride
+    Location loc = op.getLoc();
+    Type type = op.getLinearIndex().getType();
+    SmallVector<OpFoldResult> effectiveBasis = op.getEffectiveBasis();
+    ValueRange indices = op.getMultiIndex();
+
+    // effectiveBasis drops the outer bound. For indices [i0, i1, ..., iN-1]:
+    //   no outer bound:  effectiveBasis = [B1, B2, ..., BN-1] (N-1 elems)
+    //   has outer bound: effectiveBasis = [B0, B1, ..., BN-1] (N elems,
+    //                    but B0 is advisory, dropped by getEffectiveBasis)
+    //
+    // Computation: result = iN-1 + BN-1 * (iN-2 + BN-2 * (... + B1 * i0))
+    // Or equivalently, accumulate from back:
+    //   result = iN-1
+    //   stride = 1
+    //   for k = numBasis-1 downto 0:
+    //     stride *= effectiveBasis[k]
+    //     result += indices[k] * stride
+    //
+    // This works because effectiveBasis[k] is the "size" of dimension k+1,
+    // and indices[k] is paired with the product of all sizes after it.
+    Value result = indices.back();
+    Value stride = createTypedConstant(rewriter, loc, type, 1);
+
+    for (int i = static_cast<int>(effectiveBasis.size()) - 1; i >= 0; --i) {
+      FailureOr<Value> basisVal =
+          materializeBasis(rewriter, loc, effectiveBasis[i], type);
+      if (failed(basisVal))
+        return failure();
+      stride = arith::MulIOp::create(rewriter, loc, stride, *basisVal);
+      Value term = arith::MulIOp::create(rewriter, loc, indices[i], stride);
+      result = arith::AddIOp::create(rewriter, loc, term, result);
+    }
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };

>From 68c2a367c58b466115430007a16f29d500729bb7 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 24 Mar 2026 23:46:15 +0000
Subject: [PATCH 02/11] [mlir][affine] Use declarative assemblyFormat for
 linearize/delinearize index ops

Replace custom parse/print with declarative assemblyFormat using
TypesMatchWith traits for type inference:
- delinearize: infer linear_index type from first result type
- linearize: infer multi_index types from result type

The assembly format is unchanged - no existing .mlir files need updating.

Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 .../mlir/Dialect/Affine/IR/AffineOps.td       |  23 +++-
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 108 ------------------
 2 files changed, 19 insertions(+), 112 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 8a7e49c05f526..3f59e008b2a7d 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1069,7 +1069,10 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
 // AffineDelinearizeIndexOp
 //===----------------------------------------------------------------------===//
 
-def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
+def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
+    [Pure, TypesMatchWith<"linear_index type must match result types",
+                          "multi_index", "linear_index",
+                          "$_self[0]">]> {
   let summary = "delinearize an index";
   let description = [{
     The `affine.delinearize_index` operation takes a single index value and
@@ -1129,7 +1132,11 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
     DenseI64ArrayAttr:$static_basis);
   let results = (outs Variadic<Affine_IndexOrVectorOfIndex>:$multi_index);
 
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = [{
+    $linear_index `into`
+    custom<DynamicIndexList>($dynamic_basis, $static_basis, "{}", "::mlir::AsmParser::Delimiter::Paren")
+    attr-dict `:` type($multi_index)
+  }];
 
   let builders = [
     OpBuilder<(ins "Value":$linear_index, "ValueRange":$dynamic_basis, "ArrayRef<int64_t>":$static_basis, CArg<"bool", "true">:$hasOuterBound)>,
@@ -1169,7 +1176,10 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
 // AffineLinearizeIndexOp
 //===----------------------------------------------------------------------===//
 def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
-    [Pure, AttrSizedOperandSegments]> {
+    [Pure, AttrSizedOperandSegments,
+     TypesMatchWith<"multi_index types must match result type",
+                    "linear_index", "multi_index", "$_self",
+                    "[](::mlir::Type a, ::mlir::TypeRange b) { return llvm::all_of(b, [a](::mlir::Type t) { return t == a; }); }">]> {
   let summary = "linearize an index";
   let description = [{
     The `affine.linearize_index` operation takes a sequence of index values and a
@@ -1229,7 +1239,12 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     UnitProp:$disjoint);
   let results = (outs Affine_IndexOrVectorOfIndex:$linear_index);
 
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = [{
+    (`disjoint` $disjoint^)? ` `
+    `[` $multi_index `]` `by`
+    custom<DynamicIndexList>($dynamic_basis, $static_basis, "{}", "::mlir::AsmParser::Delimiter::Paren")
+    attr-dict `:` type($linear_index)
+  }];
 
   let builders = [
     OpBuilder<(ins "ValueRange":$multi_index, "ValueRange":$basis, CArg<"bool", "false">:$disjoint)>,
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 2060a74b061e7..eacf5a3bb74fb 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -21,7 +21,6 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/ValueBoundsOpInterface.h"
-#include "mlir/Interfaces/ViewLikeInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
@@ -4856,58 +4855,6 @@ LogicalResult AffineVectorStoreOp::verify() {
 // DelinearizeIndexOp
 //===----------------------------------------------------------------------===//
 
-/// Parse format:
-///   affine.delinearize_index %idx into (%c4, %c8)
-///     : index, index         (scalar)
-///   affine.delinearize_index %vec into (%c4, %c8)
-///     : vector<16xindex>, vector<16xindex>  (vector)
-ParseResult AffineDelinearizeIndexOp::parse(OpAsmParser &parser,
-                                            OperationState &result) {
-  OpAsmParser::UnresolvedOperand linearIndex;
-  if (parser.parseOperand(linearIndex) || parser.parseKeyword("into"))
-    return failure();
-
-  SmallVector<OpAsmParser::UnresolvedOperand> dynamicBasis;
-  DenseI64ArrayAttr staticBasis;
-  if (parseDynamicIndexList(parser, dynamicBasis, staticBasis, nullptr,
-                            AsmParser::Delimiter::Paren))
-    return failure();
-
-  if (parser.parseOptionalAttrDict(result.attributes))
-    return failure();
-
-  if (parser.parseColon())
-    return failure();
-
-  SmallVector<Type> resultTypes;
-  if (parser.parseTypeList(resultTypes))
-    return failure();
-
-  // Infer the linear index type from the first result type. All types must
-  // match (enforced by the verifier).
-  Type indexType = resultTypes.empty() ? IndexType::get(parser.getContext())
-                                       : resultTypes.front();
-  if (parser.resolveOperand(linearIndex, indexType, result.operands))
-    return failure();
-  if (parser.resolveOperands(dynamicBasis, IndexType::get(parser.getContext()),
-                             result.operands))
-    return failure();
-
-  result.addTypes(resultTypes);
-  result.getOrAddProperties<AffineDelinearizeIndexOp::Properties>()
-      .static_basis = staticBasis;
-  return success();
-}
-
-void AffineDelinearizeIndexOp::print(OpAsmPrinter &p) {
-  p << ' ' << getLinearIndex() << " into ";
-  printDynamicIndexList(p, *this, getDynamicBasis(), getStaticBasisAttr(),
-                        /*scalableFlags=*/{}, AsmParser::Delimiter::Paren);
-  p.printOptionalAttrDict((*this)->getAttrs(), {getStaticBasisAttrName()});
-  p << " : ";
-  llvm::interleaveComma(getResultTypes(), p);
-}
-
 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
                                      OperationState &odsState,
                                      Value linearIndex, ValueRange dynamicBasis,
@@ -5305,61 +5252,6 @@ void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
 // LinearizeIndexOp
 //===----------------------------------------------------------------------===//
 
-/// Parse format:
-///   affine.linearize_index [%x, %y] by (%c4, %c8) : index
-///   affine.linearize_index disjoint [%v0, %v1] by (%c4, %c8)
-///     : vector<16xindex>
-ParseResult AffineLinearizeIndexOp::parse(OpAsmParser &parser,
-                                          OperationState &result) {
-  bool disjoint = succeeded(parser.parseOptionalKeyword("disjoint"));
-
-  SmallVector<OpAsmParser::UnresolvedOperand> multiIndex;
-  if (parser.parseOperandList(multiIndex, AsmParser::Delimiter::Square) ||
-      parser.parseKeyword("by"))
-    return failure();
-
-  SmallVector<OpAsmParser::UnresolvedOperand> dynamicBasis;
-  DenseI64ArrayAttr staticBasis;
-  if (parseDynamicIndexList(parser, dynamicBasis, staticBasis, nullptr,
-                            AsmParser::Delimiter::Paren))
-    return failure();
-
-  if (parser.parseOptionalAttrDict(result.attributes))
-    return failure();
-
-  Type resultType;
-  if (parser.parseColonType(resultType))
-    return failure();
-
-  if (parser.resolveOperands(multiIndex, resultType, result.operands))
-    return failure();
-  if (parser.resolveOperands(dynamicBasis, IndexType::get(parser.getContext()),
-                             result.operands))
-    return failure();
-
-  result.addTypes(resultType);
-  auto &props = result.getOrAddProperties<AffineLinearizeIndexOp::Properties>();
-  props.static_basis = staticBasis;
-  props.disjoint = disjoint;
-  props.operandSegmentSizes = {static_cast<int32_t>(multiIndex.size()),
-                               static_cast<int32_t>(dynamicBasis.size())};
-  return success();
-}
-
-void AffineLinearizeIndexOp::print(OpAsmPrinter &p) {
-  if (getDisjoint())
-    p << " disjoint";
-  p << " [";
-  llvm::interleaveComma(getMultiIndex(), p);
-  p << "] by ";
-  printDynamicIndexList(p, *this, getDynamicBasis(), getStaticBasisAttr(),
-                        /*scalableFlags=*/{}, AsmParser::Delimiter::Paren);
-  p.printOptionalAttrDict(
-      (*this)->getAttrs(),
-      {getStaticBasisAttrName(), getOperandSegmentSizesAttrName()});
-  p << " : " << getLinearIndex().getType();
-}
-
 /// Infer the index type from a set of multi-index values. Returns the common
 /// type (index or vector<...xindex>), or IndexType if the set is empty.
 static Type inferIndexType(MLIRContext *ctx, ValueRange multiIndex) {

>From fd2b27f247a0f8845ec45855a3861cc7ff8aa439 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 24 Mar 2026 23:57:25 +0000
Subject: [PATCH 03/11] [mlir][affine] Add lit tests for vector
 linearize/delinearize index ops

Test coverage for vector support added in the previous commit:
- Roundtrip: 2D/3D static basis, dynamic basis, disjoint flag
- Canonicalization: linearize/delinearize cancellation, unit-extent
  dropping, single-result/single-index folding
- Lowering: vector delinearize to arith div/rem, vector linearize
  to arith mul/add, 3D case, scalar path unchanged
- Linearize->offset->delinearize pattern (vector.gather use case)

Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 mlir/test/Dialect/Affine/ops.mlir             |  42 ++++++
 .../test/Dialect/Affine/vector-index-ops.mlir | 128 ++++++++++++++++++
 2 files changed, 170 insertions(+)
 create mode 100644 mlir/test/Dialect/Affine/vector-index-ops.mlir

diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index 1562f5b1693c0..35b07c1c7fe1f 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -316,6 +316,48 @@ func.func @linearize_mixed(%index0: index, %index1: index, %index2: index, %basi
   return %1 : index
 }
 
+// CHECK-LABEL: @delinearize_vector
+func.func @delinearize_vector(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>) {
+  // CHECK: affine.delinearize_index %{{.+}} into (4, 8) : vector<16xindex>, vector<16xindex>
+  %0:2 = affine.delinearize_index %vec into (4, 8) : vector<16xindex>, vector<16xindex>
+  return %0#0, %0#1 : vector<16xindex>, vector<16xindex>
+}
+
+// CHECK-LABEL: @delinearize_vector_3d
+func.func @delinearize_vector_3d(%vec: vector<8xindex>) -> (vector<8xindex>, vector<8xindex>, vector<8xindex>) {
+  // CHECK: affine.delinearize_index %{{.+}} into (2, 3, 4) : vector<8xindex>, vector<8xindex>, vector<8xindex>
+  %0:3 = affine.delinearize_index %vec into (2, 3, 4) : vector<8xindex>, vector<8xindex>, vector<8xindex>
+  return %0#0, %0#1, %0#2 : vector<8xindex>, vector<8xindex>, vector<8xindex>
+}
+
+// CHECK-LABEL: @delinearize_vector_dynamic_basis
+func.func @delinearize_vector_dynamic_basis(%vec: vector<4xindex>, %b0: index, %b1: index) -> (vector<4xindex>, vector<4xindex>) {
+  // CHECK: affine.delinearize_index %{{.+}} into (%{{.+}}, %{{.+}}) : vector<4xindex>, vector<4xindex>
+  %0:2 = affine.delinearize_index %vec into (%b0, %b1) : vector<4xindex>, vector<4xindex>
+  return %0#0, %0#1 : vector<4xindex>, vector<4xindex>
+}
+
+// CHECK-LABEL: @linearize_vector
+func.func @linearize_vector(%v0: vector<16xindex>, %v1: vector<16xindex>) -> vector<16xindex> {
+  // CHECK: affine.linearize_index [%{{.+}}, %{{.+}}] by (4, 8) : vector<16xindex>
+  %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<16xindex>
+  return %0 : vector<16xindex>
+}
+
+// CHECK-LABEL: @linearize_vector_disjoint
+func.func @linearize_vector_disjoint(%v0: vector<16xindex>, %v1: vector<16xindex>) -> vector<16xindex> {
+  // CHECK: affine.linearize_index disjoint [%{{.+}}, %{{.+}}] by (4, 8) : vector<16xindex>
+  %0 = affine.linearize_index disjoint [%v0, %v1] by (4, 8) : vector<16xindex>
+  return %0 : vector<16xindex>
+}
+
+// CHECK-LABEL: @linearize_vector_3d
+func.func @linearize_vector_3d(%v0: vector<8xindex>, %v1: vector<8xindex>, %v2: vector<8xindex>) -> vector<8xindex> {
+  // CHECK: affine.linearize_index [%{{.+}}, %{{.+}}, %{{.+}}] by (2, 3, 4) : vector<8xindex>
+  %0 = affine.linearize_index [%v0, %v1, %v2] by (2, 3, 4) : vector<8xindex>
+  return %0 : vector<8xindex>
+}
+
 // -----
 
 // CHECK-LABEL: @gpu_launch_affine
diff --git a/mlir/test/Dialect/Affine/vector-index-ops.mlir b/mlir/test/Dialect/Affine/vector-index-ops.mlir
new file mode 100644
index 0000000000000..2c6fbc20a9316
--- /dev/null
+++ b/mlir/test/Dialect/Affine/vector-index-ops.mlir
@@ -0,0 +1,128 @@
+// RUN: mlir-opt %s -split-input-file --canonicalize | FileCheck %s --check-prefix=CANON
+// RUN: mlir-opt %s -split-input-file --affine-expand-index-ops-as-affine | FileCheck %s --check-prefix=EXPAND
+
+// Canonicalization: cancel disjoint linearize -> delinearize on vectors.
+
+// CANON-LABEL: @cancel_linearize_delinearize_vector
+// CANON-SAME:    (%[[V0:.+]]: vector<16xindex>, %[[V1:.+]]: vector<16xindex>)
+// CANON:         return %[[V0]], %[[V1]]
+func.func @cancel_linearize_delinearize_vector(%v0: vector<16xindex>, %v1: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>) {
+  %0 = affine.linearize_index disjoint [%v0, %v1] by (4, 8) : vector<16xindex>
+  %1:2 = affine.delinearize_index %0 into (4, 8) : vector<16xindex>, vector<16xindex>
+  return %1#0, %1#1 : vector<16xindex>, vector<16xindex>
+}
+
+// -----
+// Canonicalization: drop unit-extent basis on vector delinearize.
+
+// CANON-LABEL: @drop_unit_extent_vector
+// CANON-SAME:    (%[[VEC:.+]]: vector<16xindex>)
+// CANON-DAG:     %[[ZERO:.+]] = arith.constant dense<0> : vector<16xindex>
+// CANON:         %[[R:.+]]:2 = affine.delinearize_index %[[VEC]] into (4, 8) : vector<16xindex>, vector<16xindex>
+// CANON:         return %[[R]]#0, %[[ZERO]], %[[R]]#1
+func.func @drop_unit_extent_vector(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>, vector<16xindex>) {
+  %0:3 = affine.delinearize_index %vec into (4, 1, 8) : vector<16xindex>, vector<16xindex>, vector<16xindex>
+  return %0#0, %0#1, %0#2 : vector<16xindex>, vector<16xindex>, vector<16xindex>
+}
+
+// -----
+// Canonicalization: drop unit-extent basis on vector linearize.
+
+// CANON-LABEL: @drop_unit_linearize_vector
+// CANON-SAME:    (%[[V0:.+]]: vector<8xindex>, %[[V1:.+]]: vector<8xindex>)
+// CANON:         return %[[V0]]
+func.func @drop_unit_linearize_vector(%v0: vector<8xindex>, %v1: vector<8xindex>) -> vector<8xindex> {
+  %0 = affine.linearize_index disjoint [%v0, %v1] by (4, 1) : vector<8xindex>
+  return %0 : vector<8xindex>
+}
+
+// -----
+// Canonicalization: fold single-result vector delinearize to identity.
+
+// CANON-LABEL: @fold_single_result_vector
+// CANON-SAME:    (%[[VEC:.+]]: vector<4xindex>)
+// CANON:         return %[[VEC]]
+func.func @fold_single_result_vector(%vec: vector<4xindex>) -> vector<4xindex> {
+  %0:1 = affine.delinearize_index %vec into () : vector<4xindex>
+  return %0#0 : vector<4xindex>
+}
+
+// -----
+// Canonicalization: fold single-index vector linearize to identity.
+
+// CANON-LABEL: @fold_single_index_vector
+// CANON-SAME:    (%[[VEC:.+]]: vector<4xindex>)
+// CANON:         return %[[VEC]]
+func.func @fold_single_index_vector(%vec: vector<4xindex>) -> vector<4xindex> {
+  %0 = affine.linearize_index [%vec] by () : vector<4xindex>
+  return %0 : vector<4xindex>
+}
+
+// -----
+// Expansion: vector delinearize lowers to arith div/rem.
+
+// EXPAND-LABEL: @expand_delinearize_vector
+// EXPAND-SAME:    (%[[VEC:.+]]: vector<16xindex>)
+// EXPAND-DAG:     %[[C8:.+]] = arith.constant dense<8> : vector<16xindex>
+// EXPAND:         %[[DIV:.+]] = arith.divsi %[[VEC]], %[[C8]]
+// EXPAND:         %[[MUL:.+]] = arith.muli %[[DIV]], %[[C8]]
+// EXPAND:         %[[REM:.+]] = arith.subi %[[VEC]], %[[MUL]]
+// EXPAND:         return %[[DIV]], %[[REM]]
+func.func @expand_delinearize_vector(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>) {
+  %0:2 = affine.delinearize_index %vec into (4, 8) : vector<16xindex>, vector<16xindex>
+  return %0#0, %0#1 : vector<16xindex>, vector<16xindex>
+}
+
+// -----
+// Expansion: vector linearize lowers to arith mul/add.
+
+// EXPAND-LABEL: @expand_linearize_vector
+// EXPAND-SAME:    (%[[V0:.+]]: vector<16xindex>, %[[V1:.+]]: vector<16xindex>)
+// EXPAND-DAG:     %[[C8:.+]] = arith.constant dense<8> : vector<16xindex>
+// EXPAND:         %[[MUL:.+]] = arith.muli %[[V0]], %[[C8]]
+// EXPAND:         %[[ADD:.+]] = arith.addi %[[MUL]], %[[V1]]
+// EXPAND:         return %[[ADD]]
+func.func @expand_linearize_vector(%v0: vector<16xindex>, %v1: vector<16xindex>) -> vector<16xindex> {
+  %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<16xindex>
+  return %0 : vector<16xindex>
+}
+
+// -----
+// Expansion: 3D vector delinearize.
+
+// EXPAND-LABEL: @expand_delinearize_vector_3d
+// EXPAND-SAME:    (%[[VEC:.+]]: vector<16xindex>)
+// EXPAND-DAG:     %[[C4:.+]] = arith.constant dense<4> : vector<16xindex>
+// EXPAND-DAG:     %[[C12:.+]] = arith.constant dense<12> : vector<16xindex>
+// EXPAND:         %[[D0:.+]] = arith.divsi %[[VEC]], %[[C12]]
+// EXPAND:         %[[M0:.+]] = arith.muli %[[D0]], %[[C12]]
+// EXPAND:         %[[R0:.+]] = arith.subi %[[VEC]], %[[M0]]
+// EXPAND:         %[[D1:.+]] = arith.divsi %[[R0]], %[[C4]]
+// EXPAND:         %[[M1:.+]] = arith.muli %[[D1]], %[[C4]]
+// EXPAND:         %[[R1:.+]] = arith.subi %[[R0]], %[[M1]]
+// EXPAND:         return %[[D0]], %[[D1]], %[[R1]]
+func.func @expand_delinearize_vector_3d(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>, vector<16xindex>) {
+  %0:3 = affine.delinearize_index %vec into (2, 3, 4) : vector<16xindex>, vector<16xindex>, vector<16xindex>
+  return %0#0, %0#1, %0#2 : vector<16xindex>, vector<16xindex>, vector<16xindex>
+}
+
+// -----
+// Expansion: vector linearize -> offset -> delinearize pattern
+// (as would be used in vector.gather lowering).
+
+// EXPAND-LABEL: @vector_linearize_offset_delinearize
+// EXPAND-SAME:    (%[[V0:.+]]: vector<4xindex>, %[[V1:.+]]: vector<4xindex>, %[[OFF:.+]]: vector<4xindex>)
+// EXPAND-DAG:     %[[C8:.+]] = arith.constant dense<8> : vector<4xindex>
+// EXPAND:         %[[LIN:.+]] = arith.muli %[[V0]], %[[C8]]
+// EXPAND:         %[[LIN2:.+]] = arith.addi %[[LIN]], %[[V1]]
+// EXPAND:         %[[FLAT:.+]] = arith.addi %[[LIN2]], %[[OFF]]
+// EXPAND:         %[[DIV:.+]] = arith.divsi %[[FLAT]], %[[C8]]
+// EXPAND:         %[[MUL:.+]] = arith.muli %[[DIV]], %[[C8]]
+// EXPAND:         %[[REM:.+]] = arith.subi %[[FLAT]], %[[MUL]]
+// EXPAND:         return %[[DIV]], %[[REM]]
+func.func @vector_linearize_offset_delinearize(%v0: vector<4xindex>, %v1: vector<4xindex>, %offsets: vector<4xindex>) -> (vector<4xindex>, vector<4xindex>) {
+  %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<4xindex>
+  %1 = arith.addi %0, %offsets : vector<4xindex>
+  %2:2 = affine.delinearize_index %1 into (4, 8) : vector<4xindex>, vector<4xindex>
+  return %2#0, %2#1 : vector<4xindex>, vector<4xindex>
+}

>From 554a15883d096fceb858e5d851e41772fc6ace2c Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Wed, 25 Mar 2026 06:45:38 +0000
Subject: [PATCH 04/11] [mlir][affine] Add lit tests for vector
 linearize/delinearize index ops

Add vector test cases to existing test files:
- ops.mlir: roundtrip tests (2D/3D, dynamic basis, disjoint)
- canonicalize.mlir: cancel, unit-extent drop, single-result fold
- affine-expand-index-ops-as-affine.mlir: lowering to arith ops,
  3D case, linearize->offset->delinearize pattern

Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 .../affine-expand-index-ops-as-affine.mlir    |  67 +++++++++
 mlir/test/Dialect/Affine/canonicalize.mlir    |  53 ++++++++
 .../test/Dialect/Affine/vector-index-ops.mlir | 128 ------------------
 3 files changed, 120 insertions(+), 128 deletions(-)
 delete mode 100644 mlir/test/Dialect/Affine/vector-index-ops.mlir

diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
index bf9f00da5793a..21595356936fa 100644
--- a/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
@@ -68,3 +68,70 @@ func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: in
   %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index
   func.return %0 : index
 }
+
+// -----
+
+// CHECK-LABEL: @expand_delinearize_vector
+// CHECK-SAME:    (%[[VEC:.+]]: vector<16xindex>)
+// CHECK-DAG:     %[[C8:.+]] = arith.constant dense<8> : vector<16xindex>
+// CHECK:         %[[DIV:.+]] = arith.divsi %[[VEC]], %[[C8]]
+// CHECK:         %[[MUL:.+]] = arith.muli %[[DIV]], %[[C8]]
+// CHECK:         %[[REM:.+]] = arith.subi %[[VEC]], %[[MUL]]
+// CHECK:         return %[[DIV]], %[[REM]]
+func.func @expand_delinearize_vector(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>) {
+  %0:2 = affine.delinearize_index %vec into (4, 8) : vector<16xindex>, vector<16xindex>
+  return %0#0, %0#1 : vector<16xindex>, vector<16xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @expand_linearize_vector
+// CHECK-SAME:    (%[[V0:.+]]: vector<16xindex>, %[[V1:.+]]: vector<16xindex>)
+// CHECK-DAG:     %[[C8:.+]] = arith.constant dense<8> : vector<16xindex>
+// CHECK:         %[[MUL:.+]] = arith.muli %[[V0]], %[[C8]]
+// CHECK:         %[[ADD:.+]] = arith.addi %[[MUL]], %[[V1]]
+// CHECK:         return %[[ADD]]
+func.func @expand_linearize_vector(%v0: vector<16xindex>, %v1: vector<16xindex>) -> vector<16xindex> {
+  %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<16xindex>
+  return %0 : vector<16xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @expand_delinearize_vector_3d
+// CHECK-SAME:    (%[[VEC:.+]]: vector<16xindex>)
+// CHECK-DAG:     %[[C4:.+]] = arith.constant dense<4> : vector<16xindex>
+// CHECK-DAG:     %[[C12:.+]] = arith.constant dense<12> : vector<16xindex>
+// CHECK:         %[[D0:.+]] = arith.divsi %[[VEC]], %[[C12]]
+// CHECK:         %[[M0:.+]] = arith.muli %[[D0]], %[[C12]]
+// CHECK:         %[[R0:.+]] = arith.subi %[[VEC]], %[[M0]]
+// CHECK:         %[[D1:.+]] = arith.divsi %[[R0]], %[[C4]]
+// CHECK:         %[[M1:.+]] = arith.muli %[[D1]], %[[C4]]
+// CHECK:         %[[R1:.+]] = arith.subi %[[R0]], %[[M1]]
+// CHECK:         return %[[D0]], %[[D1]], %[[R1]]
+func.func @expand_delinearize_vector_3d(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>, vector<16xindex>) {
+  %0:3 = affine.delinearize_index %vec into (2, 3, 4) : vector<16xindex>, vector<16xindex>, vector<16xindex>
+  return %0#0, %0#1, %0#2 : vector<16xindex>, vector<16xindex>, vector<16xindex>
+}
+
+// -----
+
+// Vector linearize -> offset -> delinearize pattern
+// (as would be used in vector.gather lowering).
+
+// CHECK-LABEL: @vector_linearize_offset_delinearize
+// CHECK-SAME:    (%[[V0:.+]]: vector<4xindex>, %[[V1:.+]]: vector<4xindex>, %[[OFF:.+]]: vector<4xindex>)
+// CHECK-DAG:     %[[C8:.+]] = arith.constant dense<8> : vector<4xindex>
+// CHECK:         %[[LIN:.+]] = arith.muli %[[V0]], %[[C8]]
+// CHECK:         %[[LIN2:.+]] = arith.addi %[[LIN]], %[[V1]]
+// CHECK:         %[[FLAT:.+]] = arith.addi %[[LIN2]], %[[OFF]]
+// CHECK:         %[[DIV:.+]] = arith.divsi %[[FLAT]], %[[C8]]
+// CHECK:         %[[MUL:.+]] = arith.muli %[[DIV]], %[[C8]]
+// CHECK:         %[[REM:.+]] = arith.subi %[[FLAT]], %[[MUL]]
+// CHECK:         return %[[DIV]], %[[REM]]
+func.func @vector_linearize_offset_delinearize(%v0: vector<4xindex>, %v1: vector<4xindex>, %offsets: vector<4xindex>) -> (vector<4xindex>, vector<4xindex>) {
+  %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<4xindex>
+  %1 = arith.addi %0, %offsets : vector<4xindex>
+  %2:2 = affine.delinearize_index %1 into (4, 8) : vector<4xindex>, vector<4xindex>
+  return %2#0, %2#1 : vector<4xindex>, vector<4xindex>
+}
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 5a0a2b004433e..008c4babb9408 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -2429,3 +2429,56 @@ func.func @linearize_dont_fold_poison_index(%arg0: index) -> index {
   %ret = affine.linearize_index [%poison, %arg0] by (%c4) : index
   return %ret : index
 }
+
+// -----
+
+// CHECK-LABEL: @cancel_linearize_delinearize_vector
+// CHECK-SAME:    (%[[V0:.+]]: vector<16xindex>, %[[V1:.+]]: vector<16xindex>)
+// CHECK:         return %[[V0]], %[[V1]]
+func.func @cancel_linearize_delinearize_vector(%v0: vector<16xindex>, %v1: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>) {
+  %0 = affine.linearize_index disjoint [%v0, %v1] by (4, 8) : vector<16xindex>
+  %1:2 = affine.delinearize_index %0 into (4, 8) : vector<16xindex>, vector<16xindex>
+  return %1#0, %1#1 : vector<16xindex>, vector<16xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @drop_unit_extent_vector
+// CHECK-SAME:    (%[[VEC:.+]]: vector<16xindex>)
+// CHECK-DAG:     %[[ZERO:.+]] = arith.constant dense<0> : vector<16xindex>
+// CHECK:         %[[R:.+]]:2 = affine.delinearize_index %[[VEC]] into (4, 8) : vector<16xindex>, vector<16xindex>
+// CHECK:         return %[[R]]#0, %[[ZERO]], %[[R]]#1
+func.func @drop_unit_extent_vector(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>, vector<16xindex>) {
+  %0:3 = affine.delinearize_index %vec into (4, 1, 8) : vector<16xindex>, vector<16xindex>, vector<16xindex>
+  return %0#0, %0#1, %0#2 : vector<16xindex>, vector<16xindex>, vector<16xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @drop_unit_linearize_vector
+// CHECK-SAME:    (%[[V0:.+]]: vector<8xindex>, %[[V1:.+]]: vector<8xindex>)
+// CHECK:         return %[[V0]]
+func.func @drop_unit_linearize_vector(%v0: vector<8xindex>, %v1: vector<8xindex>) -> vector<8xindex> {
+  %0 = affine.linearize_index disjoint [%v0, %v1] by (4, 1) : vector<8xindex>
+  return %0 : vector<8xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_single_result_vector
+// CHECK-SAME:    (%[[VEC:.+]]: vector<4xindex>)
+// CHECK:         return %[[VEC]]
+func.func @fold_single_result_vector(%vec: vector<4xindex>) -> vector<4xindex> {
+  %0:1 = affine.delinearize_index %vec into () : vector<4xindex>
+  return %0#0 : vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_single_index_vector
+// CHECK-SAME:    (%[[VEC:.+]]: vector<4xindex>)
+// CHECK:         return %[[VEC]]
+func.func @fold_single_index_vector(%vec: vector<4xindex>) -> vector<4xindex> {
+  %0 = affine.linearize_index [%vec] by () : vector<4xindex>
+  return %0 : vector<4xindex>
+}
diff --git a/mlir/test/Dialect/Affine/vector-index-ops.mlir b/mlir/test/Dialect/Affine/vector-index-ops.mlir
deleted file mode 100644
index 2c6fbc20a9316..0000000000000
--- a/mlir/test/Dialect/Affine/vector-index-ops.mlir
+++ /dev/null
@@ -1,128 +0,0 @@
-// RUN: mlir-opt %s -split-input-file --canonicalize | FileCheck %s --check-prefix=CANON
-// RUN: mlir-opt %s -split-input-file --affine-expand-index-ops-as-affine | FileCheck %s --check-prefix=EXPAND
-
-// Canonicalization: cancel disjoint linearize -> delinearize on vectors.
-
-// CANON-LABEL: @cancel_linearize_delinearize_vector
-// CANON-SAME:    (%[[V0:.+]]: vector<16xindex>, %[[V1:.+]]: vector<16xindex>)
-// CANON:         return %[[V0]], %[[V1]]
-func.func @cancel_linearize_delinearize_vector(%v0: vector<16xindex>, %v1: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>) {
-  %0 = affine.linearize_index disjoint [%v0, %v1] by (4, 8) : vector<16xindex>
-  %1:2 = affine.delinearize_index %0 into (4, 8) : vector<16xindex>, vector<16xindex>
-  return %1#0, %1#1 : vector<16xindex>, vector<16xindex>
-}
-
-// -----
-// Canonicalization: drop unit-extent basis on vector delinearize.
-
-// CANON-LABEL: @drop_unit_extent_vector
-// CANON-SAME:    (%[[VEC:.+]]: vector<16xindex>)
-// CANON-DAG:     %[[ZERO:.+]] = arith.constant dense<0> : vector<16xindex>
-// CANON:         %[[R:.+]]:2 = affine.delinearize_index %[[VEC]] into (4, 8) : vector<16xindex>, vector<16xindex>
-// CANON:         return %[[R]]#0, %[[ZERO]], %[[R]]#1
-func.func @drop_unit_extent_vector(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>, vector<16xindex>) {
-  %0:3 = affine.delinearize_index %vec into (4, 1, 8) : vector<16xindex>, vector<16xindex>, vector<16xindex>
-  return %0#0, %0#1, %0#2 : vector<16xindex>, vector<16xindex>, vector<16xindex>
-}
-
-// -----
-// Canonicalization: drop unit-extent basis on vector linearize.
-
-// CANON-LABEL: @drop_unit_linearize_vector
-// CANON-SAME:    (%[[V0:.+]]: vector<8xindex>, %[[V1:.+]]: vector<8xindex>)
-// CANON:         return %[[V0]]
-func.func @drop_unit_linearize_vector(%v0: vector<8xindex>, %v1: vector<8xindex>) -> vector<8xindex> {
-  %0 = affine.linearize_index disjoint [%v0, %v1] by (4, 1) : vector<8xindex>
-  return %0 : vector<8xindex>
-}
-
-// -----
-// Canonicalization: fold single-result vector delinearize to identity.
-
-// CANON-LABEL: @fold_single_result_vector
-// CANON-SAME:    (%[[VEC:.+]]: vector<4xindex>)
-// CANON:         return %[[VEC]]
-func.func @fold_single_result_vector(%vec: vector<4xindex>) -> vector<4xindex> {
-  %0:1 = affine.delinearize_index %vec into () : vector<4xindex>
-  return %0#0 : vector<4xindex>
-}
-
-// -----
-// Canonicalization: fold single-index vector linearize to identity.
-
-// CANON-LABEL: @fold_single_index_vector
-// CANON-SAME:    (%[[VEC:.+]]: vector<4xindex>)
-// CANON:         return %[[VEC]]
-func.func @fold_single_index_vector(%vec: vector<4xindex>) -> vector<4xindex> {
-  %0 = affine.linearize_index [%vec] by () : vector<4xindex>
-  return %0 : vector<4xindex>
-}
-
-// -----
-// Expansion: vector delinearize lowers to arith div/rem.
-
-// EXPAND-LABEL: @expand_delinearize_vector
-// EXPAND-SAME:    (%[[VEC:.+]]: vector<16xindex>)
-// EXPAND-DAG:     %[[C8:.+]] = arith.constant dense<8> : vector<16xindex>
-// EXPAND:         %[[DIV:.+]] = arith.divsi %[[VEC]], %[[C8]]
-// EXPAND:         %[[MUL:.+]] = arith.muli %[[DIV]], %[[C8]]
-// EXPAND:         %[[REM:.+]] = arith.subi %[[VEC]], %[[MUL]]
-// EXPAND:         return %[[DIV]], %[[REM]]
-func.func @expand_delinearize_vector(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>) {
-  %0:2 = affine.delinearize_index %vec into (4, 8) : vector<16xindex>, vector<16xindex>
-  return %0#0, %0#1 : vector<16xindex>, vector<16xindex>
-}
-
-// -----
-// Expansion: vector linearize lowers to arith mul/add.
-
-// EXPAND-LABEL: @expand_linearize_vector
-// EXPAND-SAME:    (%[[V0:.+]]: vector<16xindex>, %[[V1:.+]]: vector<16xindex>)
-// EXPAND-DAG:     %[[C8:.+]] = arith.constant dense<8> : vector<16xindex>
-// EXPAND:         %[[MUL:.+]] = arith.muli %[[V0]], %[[C8]]
-// EXPAND:         %[[ADD:.+]] = arith.addi %[[MUL]], %[[V1]]
-// EXPAND:         return %[[ADD]]
-func.func @expand_linearize_vector(%v0: vector<16xindex>, %v1: vector<16xindex>) -> vector<16xindex> {
-  %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<16xindex>
-  return %0 : vector<16xindex>
-}
-
-// -----
-// Expansion: 3D vector delinearize.
-
-// EXPAND-LABEL: @expand_delinearize_vector_3d
-// EXPAND-SAME:    (%[[VEC:.+]]: vector<16xindex>)
-// EXPAND-DAG:     %[[C4:.+]] = arith.constant dense<4> : vector<16xindex>
-// EXPAND-DAG:     %[[C12:.+]] = arith.constant dense<12> : vector<16xindex>
-// EXPAND:         %[[D0:.+]] = arith.divsi %[[VEC]], %[[C12]]
-// EXPAND:         %[[M0:.+]] = arith.muli %[[D0]], %[[C12]]
-// EXPAND:         %[[R0:.+]] = arith.subi %[[VEC]], %[[M0]]
-// EXPAND:         %[[D1:.+]] = arith.divsi %[[R0]], %[[C4]]
-// EXPAND:         %[[M1:.+]] = arith.muli %[[D1]], %[[C4]]
-// EXPAND:         %[[R1:.+]] = arith.subi %[[R0]], %[[M1]]
-// EXPAND:         return %[[D0]], %[[D1]], %[[R1]]
-func.func @expand_delinearize_vector_3d(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>, vector<16xindex>) {
-  %0:3 = affine.delinearize_index %vec into (2, 3, 4) : vector<16xindex>, vector<16xindex>, vector<16xindex>
-  return %0#0, %0#1, %0#2 : vector<16xindex>, vector<16xindex>, vector<16xindex>
-}
-
-// -----
-// Expansion: vector linearize -> offset -> delinearize pattern
-// (as would be used in vector.gather lowering).
-
-// EXPAND-LABEL: @vector_linearize_offset_delinearize
-// EXPAND-SAME:    (%[[V0:.+]]: vector<4xindex>, %[[V1:.+]]: vector<4xindex>, %[[OFF:.+]]: vector<4xindex>)
-// EXPAND-DAG:     %[[C8:.+]] = arith.constant dense<8> : vector<4xindex>
-// EXPAND:         %[[LIN:.+]] = arith.muli %[[V0]], %[[C8]]
-// EXPAND:         %[[LIN2:.+]] = arith.addi %[[LIN]], %[[V1]]
-// EXPAND:         %[[FLAT:.+]] = arith.addi %[[LIN2]], %[[OFF]]
-// EXPAND:         %[[DIV:.+]] = arith.divsi %[[FLAT]], %[[C8]]
-// EXPAND:         %[[MUL:.+]] = arith.muli %[[DIV]], %[[C8]]
-// EXPAND:         %[[REM:.+]] = arith.subi %[[FLAT]], %[[MUL]]
-// EXPAND:         return %[[DIV]], %[[REM]]
-func.func @vector_linearize_offset_delinearize(%v0: vector<4xindex>, %v1: vector<4xindex>, %offsets: vector<4xindex>) -> (vector<4xindex>, vector<4xindex>) {
-  %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<4xindex>
-  %1 = arith.addi %0, %offsets : vector<4xindex>
-  %2:2 = affine.delinearize_index %1 into (4, 8) : vector<4xindex>, vector<4xindex>
-  return %2#0, %2#1 : vector<4xindex>, vector<4xindex>
-}

>From 1393eb0291ea07f5a97d34df8a9df85202f2ac09 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Wed, 25 Mar 2026 16:08:48 +0000
Subject: [PATCH 05/11] Bracketed conditionals

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index eacf5a3bb74fb..92f2015022120 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -5048,12 +5048,13 @@ struct DropUnitExtentBasis
     auto getZero = [&]() -> Value {
       if (!zero) {
         Value scalarZero = arith::ConstantIndexOp::create(rewriter, loc, 0);
-        if (auto vecTy = dyn_cast<VectorType>(indexType))
+        if (auto vecTy = dyn_cast<VectorType>(indexType)) {
           zero = arith::ConstantOp::create(
               rewriter, loc,
               DenseElementsAttr::get(vecTy, rewriter.getIndexAttr(0)));
-        else
+        } else {
           zero = scalarZero;
+        }
       }
       return zero.value();
     };

>From 806e784b08a1eae4a001492f58555d118f6efc21 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Thu, 26 Mar 2026 19:40:09 +0000
Subject: [PATCH 06/11] [mlir][affine] Move type checking to TypesMatchWith
 traits, use declarative assemblyFormat
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Replace custom parse/print and manual verifier type checks with
declarative assemblyFormat and TypesMatchWith traits:
- delinearize: two TypesMatchWith traits — one for parser type
  inference (infers linear_index from first result), one for
  verification (all results match linear_index).
- linearize: one TypesMatchWith trait for both inference and
  verification (all multi_index types match result).

Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 .../mlir/Dialect/Affine/IR/AffineOps.td        | 18 ++++++++++++++----
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp       | 16 ----------------
 2 files changed, 14 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 3f59e008b2a7d..21290b61dd4e1 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1070,9 +1070,17 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
 //===----------------------------------------------------------------------===//
 
 def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
-    [Pure, TypesMatchWith<"linear_index type must match result types",
-                          "multi_index", "linear_index",
-                          "$_self[0]">]> {
+    [Pure,
+     // Infer linear_index type from the first result type during parsing.
+     TypesMatchWith<"linear_index type must match result types",
+                    "multi_index", "linear_index", "$_self[0]">,
+     // Verify all result types match the linear_index type.
+     TypesMatchWith<"all result types must match linear_index type",
+                    "linear_index", "multi_index", "$_self",
+                    "[](::mlir::Type a, ::mlir::TypeRange b) { "
+                    "return llvm::all_of(b, [a](::mlir::Type t) { "
+                    "return t == a; }); }">
+    ]> {
   let summary = "delinearize an index";
   let description = [{
     The `affine.delinearize_index` operation takes a single index value and
@@ -1179,7 +1187,9 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
     [Pure, AttrSizedOperandSegments,
      TypesMatchWith<"multi_index types must match result type",
                     "linear_index", "multi_index", "$_self",
-                    "[](::mlir::Type a, ::mlir::TypeRange b) { return llvm::all_of(b, [a](::mlir::Type t) { return t == a; }); }">]> {
+                    "[](::mlir::Type a, ::mlir::TypeRange b) { "
+                    "return llvm::all_of(b, [a](::mlir::Type t) { "
+                    "return t == a; }); }">]> {
   let summary = "linearize an index";
   let description = [{
     The `affine.linearize_index` operation takes a sequence of index values and a
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 92f2015022120..b13c4817f1428 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4925,14 +4925,6 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
       }))
     return emitOpError("no basis element may be statically non-positive");
 
-  // All result types must match the input type.
-  Type inputType = getLinearIndex().getType();
-  for (Type resultType : getResultTypes()) {
-    if (resultType != inputType)
-      return emitOpError("result types must match the linear index type, got ")
-             << resultType << " vs " << inputType;
-  }
-
   return success();
 }
 
@@ -5315,14 +5307,6 @@ LogicalResult AffineLinearizeIndexOp::verify() {
         "corresponding dynamic basis entry) -- this can only happen due to an "
         "incorrect fold/rewrite");
 
-  // All multi_index types must match the result type.
-  Type resultType = getLinearIndex().getType();
-  for (Value idx : getMultiIndex()) {
-    if (idx.getType() != resultType)
-      return emitOpError("multi_index types must match the result type, got ")
-             << idx.getType() << " vs " << resultType;
-  }
-
   return success();
 }
 

>From 45bcb561e8be049c0504a5879f5a954eb8d2ac12 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Thu, 26 Mar 2026 21:50:19 +0000
Subject: [PATCH 07/11] [mlir][affine] Address review: use getZeroAttr, add
 vector support to AffineExpandIndexOps

- Use rewriter.getZeroAttr(Type) instead of manual DenseElementsAttr
  construction in canonicalizers.
- Add vector support to AffineExpandIndexOps.cpp (the primary lowering
  pass) using vector.broadcast to splat scalar strides to vector type.
- Revert AffineExpandIndexOpsAsAffine.cpp vector lowering: this pass
  lowers to affine.apply which is scalar-only, so vector ops are left
  unconverted.
- Move vector expansion tests from affine-expand-index-ops-as-affine.mlir
  to affine-expand-index-ops.mlir.

Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 .../mlir/Dialect/Affine/Transforms/Passes.td  |   1 +
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      |  27 +---
 .../Transforms/AffineExpandIndexOps.cpp       |  28 +++-
 .../AffineExpandIndexOpsAsAffine.cpp          | 151 +++---------------
 .../affine-expand-index-ops-as-affine.mlir    |  67 --------
 .../Affine/affine-expand-index-ops.mlir       |  47 ++++++
 6 files changed, 104 insertions(+), 217 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td b/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td
index 430edffc29038..db3dd0544c7c2 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td
@@ -433,6 +433,7 @@ def SimplifyAffineMinMaxPass : InterfacePass<"affine-simplify-min-max", "Functio
 def AffineExpandIndexOps : Pass<"affine-expand-index-ops"> {
   let summary = "Lower affine operations operating on indices into more fundamental operations";
   let constructor = "mlir::affine::createAffineExpandIndexOpsPass()";
+  let dependentDialects = ["vector::VectorDialect"];
 }
 
 def AffineExpandIndexOpsAsAffine : Pass<"affine-expand-index-ops-as-affine"> {
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index b13c4817f1428..c2d20271a815f 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -3961,9 +3961,8 @@ void AffinePrefetchOp::print(OpAsmPrinter &p) {
       (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
   if (mapAttr)
     p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
-  p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
-    << "locality<" << getLocalityHint() << ">, "
-    << (getIsDataCache() ? "data" : "instr");
+  p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", " << "locality<"
+    << getLocalityHint() << ">, " << (getIsDataCache() ? "data" : "instr");
   p.printOptionalAttrDict(
       (*this)->getAttrs(),
       /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
@@ -5038,16 +5037,9 @@ struct DropUnitExtentBasis
     Location loc = delinearizeOp->getLoc();
     Type indexType = delinearizeOp.getLinearIndex().getType();
     auto getZero = [&]() -> Value {
-      if (!zero) {
-        Value scalarZero = arith::ConstantIndexOp::create(rewriter, loc, 0);
-        if (auto vecTy = dyn_cast<VectorType>(indexType)) {
-          zero = arith::ConstantOp::create(
-              rewriter, loc,
-              DenseElementsAttr::get(vecTy, rewriter.getIndexAttr(0)));
-        } else {
-          zero = scalarZero;
-        }
-      }
+      if (!zero)
+        zero = arith::ConstantOp::create(rewriter, loc,
+                                         rewriter.getZeroAttr(indexType));
       return zero.value();
     };
 
@@ -5425,13 +5417,8 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
                                          "no unit basis entries to replace");
 
     if (newIndices.empty()) {
-      Type resultType = op.getLinearIndex().getType();
-      if (auto vecTy = dyn_cast<VectorType>(resultType)) {
-        rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-            op, DenseElementsAttr::get(vecTy, rewriter.getIndexAttr(0)));
-      } else {
-        rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
-      }
+      rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+          op, rewriter.getZeroAttr(op.getLinearIndex().getType()));
       return success();
     }
     rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
index e1317b3f78b05..66c48d89068ec 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp
@@ -15,6 +15,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
@@ -83,6 +84,15 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
   return result;
 }
 
+/// Broadcast a scalar value to match the given type. If the type is already
+/// scalar, returns the value as-is. For vector types, uses vector.broadcast.
+static Value broadcastToMatchType(RewriterBase &rewriter, Location loc,
+                                  Value value, Type targetType) {
+  if (value.getType() == targetType)
+    return value;
+  return vector::BroadcastOp::create(rewriter, loc, targetType, value);
+}
+
 LogicalResult
 affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
                                       AffineDelinearizeIndexOp op) {
@@ -104,7 +114,14 @@ affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
       computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
                      /*knownNonNegative=*/true);
 
-  Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+  // Broadcast strides and zero to match the linear index type (needed for
+  // vector types where the strides are scalar but the index is a vector).
+  Type indexType = linearIdx.getType();
+  for (Value &stride : strides)
+    stride = broadcastToMatchType(rewriter, loc, stride, indexType);
+
+  Value zero =
+      arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(indexType));
 
   Value initialPart =
       arith::FloorDivSIOp::create(rewriter, loc, linearIdx, strides.front());
@@ -146,12 +163,14 @@ LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
                                                   AffineLinearizeIndexOp op) {
   // Should be folded away, included here for safety.
   if (op.getMultiIndex().empty()) {
-    rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+        op, rewriter.getZeroAttr(op.getLinearIndex().getType()));
     return success();
   }
 
   Location loc = op.getLoc();
   ValueRange multiIndex = op.getMultiIndex();
+  Type indexType = op.getLinearIndex().getType();
   size_t numIndexes = multiIndex.size();
   ArrayRef<int64_t> staticBasis = op.getStaticBasis();
   if (numIndexes == staticBasis.size())
@@ -160,6 +179,11 @@ LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
   SmallVector<Value> strides =
       computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
                      /*knownNonNegative=*/op.getDisjoint());
+
+  // Broadcast strides to match the index type (needed for vector types).
+  for (Value &stride : strides)
+    stride = broadcastToMatchType(rewriter, loc, stride, indexType);
+
   SmallVector<std::pair<Value, int64_t>> scaledValues;
   scaledValues.reserve(numIndexes);
 
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
index 0178c5159df53..fa3fc03520cc0 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
@@ -15,9 +15,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
 #include "mlir/Dialect/Affine/Utils.h"
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
@@ -31,33 +29,6 @@ using namespace mlir;
 using namespace mlir::affine;
 
 namespace {
-
-/// Create a constant splat of the given type with the given integer value.
-static Value createTypedConstant(OpBuilder &b, Location loc, Type type,
-                                 int64_t value) {
-  if (auto vecTy = dyn_cast<VectorType>(type))
-    return arith::ConstantOp::create(
-        b, loc, DenseElementsAttr::get(vecTy, b.getIndexAttr(value)));
-  return arith::ConstantIndexOp::create(b, loc, value);
-}
-
-/// Materialize an OpFoldResult (which represents a scalar index or constant)
-/// as a Value matching the given target type. For vector target types, scalar
-/// constants are splatted. Returns failure for dynamic basis with vector types
-/// since that requires vector.broadcast which is not available here.
-static FailureOr<Value> materializeBasis(OpBuilder &b, Location loc,
-                                         OpFoldResult ofr, Type targetType) {
-  std::optional<int64_t> cst = getConstantIntValue(ofr);
-  if (cst)
-    return createTypedConstant(b, loc, targetType, *cst);
-  // Dynamic scalar basis value. For scalar target types, return as-is.
-  if (isa<IndexType>(targetType))
-    return getValueOrCreateConstantIndexOp(b, loc, ofr);
-  // Dynamic scalar basis with vector target type -- would need
-  // vector.broadcast, bail out.
-  return failure();
-}
-
 /// Lowers `affine.delinearize_index` into a sequence of division and remainder
 /// operations.
 struct LowerDelinearizeIndexOps
@@ -65,51 +36,17 @@ struct LowerDelinearizeIndexOps
   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    // For scalar types, use the existing affine lowering path.
-    if (isa<IndexType>(op.getLinearIndex().getType())) {
-      FailureOr<SmallVector<Value>> multiIndex =
-          delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
-                           op.getEffectiveBasis(), /*hasOuterBound=*/false);
-      if (failed(multiIndex))
-        return failure();
-      rewriter.replaceOp(op, *multiIndex);
-      return success();
-    }
-
-    // Vector lowering: emit arith div/rem ops (which work element-wise on
-    // vectors).
-    Location loc = op.getLoc();
-    Value linearIndex = op.getLinearIndex();
-    Type type = linearIndex.getType();
-    SmallVector<OpFoldResult> basis = op.getEffectiveBasis();
-
-    // Compute cumulative products of basis from the right. These serve as
-    // divisors: for basis (B0, B1, B2), the divisors are (B1*B2, B2).
-    SmallVector<Value> divisors;
-    Value cumulativeProd = createTypedConstant(rewriter, loc, type, 1);
-    for (OpFoldResult basisElem : llvm::reverse(basis)) {
-      FailureOr<Value> basisVal =
-          materializeBasis(rewriter, loc, basisElem, type);
-      if (failed(basisVal))
-        return failure();
-      cumulativeProd =
-          arith::MulIOp::create(rewriter, loc, cumulativeProd, *basisVal);
-      divisors.push_back(cumulativeProd);
-    }
-
-    // Emit div/mod pairs from the most-significant dimension to the least.
-    SmallVector<Value> results;
-    results.reserve(divisors.size() + 1);
-    Value residual = linearIndex;
-    for (Value divisor : llvm::reverse(divisors)) {
-      Value quotient = arith::DivSIOp::create(rewriter, loc, residual, divisor);
-      Value product = arith::MulIOp::create(rewriter, loc, quotient, divisor);
-      Value remainder = arith::SubIOp::create(rewriter, loc, residual, product);
-      results.push_back(quotient);
-      residual = remainder;
-    }
-    results.push_back(residual);
-    rewriter.replaceOp(op, results);
+    // This pass lowers to affine.apply which only supports scalar index types.
+    // Vector types should be lowered using -affine-expand-index-ops instead.
+    if (!isa<IndexType>(op.getLinearIndex().getType()))
+      return rewriter.notifyMatchFailure(op, "expected scalar index type");
+
+    FailureOr<SmallVector<Value>> multiIndex =
+        delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
+                         op.getEffectiveBasis(), /*hasOuterBound=*/false);
+    if (failed(multiIndex))
+      return failure();
+    rewriter.replaceOp(op, *multiIndex);
     return success();
   }
 };
@@ -120,66 +57,24 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
+    // This pass lowers to affine.apply which only supports scalar index types.
+    // Vector types should be lowered using -affine-expand-index-ops instead.
+    if (!isa<IndexType>(op.getLinearIndex().getType()))
+      return rewriter.notifyMatchFailure(op, "expected scalar index type");
+
     // Should be folded away, included here for safety.
     if (op.getMultiIndex().empty()) {
       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
       return success();
     }
 
-    // For scalar types, use the existing affine lowering path.
-    if (isa<IndexType>(op.getLinearIndex().getType())) {
-      SmallVector<OpFoldResult> multiIndex =
-          getAsOpFoldResult(op.getMultiIndex());
-      OpFoldResult linearIndex =
-          linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
-      Value linearIndexValue =
-          getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
-      rewriter.replaceOp(op, linearIndexValue);
-      return success();
-    }
-
-    // Vector lowering: emit arith ops (which work element-wise on vectors).
-    //
-    // linearize_index [i0, i1, ..., iN-1] by (B0, B1, ..., BN-1)
-    // = i0 * stride_0 + i1 * stride_1 + ... + iN-1
-    // where stride_k = B_{k+1} * B_{k+2} * ... * B_{N-1}
-    //
-    // We compute from the back: result = iN-1, stride = 1, then:
-    //   stride *= B_{k}, result += i_k * stride
-    Location loc = op.getLoc();
-    Type type = op.getLinearIndex().getType();
-    SmallVector<OpFoldResult> effectiveBasis = op.getEffectiveBasis();
-    ValueRange indices = op.getMultiIndex();
-
-    // effectiveBasis drops the outer bound. For indices [i0, i1, ..., iN-1]:
-    //   no outer bound:  effectiveBasis = [B1, B2, ..., BN-1] (N-1 elems)
-    //   has outer bound: effectiveBasis = [B0, B1, ..., BN-1] (N elems,
-    //                    but B0 is advisory, dropped by getEffectiveBasis)
-    //
-    // Computation: result = iN-1 + BN-1 * (iN-2 + BN-2 * (... + B1 * i0))
-    // Or equivalently, accumulate from back:
-    //   result = iN-1
-    //   stride = 1
-    //   for k = numBasis-1 downto 0:
-    //     stride *= effectiveBasis[k]
-    //     result += indices[k] * stride
-    //
-    // This works because effectiveBasis[k] is the "size" of dimension k+1,
-    // and indices[k] is paired with the product of all sizes after it.
-    Value result = indices.back();
-    Value stride = createTypedConstant(rewriter, loc, type, 1);
-
-    for (int i = static_cast<int>(effectiveBasis.size()) - 1; i >= 0; --i) {
-      FailureOr<Value> basisVal =
-          materializeBasis(rewriter, loc, effectiveBasis[i], type);
-      if (failed(basisVal))
-        return failure();
-      stride = arith::MulIOp::create(rewriter, loc, stride, *basisVal);
-      Value term = arith::MulIOp::create(rewriter, loc, indices[i], stride);
-      result = arith::AddIOp::create(rewriter, loc, term, result);
-    }
-
-    rewriter.replaceOp(op, result);
+    SmallVector<OpFoldResult> multiIndex =
+        getAsOpFoldResult(op.getMultiIndex());
+    OpFoldResult linearIndex =
+        linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
+    Value linearIndexValue =
+        getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
+    rewriter.replaceOp(op, linearIndexValue);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
index 21595356936fa..bf9f00da5793a 100644
--- a/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
@@ -68,70 +68,3 @@ func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: in
   %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index
   func.return %0 : index
 }
-
-// -----
-
-// CHECK-LABEL: @expand_delinearize_vector
-// CHECK-SAME:    (%[[VEC:.+]]: vector<16xindex>)
-// CHECK-DAG:     %[[C8:.+]] = arith.constant dense<8> : vector<16xindex>
-// CHECK:         %[[DIV:.+]] = arith.divsi %[[VEC]], %[[C8]]
-// CHECK:         %[[MUL:.+]] = arith.muli %[[DIV]], %[[C8]]
-// CHECK:         %[[REM:.+]] = arith.subi %[[VEC]], %[[MUL]]
-// CHECK:         return %[[DIV]], %[[REM]]
-func.func @expand_delinearize_vector(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>) {
-  %0:2 = affine.delinearize_index %vec into (4, 8) : vector<16xindex>, vector<16xindex>
-  return %0#0, %0#1 : vector<16xindex>, vector<16xindex>
-}
-
-// -----
-
-// CHECK-LABEL: @expand_linearize_vector
-// CHECK-SAME:    (%[[V0:.+]]: vector<16xindex>, %[[V1:.+]]: vector<16xindex>)
-// CHECK-DAG:     %[[C8:.+]] = arith.constant dense<8> : vector<16xindex>
-// CHECK:         %[[MUL:.+]] = arith.muli %[[V0]], %[[C8]]
-// CHECK:         %[[ADD:.+]] = arith.addi %[[MUL]], %[[V1]]
-// CHECK:         return %[[ADD]]
-func.func @expand_linearize_vector(%v0: vector<16xindex>, %v1: vector<16xindex>) -> vector<16xindex> {
-  %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<16xindex>
-  return %0 : vector<16xindex>
-}
-
-// -----
-
-// CHECK-LABEL: @expand_delinearize_vector_3d
-// CHECK-SAME:    (%[[VEC:.+]]: vector<16xindex>)
-// CHECK-DAG:     %[[C4:.+]] = arith.constant dense<4> : vector<16xindex>
-// CHECK-DAG:     %[[C12:.+]] = arith.constant dense<12> : vector<16xindex>
-// CHECK:         %[[D0:.+]] = arith.divsi %[[VEC]], %[[C12]]
-// CHECK:         %[[M0:.+]] = arith.muli %[[D0]], %[[C12]]
-// CHECK:         %[[R0:.+]] = arith.subi %[[VEC]], %[[M0]]
-// CHECK:         %[[D1:.+]] = arith.divsi %[[R0]], %[[C4]]
-// CHECK:         %[[M1:.+]] = arith.muli %[[D1]], %[[C4]]
-// CHECK:         %[[R1:.+]] = arith.subi %[[R0]], %[[M1]]
-// CHECK:         return %[[D0]], %[[D1]], %[[R1]]
-func.func @expand_delinearize_vector_3d(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>, vector<16xindex>) {
-  %0:3 = affine.delinearize_index %vec into (2, 3, 4) : vector<16xindex>, vector<16xindex>, vector<16xindex>
-  return %0#0, %0#1, %0#2 : vector<16xindex>, vector<16xindex>, vector<16xindex>
-}
-
-// -----
-
-// Vector linearize -> offset -> delinearize pattern
-// (as would be used in vector.gather lowering).
-
-// CHECK-LABEL: @vector_linearize_offset_delinearize
-// CHECK-SAME:    (%[[V0:.+]]: vector<4xindex>, %[[V1:.+]]: vector<4xindex>, %[[OFF:.+]]: vector<4xindex>)
-// CHECK-DAG:     %[[C8:.+]] = arith.constant dense<8> : vector<4xindex>
-// CHECK:         %[[LIN:.+]] = arith.muli %[[V0]], %[[C8]]
-// CHECK:         %[[LIN2:.+]] = arith.addi %[[LIN]], %[[V1]]
-// CHECK:         %[[FLAT:.+]] = arith.addi %[[LIN2]], %[[OFF]]
-// CHECK:         %[[DIV:.+]] = arith.divsi %[[FLAT]], %[[C8]]
-// CHECK:         %[[MUL:.+]] = arith.muli %[[DIV]], %[[C8]]
-// CHECK:         %[[REM:.+]] = arith.subi %[[FLAT]], %[[MUL]]
-// CHECK:         return %[[DIV]], %[[REM]]
-func.func @vector_linearize_offset_delinearize(%v0: vector<4xindex>, %v1: vector<4xindex>, %offsets: vector<4xindex>) -> (vector<4xindex>, vector<4xindex>) {
-  %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<4xindex>
-  %1 = arith.addi %0, %offsets : vector<4xindex>
-  %2:2 = affine.delinearize_index %1 into (4, 8) : vector<4xindex>, vector<4xindex>
-  return %2#0, %2#1 : vector<4xindex>, vector<4xindex>
-}
diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
index 202050489b7e4..f9c4c561e3a10 100644
--- a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir
@@ -127,3 +127,50 @@ func.func @linearize_sort_adds(%arg0: memref<?xi32>, %arg1: index, %arg2: index)
   }
   return
 }
+
+// -----
+
+// CHECK-LABEL: @delinearize_vector
+// CHECK-SAME:    (%[[VEC:.+]]: vector<4xindex>)
+// CHECK-DAG:     %[[C8:.+]] = arith.constant dense<8> : vector<4xindex>
+// CHECK-DAG:     %[[C0:.+]] = arith.constant dense<0> : vector<4xindex>
+// CHECK:         %[[DIV:.+]] = arith.floordivsi %[[VEC]], %[[C8]]
+// CHECK:         %[[REM:.+]] = arith.remsi %[[VEC]], %[[C8]]
+// CHECK:         %[[NEG:.+]] = arith.cmpi slt, %[[REM]], %[[C0]]
+// CHECK:         %[[ADD:.+]] = arith.addi %[[REM]], %[[C8]] overflow<nsw>
+// CHECK:         %[[MOD:.+]] = arith.select %[[NEG]], %[[ADD]], %[[REM]]
+// CHECK:         return %[[DIV]], %[[MOD]]
+func.func @delinearize_vector(%vec: vector<4xindex>) -> (vector<4xindex>, vector<4xindex>) {
+  %0:2 = affine.delinearize_index %vec into (4, 8) : vector<4xindex>, vector<4xindex>
+  return %0#0, %0#1 : vector<4xindex>, vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @linearize_vector
+// CHECK-SAME:    (%[[V0:.+]]: vector<4xindex>, %[[V1:.+]]: vector<4xindex>)
+// CHECK-DAG:     %[[C8:.+]] = arith.constant dense<8> : vector<4xindex>
+// CHECK:         %[[MUL:.+]] = arith.muli %[[V0]], %[[C8]] overflow<nsw>
+// CHECK:         %[[ADD:.+]] = arith.addi %[[MUL]], %[[V1]] overflow<nsw>
+// CHECK:         return %[[ADD]]
+func.func @linearize_vector(%v0: vector<4xindex>, %v1: vector<4xindex>) -> vector<4xindex> {
+  %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<4xindex>
+  return %0 : vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @delinearize_vector_dynamic
+// CHECK-SAME:    (%[[VEC:.+]]: vector<4xindex>, %[[B:.+]]: index)
+// CHECK-DAG:     %[[C0:.+]] = arith.constant dense<0> : vector<4xindex>
+// CHECK:         %[[BCAST:.+]] = vector.broadcast %[[B]] : index to vector<4xindex>
+// CHECK:         %[[DIV:.+]] = arith.floordivsi %[[VEC]], %[[BCAST]]
+// CHECK:         %[[REM:.+]] = arith.remsi %[[VEC]], %[[BCAST]]
+// CHECK:         %[[NEG:.+]] = arith.cmpi slt, %[[REM]], %[[C0]]
+// CHECK:         %[[ADD:.+]] = arith.addi %[[REM]], %[[BCAST]] overflow<nsw>
+// CHECK:         %[[MOD:.+]] = arith.select %[[NEG]], %[[ADD]], %[[REM]]
+// CHECK:         return %[[DIV]], %[[MOD]]
+func.func @delinearize_vector_dynamic(%vec: vector<4xindex>, %b: index) -> (vector<4xindex>, vector<4xindex>) {
+  %0:2 = affine.delinearize_index %vec into (4, %b) : vector<4xindex>, vector<4xindex>
+  return %0#0, %0#1 : vector<4xindex>, vector<4xindex>
+}

>From 06cc48d0aeb6ffe42ae530b1b17402a91ac41189 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Thu, 26 Mar 2026 23:03:32 +0000
Subject: [PATCH 08/11] [mlir][affine] Add vector unrolling to
 AffineExpandIndexOpsAsAffine

Implement vector support in AffineExpandIndexOpsAsAffine by unrolling
to per-element scalar affine.apply operations. This preserves the
pass's contract of lowering to affine.apply (which is scalar-only)
while handling vector inputs.

For each vector lane:
  extract scalar -> apply scalar delinearize/linearize -> insert back

Add tests for the unrolled vector lowering.

Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 .../mlir/Dialect/Affine/Transforms/Passes.td  |   1 +
 .../AffineExpandIndexOpsAsAffine.cpp          | 129 ++++++++++++++----
 .../affine-expand-index-ops-as-affine.mlir    |  49 +++++++
 3 files changed, 151 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td b/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td
index db3dd0544c7c2..2d119638bcd97 100644
--- a/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Transforms/Passes.td
@@ -439,6 +439,7 @@ def AffineExpandIndexOps : Pass<"affine-expand-index-ops"> {
 def AffineExpandIndexOpsAsAffine : Pass<"affine-expand-index-ops-as-affine"> {
   let summary = "Lower affine operations operating on indices into affine.apply operations";
   let constructor = "mlir::affine::createAffineExpandIndexOpsAsAffinePass()";
+  let dependentDialects = ["vector::VectorDialect", "ub::UBDialect"];
 }
 
 def AffineFoldMemRefAliasOps : Pass<"affine-fold-memref-alias-ops"> {
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
index fa3fc03520cc0..758dfe77e92a0 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
@@ -16,6 +16,8 @@
 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 namespace mlir {
@@ -30,51 +32,122 @@ using namespace mlir::affine;
 
 namespace {
 /// Lowers `affine.delinearize_index` into a sequence of division and remainder
-/// operations.
+/// operations via affine.apply. For vector types, unrolls to per-element
+/// scalar affine.apply operations.
 struct LowerDelinearizeIndexOps
     : public OpRewritePattern<AffineDelinearizeIndexOp> {
   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    // This pass lowers to affine.apply which only supports scalar index types.
-    // Vector types should be lowered using -affine-expand-index-ops instead.
-    if (!isa<IndexType>(op.getLinearIndex().getType()))
-      return rewriter.notifyMatchFailure(op, "expected scalar index type");
-
-    FailureOr<SmallVector<Value>> multiIndex =
-        delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
-                         op.getEffectiveBasis(), /*hasOuterBound=*/false);
-    if (failed(multiIndex))
-      return failure();
-    rewriter.replaceOp(op, *multiIndex);
+    Location loc = op.getLoc();
+    Value linearIndex = op.getLinearIndex();
+    auto vecTy = dyn_cast<VectorType>(linearIndex.getType());
+
+    // Scalar case: use the existing affine lowering path.
+    if (!vecTy) {
+      FailureOr<SmallVector<Value>> multiIndex =
+          delinearizeIndex(rewriter, loc, linearIndex, op.getEffectiveBasis(),
+                           /*hasOuterBound=*/false);
+      if (failed(multiIndex))
+        return failure();
+      rewriter.replaceOp(op, *multiIndex);
+      return success();
+    }
+
+    // Vector case: unroll to per-element scalar affine.apply operations.
+    if (vecTy.isScalable())
+      return rewriter.notifyMatchFailure(op, "scalable vectors not supported");
+
+    int64_t numElems = vecTy.getNumElements();
+    unsigned numResults = op.getNumResults();
+
+    // Initialize result vectors with a poison/undef-like value.
+    SmallVector<Value> resultVecs(numResults);
+    Value poison = ub::PoisonOp::create(rewriter, loc, vecTy);
+    for (unsigned r = 0; r < numResults; ++r)
+      resultVecs[r] = poison;
+
+    for (int64_t i = 0; i < numElems; ++i) {
+      // Extract scalar element.
+      Value idx = arith::ConstantIndexOp::create(rewriter, loc, i);
+      Value scalar = vector::ExtractOp::create(rewriter, loc, linearIndex, idx);
+
+      // Apply scalar delinearization.
+      FailureOr<SmallVector<Value>> scalarResults =
+          delinearizeIndex(rewriter, loc, scalar, op.getEffectiveBasis(),
+                           /*hasOuterBound=*/false);
+      if (failed(scalarResults))
+        return failure();
+
+      // Insert results back into vectors.
+      for (unsigned r = 0; r < numResults; ++r)
+        resultVecs[r] = vector::InsertOp::create(
+            rewriter, loc, (*scalarResults)[r], resultVecs[r], idx);
+    }
+
+    rewriter.replaceOp(op, resultVecs);
     return success();
   }
 };
 
 /// Lowers `affine.linearize_index` into a sequence of multiplications and
-/// additions.
+/// additions via affine.apply. For vector types, unrolls to per-element
+/// scalar affine.apply operations.
 struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
   using OpRewritePattern::OpRewritePattern;
   LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
                                 PatternRewriter &rewriter) const override {
-    // This pass lowers to affine.apply which only supports scalar index types.
-    // Vector types should be lowered using -affine-expand-index-ops instead.
-    if (!isa<IndexType>(op.getLinearIndex().getType()))
-      return rewriter.notifyMatchFailure(op, "expected scalar index type");
-
-    // Should be folded away, included here for safety.
-    if (op.getMultiIndex().empty()) {
-      rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+    Location loc = op.getLoc();
+    auto vecTy = dyn_cast<VectorType>(op.getLinearIndex().getType());
+
+    // Scalar case: use the existing affine lowering path.
+    if (!vecTy) {
+      // Should be folded away, included here for safety.
+      if (op.getMultiIndex().empty()) {
+        rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+        return success();
+      }
+
+      SmallVector<OpFoldResult> multiIndex =
+          getAsOpFoldResult(op.getMultiIndex());
+      OpFoldResult linearIndex =
+          linearizeIndex(rewriter, loc, multiIndex, op.getMixedBasis());
+      Value linearIndexValue =
+          getValueOrCreateConstantIntOp(rewriter, loc, linearIndex);
+      rewriter.replaceOp(op, linearIndexValue);
       return success();
     }
 
-    SmallVector<OpFoldResult> multiIndex =
-        getAsOpFoldResult(op.getMultiIndex());
-    OpFoldResult linearIndex =
-        linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
-    Value linearIndexValue =
-        getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
-    rewriter.replaceOp(op, linearIndexValue);
+    // Vector case: unroll to per-element scalar affine.apply operations.
+    if (vecTy.isScalable())
+      return rewriter.notifyMatchFailure(op, "scalable vectors not supported");
+
+    int64_t numElems = vecTy.getNumElements();
+    ValueRange multiIndex = op.getMultiIndex();
+
+    Value result = ub::PoisonOp::create(rewriter, loc, vecTy);
+
+    for (int64_t i = 0; i < numElems; ++i) {
+      Value idx = arith::ConstantIndexOp::create(rewriter, loc, i);
+
+      // Extract scalar elements from each multi_index vector.
+      SmallVector<OpFoldResult> scalarIndices;
+      for (Value vec : multiIndex)
+        scalarIndices.push_back(
+            vector::ExtractOp::create(rewriter, loc, vec, idx).getResult());
+
+      // Apply scalar linearization.
+      OpFoldResult linearIndex =
+          linearizeIndex(rewriter, loc, scalarIndices, op.getMixedBasis());
+      Value scalarResult =
+          getValueOrCreateConstantIntOp(rewriter, loc, linearIndex);
+
+      // Insert result back into vector.
+      result =
+          vector::InsertOp::create(rewriter, loc, scalarResult, result, idx);
+    }
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
index bf9f00da5793a..b1739159b5513 100644
--- a/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
@@ -68,3 +68,52 @@ func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: in
   %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index
   func.return %0 : index
 }
+
+// -----
+
+// Vector delinearize: unrolled to per-element affine.apply.
+
+//   CHECK-DAG:   #[[$DIV8:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//   CHECK-DAG:   #[[$MOD8:.+]] = affine_map<()[s0] -> (s0 mod 8)>
+
+// CHECK-LABEL: @delinearize_vector_unroll
+// CHECK-SAME:    (%[[VEC:.+]]: vector<2xindex>)
+// CHECK:         %[[POISON0:.+]] = ub.poison : vector<2xindex>
+// CHECK:         %[[S0:.+]] = vector.extract %[[VEC]][0]
+// CHECK:         %[[D0:.+]] = affine.apply #[[$DIV8]]()[%[[S0]]]
+// CHECK:         %[[M0:.+]] = affine.apply #[[$MOD8]]()[%[[S0]]]
+// CHECK:         %[[R0_0:.+]] = vector.insert %[[D0]], %[[POISON0]] [0]
+// CHECK:         %[[R1_0:.+]] = vector.insert %[[M0]], %[[POISON0]] [0]
+// CHECK:         %[[S1:.+]] = vector.extract %[[VEC]][1]
+// CHECK:         %[[D1:.+]] = affine.apply #[[$DIV8]]()[%[[S1]]]
+// CHECK:         %[[M1:.+]] = affine.apply #[[$MOD8]]()[%[[S1]]]
+// CHECK:         %[[R0_1:.+]] = vector.insert %[[D1]], %[[R0_0]] [1]
+// CHECK:         %[[R1_1:.+]] = vector.insert %[[M1]], %[[R1_0]] [1]
+// CHECK:         return %[[R0_1]], %[[R1_1]]
+func.func @delinearize_vector_unroll(%vec: vector<2xindex>) -> (vector<2xindex>, vector<2xindex>) {
+  %0:2 = affine.delinearize_index %vec into (4, 8) : vector<2xindex>, vector<2xindex>
+  return %0#0, %0#1 : vector<2xindex>, vector<2xindex>
+}
+
+// -----
+
+// Vector linearize: unrolled to per-element affine.apply.
+
+// CHECK-DAG:   #[[$LIN:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + s1)>
+
+// CHECK-LABEL: @linearize_vector_unroll
+// CHECK-SAME:    (%[[V0:.+]]: vector<2xindex>, %[[V1:.+]]: vector<2xindex>)
+// CHECK:         %[[POISON:.+]] = ub.poison : vector<2xindex>
+// CHECK:         %[[A0:.+]] = vector.extract %[[V0]][0]
+// CHECK:         %[[B0:.+]] = vector.extract %[[V1]][0]
+// CHECK:         %[[L0:.+]] = affine.apply #[[$LIN]]()[%[[A0]], %[[B0]]]
+// CHECK:         %[[R0:.+]] = vector.insert %[[L0]], %[[POISON]] [0]
+// CHECK:         %[[A1:.+]] = vector.extract %[[V0]][1]
+// CHECK:         %[[B1:.+]] = vector.extract %[[V1]][1]
+// CHECK:         %[[L1:.+]] = affine.apply #[[$LIN]]()[%[[A1]], %[[B1]]]
+// CHECK:         %[[R1:.+]] = vector.insert %[[L1]], %[[R0]] [1]
+// CHECK:         return %[[R1]]
+func.func @linearize_vector_unroll(%v0: vector<2xindex>, %v1: vector<2xindex>) -> vector<2xindex> {
+  %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<2xindex>
+  return %0 : vector<2xindex>
+}

>From cf3ca25b7e797458ca293cfe2023dbe46538ae54 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Thu, 26 Mar 2026 23:06:39 +0000
Subject: [PATCH 09/11] [mlir][affine] Add vector copies of
 canonicalization/folding tests

Add vector variants of all canonicalization patterns:
- cancel delinearize(linearize_disjoint) full and partial
- cancel linearize(delinearize) exact
- drop all unit bases (delinearize, with/without outer bound)
- linearize all zero unit basis
- linearize one element basis fold
- drop leading zero in linearize
- split delinearize spanning last linearize arg

Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 mlir/test/Dialect/Affine/canonicalize.mlir | 114 +++++++++++++++++++++
 1 file changed, 114 insertions(+)

diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 008c4babb9408..347eb64086228 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -2482,3 +2482,117 @@ func.func @fold_single_index_vector(%vec: vector<4xindex>) -> vector<4xindex> {
   %0 = affine.linearize_index [%vec] by () : vector<4xindex>
   return %0 : vector<4xindex>
 }
+
+// -----
+
+// Vector: cancel delinearize(linearize_disjoint) partial match.
+
+// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_partial_vector(
+// CHECK-SAME:     %[[V0:.+]]: vector<4xindex>, %[[V1:.+]]: vector<4xindex>, %[[V2:.+]]: vector<4xindex>)
+// CHECK:         return %[[V0]], %[[V1]], %[[V2]]
+func.func @cancel_delinearize_linearize_disjoint_partial_vector(
+    %v0: vector<4xindex>, %v1: vector<4xindex>, %v2: vector<4xindex>)
+    -> (vector<4xindex>, vector<4xindex>, vector<4xindex>) {
+  %0 = affine.linearize_index disjoint [%v0, %v1, %v2] by (3, 2, 32) : vector<4xindex>
+  %1:3 = affine.delinearize_index %0 into (3, 2, 32) : vector<4xindex>, vector<4xindex>, vector<4xindex>
+  return %1#0, %1#1, %1#2 : vector<4xindex>, vector<4xindex>, vector<4xindex>
+}
+
+// -----
+
+// Vector: cancel linearize(delinearize) exact.
+
+// CHECK-LABEL: func @cancel_linearize_delinearize_exact_vector(
+// CHECK-SAME:     %[[ARG:.+]]: vector<4xindex>)
+// CHECK:         return %[[ARG]]
+func.func @cancel_linearize_delinearize_exact_vector(%arg: vector<4xindex>) -> vector<4xindex> {
+  %0:3 = affine.delinearize_index %arg into (2, 4, 8) : vector<4xindex>, vector<4xindex>, vector<4xindex>
+  %1 = affine.linearize_index [%0#0, %0#1, %0#2] by (2, 4, 8) : vector<4xindex>
+  return %1 : vector<4xindex>
+}
+
+// -----
+
+// Vector: drop all unit bases in delinearize.
+
+// CHECK-LABEL: func @drop_all_unit_bases_vector(
+// CHECK-SAME:     %[[VEC:.+]]: vector<4xindex>)
+// CHECK-DAG:     %[[C0:.+]] = arith.constant dense<0> : vector<4xindex>
+// CHECK-NOT:     affine.delinearize_index
+// CHECK:         return %[[C0]], %[[C0]]
+func.func @drop_all_unit_bases_vector(%vec: vector<4xindex>) -> (vector<4xindex>, vector<4xindex>) {
+  %0:2 = affine.delinearize_index %vec into (1, 1) : vector<4xindex>, vector<4xindex>
+  return %0#0, %0#1 : vector<4xindex>, vector<4xindex>
+}
+
+// -----
+
+// Vector: drop all unit bases no outer bound.
+
+// CHECK-LABEL: func @drop_all_unit_bases_no_outer_bound_vector(
+// CHECK-SAME:     %[[VEC:.+]]: vector<4xindex>)
+// CHECK-DAG:     %[[C0:.+]] = arith.constant dense<0> : vector<4xindex>
+// CHECK-NOT:     affine.delinearize_index
+// CHECK:         return %[[VEC]], %[[C0]], %[[C0]]
+func.func @drop_all_unit_bases_no_outer_bound_vector(%vec: vector<4xindex>) -> (vector<4xindex>, vector<4xindex>, vector<4xindex>) {
+  %0:3 = affine.delinearize_index %vec into (1, 1) : vector<4xindex>, vector<4xindex>, vector<4xindex>
+  return %0#0, %0#1, %0#2 : vector<4xindex>, vector<4xindex>, vector<4xindex>
+}
+
+// -----
+
+// Vector: linearize all zero unit basis.
+
+// CHECK-LABEL: @linearize_all_zero_unit_basis_vector
+// CHECK:         arith.constant dense<0> : vector<4xindex>
+// CHECK-NOT:     affine.linearize_index
+func.func @linearize_all_zero_unit_basis_vector() -> vector<4xindex> {
+  %c0 = arith.constant dense<0> : vector<4xindex>
+  %ret = affine.linearize_index [%c0, %c0] by (1, 1) : vector<4xindex>
+  return %ret : vector<4xindex>
+}
+
+// -----
+
+// Vector: linearize one element basis fold.
+
+// CHECK-LABEL: @linearize_one_element_basis_vector
+// CHECK-SAME:    (%[[V:.+]]: vector<4xindex>)
+// CHECK-NOT:     affine.linearize_index
+// CHECK:         return %[[V]]
+func.func @linearize_one_element_basis_vector(%v: vector<4xindex>) -> vector<4xindex> {
+  %ret = affine.linearize_index [%v] by (8) : vector<4xindex>
+  return %ret : vector<4xindex>
+}
+
+// -----
+
+// Vector: drop leading zero in linearize.
+
+// CHECK-LABEL: @linearize_drop_leading_zero_vector
+// CHECK-SAME:    (%[[V0:.+]]: vector<4xindex>, %[[V1:.+]]: vector<4xindex>)
+// CHECK-NOT:     affine.linearize_index
+// CHECK:         return %[[V1]]
+func.func @linearize_drop_leading_zero_vector(%v0: vector<4xindex>, %v1: vector<4xindex>) -> vector<4xindex> {
+  %c0 = arith.constant dense<0> : vector<4xindex>
+  %ret = affine.linearize_index [%c0, %v1] by (4, 8) : vector<4xindex>
+  return %ret : vector<4xindex>
+}
+
+// -----
+
+// Vector: split delinearize spanning last linearize arg.
+
+// CHECK-LABEL: func @split_delinearize_spanning_final_part_vector
+// CHECK-SAME:    (%[[V0:.+]]: vector<4xindex>, %[[V1:.+]]: vector<4xindex>, %[[V2:.+]]: vector<4xindex>)
+// CHECK:         %[[LIN:.+]] = affine.linearize_index disjoint [%[[V0]], %[[V1]]] by (3, 2) : vector<4xindex>
+// CHECK:         %[[DELIN1:.+]]:2 = affine.delinearize_index %[[LIN]] into (2, 3) : vector<4xindex>, vector<4xindex>
+// CHECK:         %[[DELIN2:.+]]:2 = affine.delinearize_index %[[V2]] into (8, 4) : vector<4xindex>, vector<4xindex>
+// CHECK:         return %[[DELIN1]]#0, %[[DELIN1]]#1, %[[DELIN2]]#0, %[[DELIN2]]#1
+func.func @split_delinearize_spanning_final_part_vector(
+    %v0: vector<4xindex>, %v1: vector<4xindex>, %v2: vector<4xindex>)
+    -> (vector<4xindex>, vector<4xindex>, vector<4xindex>, vector<4xindex>) {
+  %0 = affine.linearize_index disjoint [%v0, %v1, %v2] by (3, 2, 32) : vector<4xindex>
+  %1:4 = affine.delinearize_index %0 into (2, 3, 8, 4) : vector<4xindex>, vector<4xindex>, vector<4xindex>, vector<4xindex>
+  return %1#0, %1#1, %1#2, %1#3 : vector<4xindex>, vector<4xindex>, vector<4xindex>, vector<4xindex>
+}

>From a0a6fc3cbd5d91fbd0d3602a5f6125bcfb1d3dd4 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Fri, 27 Mar 2026 04:28:47 +0000
Subject: [PATCH 10/11] [mlir][affine] Use StaticTileOffsetRange and static
 positions in vector unrolling

Address review comments:
- Use StaticTileOffsetRange to iterate over vector elements, supporting
  multi-dimensional vectors (e.g. vector<2x4xindex>).
- Use static integer positions for vector.extract/vector.insert instead
  of arith.constant index values.

Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 .../AffineExpandIndexOpsAsAffine.cpp          | 35 ++++++++-----------
 1 file changed, 15 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
index 758dfe77e92a0..4a5a09b07f28b 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
@@ -54,35 +55,32 @@ struct LowerDelinearizeIndexOps
       return success();
     }
 
-    // Vector case: unroll to per-element scalar affine.apply operations.
+    // Vector case: unroll to per-element scalar affine.apply operations
+    // using StaticTileOffsetRange for multi-dimensional vector support.
     if (vecTy.isScalable())
       return rewriter.notifyMatchFailure(op, "scalable vectors not supported");
 
-    int64_t numElems = vecTy.getNumElements();
     unsigned numResults = op.getNumResults();
+    ArrayRef<int64_t> shape = vecTy.getShape();
+    SmallVector<int64_t> tileShape(shape.size(), 1);
 
-    // Initialize result vectors with a poison/undef-like value.
     SmallVector<Value> resultVecs(numResults);
     Value poison = ub::PoisonOp::create(rewriter, loc, vecTy);
     for (unsigned r = 0; r < numResults; ++r)
       resultVecs[r] = poison;
 
-    for (int64_t i = 0; i < numElems; ++i) {
-      // Extract scalar element.
-      Value idx = arith::ConstantIndexOp::create(rewriter, loc, i);
-      Value scalar = vector::ExtractOp::create(rewriter, loc, linearIndex, idx);
+    for (SmallVector<int64_t> pos : StaticTileOffsetRange(shape, tileShape)) {
+      Value scalar = vector::ExtractOp::create(rewriter, loc, linearIndex, pos);
 
-      // Apply scalar delinearization.
       FailureOr<SmallVector<Value>> scalarResults =
           delinearizeIndex(rewriter, loc, scalar, op.getEffectiveBasis(),
                            /*hasOuterBound=*/false);
       if (failed(scalarResults))
         return failure();
 
-      // Insert results back into vectors.
       for (unsigned r = 0; r < numResults; ++r)
         resultVecs[r] = vector::InsertOp::create(
-            rewriter, loc, (*scalarResults)[r], resultVecs[r], idx);
+            rewriter, loc, (*scalarResults)[r], resultVecs[r], pos);
     }
 
     rewriter.replaceOp(op, resultVecs);
@@ -118,33 +116,30 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
       return success();
     }
 
-    // Vector case: unroll to per-element scalar affine.apply operations.
+    // Vector case: unroll to per-element scalar affine.apply operations
+    // using StaticTileOffsetRange for multi-dimensional vector support.
     if (vecTy.isScalable())
       return rewriter.notifyMatchFailure(op, "scalable vectors not supported");
 
-    int64_t numElems = vecTy.getNumElements();
+    ArrayRef<int64_t> shape = vecTy.getShape();
+    SmallVector<int64_t> tileShape(shape.size(), 1);
     ValueRange multiIndex = op.getMultiIndex();
 
     Value result = ub::PoisonOp::create(rewriter, loc, vecTy);
 
-    for (int64_t i = 0; i < numElems; ++i) {
-      Value idx = arith::ConstantIndexOp::create(rewriter, loc, i);
-
-      // Extract scalar elements from each multi_index vector.
+    for (SmallVector<int64_t> pos : StaticTileOffsetRange(shape, tileShape)) {
       SmallVector<OpFoldResult> scalarIndices;
       for (Value vec : multiIndex)
         scalarIndices.push_back(
-            vector::ExtractOp::create(rewriter, loc, vec, idx).getResult());
+            vector::ExtractOp::create(rewriter, loc, vec, pos).getResult());
 
-      // Apply scalar linearization.
       OpFoldResult linearIndex =
           linearizeIndex(rewriter, loc, scalarIndices, op.getMixedBasis());
       Value scalarResult =
           getValueOrCreateConstantIntOp(rewriter, loc, linearIndex);
 
-      // Insert result back into vector.
       result =
-          vector::InsertOp::create(rewriter, loc, scalarResult, result, idx);
+          vector::InsertOp::create(rewriter, loc, scalarResult, result, pos);
     }
 
     rewriter.replaceOp(op, result);

>From fa4ecaadb6cfeffa9164057e945f83ddbff28d55 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Fri, 27 Mar 2026 04:32:35 +0000
Subject: [PATCH 11/11] [mlir][affine] Add multi-dimensional vector unrolling
 test

Add test for vector<2x2xindex> to verify StaticTileOffsetRange
correctly iterates over multi-dimensional vector elements with
static positions in vector.extract/vector.insert.

Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 .../affine-expand-index-ops-as-affine.mlir    | 27 +++++++++++++++++++
 1 file changed, 27 insertions(+)

diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
index b1739159b5513..d99b2605147c3 100644
--- a/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
@@ -117,3 +117,30 @@ func.func @linearize_vector_unroll(%v0: vector<2xindex>, %v1: vector<2xindex>) -
   %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<2xindex>
   return %0 : vector<2xindex>
 }
+
+// -----
+
+// Multi-dimensional vector delinearize: unrolled with static positions.
+
+//   CHECK-DAG:   #[[$DIV8:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//   CHECK-DAG:   #[[$MOD8:.+]] = affine_map<()[s0] -> (s0 mod 8)>
+
+// CHECK-LABEL: @delinearize_2d_vector_unroll
+// CHECK-SAME:    (%[[VEC:.+]]: vector<2x2xindex>)
+// CHECK:         %[[POISON:.+]] = ub.poison : vector<2x2xindex>
+// CHECK:         %[[S00:.+]] = vector.extract %[[VEC]][0, 0]
+// CHECK:         %[[D00:.+]] = affine.apply #[[$DIV8]]()[%[[S00]]]
+// CHECK:         %[[M00:.+]] = affine.apply #[[$MOD8]]()[%[[S00]]]
+// CHECK:         %[[R0_00:.+]] = vector.insert %[[D00]], %[[POISON]] [0, 0]
+// CHECK:         %[[R1_00:.+]] = vector.insert %[[M00]], %[[POISON]] [0, 0]
+// CHECK:         %[[S01:.+]] = vector.extract %[[VEC]][0, 1]
+// CHECK:         %[[D01:.+]] = affine.apply #[[$DIV8]]()[%[[S01]]]
+// CHECK:         %[[M01:.+]] = affine.apply #[[$MOD8]]()[%[[S01]]]
+// CHECK:         vector.insert %[[D01]], %[[R0_00]] [0, 1]
+// CHECK:         vector.insert %[[M01]], %[[R1_00]] [0, 1]
+// CHECK:         vector.extract %[[VEC]][1, 0]
+// CHECK:         vector.extract %[[VEC]][1, 1]
+func.func @delinearize_2d_vector_unroll(%vec: vector<2x2xindex>) -> (vector<2x2xindex>, vector<2x2xindex>) {
+  %0:2 = affine.delinearize_index %vec into (4, 8) : vector<2x2xindex>, vector<2x2xindex>
+  return %0#0, %0#1 : vector<2x2xindex>, vector<2x2xindex>
+}



More information about the Mlir-commits mailing list