[Mlir-commits] [mlir] [MLIR][Affine] Add vector support to affine.linearize_index and affine.delinearize_index (PR #188369)
Keshav Vinayak Jha
llvmlistbot at llvm.org
Wed Mar 25 09:09:37 PDT 2026
https://github.com/keshavvinayak01 updated https://github.com/llvm/llvm-project/pull/188369
>From 3f09929fea3253c7ec367e81bb93763bff32bd02 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 24 Mar 2026 22:08:13 +0000
Subject: [PATCH 1/5] [mlir][affine] Add vector support to
affine.linearize_index and affine.delinearize_index
Allow affine.delinearize_index and affine.linearize_index to operate on
vector<...xindex> types in addition to scalar index. The basis remains
scalar (it describes the shape of the index space, not per-lane data).
This enables expressing element-wise index computations across vector
lanes directly, rather than manually lowering to vector.broadcast +
vector.step + arith patterns.
Changes:
- Add Affine_IndexOrVectorOfIndex type constraint in AffineOps.td
- Implement custom parse/print for both ops
- Add type consistency verifiers (all results/inputs must match)
- Update canonicalizers to produce vector zeros where needed
- Add vector lowering path in AffineExpandIndexOpsAsAffine using
arith.divsi/muli/subi/addi (which natively support vectors)
- Scalar behavior is completely unchanged
Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>
Signed-off-by: Keshav Vinayak <keshavvinayak01 at gmail.com>
---
.../mlir/Dialect/Affine/IR/AffineOps.td | 27 ++-
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 170 +++++++++++++++++-
.../AffineExpandIndexOpsAsAffine.cpp | 141 +++++++++++++--
3 files changed, 301 insertions(+), 37 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 9cb0f3242db17..8a7e49c05f526 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -32,6 +32,12 @@ def Affine_Dialect : Dialect {
class Affine_Op<string mnemonic, list<Trait> traits = []> :
Op<Affine_Dialect, mnemonic, traits>;
+// Type constraint for index-like types: index or vector of index.
+def Affine_IndexOrVectorOfIndex :
+ Type<Or<[Index.predicate,
+ VectorOfAnyRankOf<[Index]>.predicate]>,
+ "index or vector of index">;
+
// Require regions to have affine.yield.
def ImplicitAffineTerminator
: SingleBlockImplicitTerminator<"AffineYieldOp">;
@@ -1118,16 +1124,12 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
- that is, the product of all basis elements is positive as an `index` as well.
}];
- let arguments = (ins Index:$linear_index,
+ let arguments = (ins Affine_IndexOrVectorOfIndex:$linear_index,
Variadic<Index>:$dynamic_basis,
DenseI64ArrayAttr:$static_basis);
- let results = (outs Variadic<Index>:$multi_index);
+ let results = (outs Variadic<Affine_IndexOrVectorOfIndex>:$multi_index);
- let assemblyFormat = [{
- $linear_index `into`
- custom<DynamicIndexList>($dynamic_basis, $static_basis, "{}", "::mlir::AsmParser::Delimiter::Paren")
- attr-dict `:` type($multi_index)
- }];
+ let hasCustomAssemblyFormat = 1;
let builders = [
OpBuilder<(ins "Value":$linear_index, "ValueRange":$dynamic_basis, "ArrayRef<int64_t>":$static_basis, CArg<"bool", "true">:$hasOuterBound)>,
@@ -1221,18 +1223,13 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
```
}];
- let arguments = (ins Variadic<Index>:$multi_index,
+ let arguments = (ins Variadic<Affine_IndexOrVectorOfIndex>:$multi_index,
Variadic<Index>:$dynamic_basis,
DenseI64ArrayAttr:$static_basis,
UnitProp:$disjoint);
- let results = (outs Index:$linear_index);
+ let results = (outs Affine_IndexOrVectorOfIndex:$linear_index);
- let assemblyFormat = [{
- (`disjoint` $disjoint^)? ` `
- `[` $multi_index `]` `by`
- custom<DynamicIndexList>($dynamic_basis, $static_basis, "{}", "::mlir::AsmParser::Delimiter::Paren")
- attr-dict `:` type($linear_index)
- }];
+ let hasCustomAssemblyFormat = 1;
let builders = [
OpBuilder<(ins "ValueRange":$multi_index, "ValueRange":$basis, CArg<"bool", "false">:$disjoint)>,
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 839d34b41cbd4..2060a74b061e7 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
@@ -4855,6 +4856,58 @@ LogicalResult AffineVectorStoreOp::verify() {
// DelinearizeIndexOp
//===----------------------------------------------------------------------===//
+/// Parse format:
+/// affine.delinearize_index %idx into (%c4, %c8)
+/// : index, index (scalar)
+/// affine.delinearize_index %vec into (%c4, %c8)
+/// : vector<16xindex>, vector<16xindex> (vector)
+ParseResult AffineDelinearizeIndexOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ OpAsmParser::UnresolvedOperand linearIndex;
+ if (parser.parseOperand(linearIndex) || parser.parseKeyword("into"))
+ return failure();
+
+ SmallVector<OpAsmParser::UnresolvedOperand> dynamicBasis;
+ DenseI64ArrayAttr staticBasis;
+ if (parseDynamicIndexList(parser, dynamicBasis, staticBasis, nullptr,
+ AsmParser::Delimiter::Paren))
+ return failure();
+
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ if (parser.parseColon())
+ return failure();
+
+ SmallVector<Type> resultTypes;
+ if (parser.parseTypeList(resultTypes))
+ return failure();
+
+ // Infer the linear index type from the first result type. All types must
+ // match (enforced by the verifier).
+ Type indexType = resultTypes.empty() ? IndexType::get(parser.getContext())
+ : resultTypes.front();
+ if (parser.resolveOperand(linearIndex, indexType, result.operands))
+ return failure();
+ if (parser.resolveOperands(dynamicBasis, IndexType::get(parser.getContext()),
+ result.operands))
+ return failure();
+
+ result.addTypes(resultTypes);
+ result.getOrAddProperties<AffineDelinearizeIndexOp::Properties>()
+ .static_basis = staticBasis;
+ return success();
+}
+
+void AffineDelinearizeIndexOp::print(OpAsmPrinter &p) {
+ p << ' ' << getLinearIndex() << " into ";
+ printDynamicIndexList(p, *this, getDynamicBasis(), getStaticBasisAttr(),
+ /*scalableFlags=*/{}, AsmParser::Delimiter::Paren);
+ p.printOptionalAttrDict((*this)->getAttrs(), {getStaticBasisAttrName()});
+ p << " : ";
+ llvm::interleaveComma(getResultTypes(), p);
+}
+
void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
Value linearIndex, ValueRange dynamicBasis,
@@ -4925,6 +4978,14 @@ LogicalResult AffineDelinearizeIndexOp::verify() {
}))
return emitOpError("no basis element may be statically non-positive");
+ // All result types must match the input type.
+ Type inputType = getLinearIndex().getType();
+ for (Type resultType : getResultTypes()) {
+ if (resultType != inputType)
+ return emitOpError("result types must match the linear index type, got ")
+ << resultType << " vs " << inputType;
+ }
+
return success();
}
@@ -5036,9 +5097,17 @@ struct DropUnitExtentBasis
SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
std::optional<Value> zero = std::nullopt;
Location loc = delinearizeOp->getLoc();
+ Type indexType = delinearizeOp.getLinearIndex().getType();
auto getZero = [&]() -> Value {
- if (!zero)
- zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ if (!zero) {
+ Value scalarZero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ if (auto vecTy = dyn_cast<VectorType>(indexType))
+ zero = arith::ConstantOp::create(
+ rewriter, loc,
+ DenseElementsAttr::get(vecTy, rewriter.getIndexAttr(0)));
+ else
+ zero = scalarZero;
+ }
return zero.value();
};
@@ -5204,9 +5273,9 @@ struct SplitDelinearizeSpanningLastLinearizeArg final
"need at least two elements to form the basis product");
Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create(
- rewriter, linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
- linearizeOp.getDynamicBasis(), linearizeOp.getStaticBasis().drop_back(),
- linearizeOp.getDisjoint());
+ rewriter, linearizeOp.getLoc(), linearizeOp.getLinearIndex().getType(),
+ linearizeOp.getMultiIndex().drop_back(), linearizeOp.getDynamicBasis(),
+ linearizeOp.getStaticBasis().drop_back(), linearizeOp.getDisjoint());
auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create(
rewriter, delinearizeOp.getLoc(), linearizeWithoutBack,
delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
@@ -5236,6 +5305,69 @@ void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
// LinearizeIndexOp
//===----------------------------------------------------------------------===//
+/// Parse format:
+/// affine.linearize_index [%x, %y] by (%c4, %c8) : index
+/// affine.linearize_index disjoint [%v0, %v1] by (%c4, %c8)
+/// : vector<16xindex>
+ParseResult AffineLinearizeIndexOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ bool disjoint = succeeded(parser.parseOptionalKeyword("disjoint"));
+
+ SmallVector<OpAsmParser::UnresolvedOperand> multiIndex;
+ if (parser.parseOperandList(multiIndex, AsmParser::Delimiter::Square) ||
+ parser.parseKeyword("by"))
+ return failure();
+
+ SmallVector<OpAsmParser::UnresolvedOperand> dynamicBasis;
+ DenseI64ArrayAttr staticBasis;
+ if (parseDynamicIndexList(parser, dynamicBasis, staticBasis, nullptr,
+ AsmParser::Delimiter::Paren))
+ return failure();
+
+ if (parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ Type resultType;
+ if (parser.parseColonType(resultType))
+ return failure();
+
+ if (parser.resolveOperands(multiIndex, resultType, result.operands))
+ return failure();
+ if (parser.resolveOperands(dynamicBasis, IndexType::get(parser.getContext()),
+ result.operands))
+ return failure();
+
+ result.addTypes(resultType);
+ auto &props = result.getOrAddProperties<AffineLinearizeIndexOp::Properties>();
+ props.static_basis = staticBasis;
+ props.disjoint = disjoint;
+ props.operandSegmentSizes = {static_cast<int32_t>(multiIndex.size()),
+ static_cast<int32_t>(dynamicBasis.size())};
+ return success();
+}
+
+void AffineLinearizeIndexOp::print(OpAsmPrinter &p) {
+ if (getDisjoint())
+ p << " disjoint";
+ p << " [";
+ llvm::interleaveComma(getMultiIndex(), p);
+ p << "] by ";
+ printDynamicIndexList(p, *this, getDynamicBasis(), getStaticBasisAttr(),
+ /*scalableFlags=*/{}, AsmParser::Delimiter::Paren);
+ p.printOptionalAttrDict(
+ (*this)->getAttrs(),
+ {getStaticBasisAttrName(), getOperandSegmentSizesAttrName()});
+ p << " : " << getLinearIndex().getType();
+}
+
+/// Infer the index type from a set of multi-index values. Returns the common
+/// type (index or vector<...xindex>), or IndexType if the set is empty.
+static Type inferIndexType(MLIRContext *ctx, ValueRange multiIndex) {
+ if (multiIndex.empty())
+ return IndexType::get(ctx);
+ return multiIndex.front().getType();
+}
+
void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
ValueRange multiIndex, ValueRange basis,
@@ -5246,7 +5378,9 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
staticBasis);
- build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
+ Type resultType = inferIndexType(odsBuilder.getContext(), multiIndex);
+ build(odsBuilder, odsState, resultType, multiIndex, dynamicBasis, staticBasis,
+ disjoint);
}
void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
@@ -5259,14 +5393,18 @@ void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
SmallVector<Value> dynamicBasis;
SmallVector<int64_t> staticBasis;
dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
- build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
+ Type resultType = inferIndexType(odsBuilder.getContext(), multiIndex);
+ build(odsBuilder, odsState, resultType, multiIndex, dynamicBasis, staticBasis,
+ disjoint);
}
void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
ValueRange multiIndex,
ArrayRef<int64_t> basis, bool disjoint) {
- build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
+ Type resultType = inferIndexType(odsBuilder.getContext(), multiIndex);
+ build(odsBuilder, odsState, resultType, multiIndex, ValueRange{}, basis,
+ disjoint);
}
LogicalResult AffineLinearizeIndexOp::verify() {
@@ -5284,6 +5422,14 @@ LogicalResult AffineLinearizeIndexOp::verify() {
"corresponding dynamic basis entry) -- this can only happen due to an "
"incorrect fold/rewrite");
+ // All multi_index types must match the result type.
+ Type resultType = getLinearIndex().getType();
+ for (Value idx : getMultiIndex()) {
+ if (idx.getType() != resultType)
+ return emitOpError("multi_index types must match the result type, got ")
+ << idx.getType() << " vs " << resultType;
+ }
+
return success();
}
@@ -5402,7 +5548,13 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final
"no unit basis entries to replace");
if (newIndices.empty()) {
- rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+ Type resultType = op.getLinearIndex().getType();
+ if (auto vecTy = dyn_cast<VectorType>(resultType)) {
+ rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+ op, DenseElementsAttr::get(vecTy, rewriter.getIndexAttr(0)));
+ } else {
+ rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
+ }
return success();
}
rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
index e919bc6d36265..0178c5159df53 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp
@@ -15,7 +15,9 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
@@ -29,6 +31,33 @@ using namespace mlir;
using namespace mlir::affine;
namespace {
+
+/// Create a constant splat of the given type with the given integer value.
+static Value createTypedConstant(OpBuilder &b, Location loc, Type type,
+ int64_t value) {
+ if (auto vecTy = dyn_cast<VectorType>(type))
+ return arith::ConstantOp::create(
+ b, loc, DenseElementsAttr::get(vecTy, b.getIndexAttr(value)));
+ return arith::ConstantIndexOp::create(b, loc, value);
+}
+
+/// Materialize an OpFoldResult (which represents a scalar index or constant)
+/// as a Value matching the given target type. For vector target types, scalar
+/// constants are splatted. Returns failure for dynamic basis with vector types
+/// since that requires vector.broadcast which is not available here.
+static FailureOr<Value> materializeBasis(OpBuilder &b, Location loc,
+ OpFoldResult ofr, Type targetType) {
+ std::optional<int64_t> cst = getConstantIntValue(ofr);
+ if (cst)
+ return createTypedConstant(b, loc, targetType, *cst);
+ // Dynamic scalar basis value. For scalar target types, return as-is.
+ if (isa<IndexType>(targetType))
+ return getValueOrCreateConstantIndexOp(b, loc, ofr);
+ // Dynamic scalar basis with vector target type -- would need
+ // vector.broadcast, bail out.
+ return failure();
+}
+
/// Lowers `affine.delinearize_index` into a sequence of division and remainder
/// operations.
struct LowerDelinearizeIndexOps
@@ -36,12 +65,51 @@ struct LowerDelinearizeIndexOps
using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
PatternRewriter &rewriter) const override {
- FailureOr<SmallVector<Value>> multiIndex =
- delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
- op.getEffectiveBasis(), /*hasOuterBound=*/false);
- if (failed(multiIndex))
- return failure();
- rewriter.replaceOp(op, *multiIndex);
+ // For scalar types, use the existing affine lowering path.
+ if (isa<IndexType>(op.getLinearIndex().getType())) {
+ FailureOr<SmallVector<Value>> multiIndex =
+ delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
+ op.getEffectiveBasis(), /*hasOuterBound=*/false);
+ if (failed(multiIndex))
+ return failure();
+ rewriter.replaceOp(op, *multiIndex);
+ return success();
+ }
+
+ // Vector lowering: emit arith div/rem ops (which work element-wise on
+ // vectors).
+ Location loc = op.getLoc();
+ Value linearIndex = op.getLinearIndex();
+ Type type = linearIndex.getType();
+ SmallVector<OpFoldResult> basis = op.getEffectiveBasis();
+
+ // Compute cumulative products of basis from the right. These serve as
+ // divisors: for basis (B0, B1, B2), the divisors are (B1*B2, B2).
+ SmallVector<Value> divisors;
+ Value cumulativeProd = createTypedConstant(rewriter, loc, type, 1);
+ for (OpFoldResult basisElem : llvm::reverse(basis)) {
+ FailureOr<Value> basisVal =
+ materializeBasis(rewriter, loc, basisElem, type);
+ if (failed(basisVal))
+ return failure();
+ cumulativeProd =
+ arith::MulIOp::create(rewriter, loc, cumulativeProd, *basisVal);
+ divisors.push_back(cumulativeProd);
+ }
+
+ // Emit div/mod pairs from the most-significant dimension to the least.
+ SmallVector<Value> results;
+ results.reserve(divisors.size() + 1);
+ Value residual = linearIndex;
+ for (Value divisor : llvm::reverse(divisors)) {
+ Value quotient = arith::DivSIOp::create(rewriter, loc, residual, divisor);
+ Value product = arith::MulIOp::create(rewriter, loc, quotient, divisor);
+ Value remainder = arith::SubIOp::create(rewriter, loc, residual, product);
+ results.push_back(quotient);
+ residual = remainder;
+ }
+ results.push_back(residual);
+ rewriter.replaceOp(op, results);
return success();
}
};
@@ -58,13 +126,60 @@ struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
return success();
}
- SmallVector<OpFoldResult> multiIndex =
- getAsOpFoldResult(op.getMultiIndex());
- OpFoldResult linearIndex =
- linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
- Value linearIndexValue =
- getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
- rewriter.replaceOp(op, linearIndexValue);
+ // For scalar types, use the existing affine lowering path.
+ if (isa<IndexType>(op.getLinearIndex().getType())) {
+ SmallVector<OpFoldResult> multiIndex =
+ getAsOpFoldResult(op.getMultiIndex());
+ OpFoldResult linearIndex =
+ linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
+ Value linearIndexValue =
+ getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
+ rewriter.replaceOp(op, linearIndexValue);
+ return success();
+ }
+
+ // Vector lowering: emit arith ops (which work element-wise on vectors).
+ //
+ // linearize_index [i0, i1, ..., iN-1] by (B0, B1, ..., BN-1)
+ // = i0 * stride_0 + i1 * stride_1 + ... + iN-1
+ // where stride_k = B_{k+1} * B_{k+2} * ... * B_{N-1}
+ //
+ // We compute from the back: result = iN-1, stride = 1, then:
+ // stride *= B_{k}, result += i_k * stride
+ Location loc = op.getLoc();
+ Type type = op.getLinearIndex().getType();
+ SmallVector<OpFoldResult> effectiveBasis = op.getEffectiveBasis();
+ ValueRange indices = op.getMultiIndex();
+
+ // effectiveBasis drops the outer bound. For indices [i0, i1, ..., iN-1]:
+ // no outer bound: effectiveBasis = [B1, B2, ..., BN-1] (N-1 elems)
+ // has outer bound: effectiveBasis = [B0, B1, ..., BN-1] (N elems,
+ // but B0 is advisory, dropped by getEffectiveBasis)
+ //
+ // Computation: result = iN-1 + BN-1 * (iN-2 + BN-2 * (... + B1 * i0))
+ // Or equivalently, accumulate from back:
+ // result = iN-1
+ // stride = 1
+ // for k = numBasis-1 downto 0:
+ // stride *= effectiveBasis[k]
+ // result += indices[k] * stride
+ //
+ // This works because effectiveBasis[k] is the "size" of dimension k+1,
+ // and indices[k] is paired with the product of all sizes after it.
+ Value result = indices.back();
+ Value stride = createTypedConstant(rewriter, loc, type, 1);
+
+ for (int i = static_cast<int>(effectiveBasis.size()) - 1; i >= 0; --i) {
+ FailureOr<Value> basisVal =
+ materializeBasis(rewriter, loc, effectiveBasis[i], type);
+ if (failed(basisVal))
+ return failure();
+ stride = arith::MulIOp::create(rewriter, loc, stride, *basisVal);
+ Value term = arith::MulIOp::create(rewriter, loc, indices[i], stride);
+ result = arith::AddIOp::create(rewriter, loc, term, result);
+ }
+
+ rewriter.replaceOp(op, result);
return success();
}
};
>From 68c2a367c58b466115430007a16f29d500729bb7 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 24 Mar 2026 23:46:15 +0000
Subject: [PATCH 2/5] [mlir][affine] Use declarative assemblyFormat for
linearize/delinearize index ops
Replace custom parse/print with declarative assemblyFormat using
TypesMatchWith traits for type inference:
- delinearize: infer linear_index type from first result type
- linearize: infer multi_index types from result type
The assembly format is unchanged - no existing .mlir files need updating.
Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
.../mlir/Dialect/Affine/IR/AffineOps.td | 23 +++-
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 108 ------------------
2 files changed, 19 insertions(+), 112 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 8a7e49c05f526..3f59e008b2a7d 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -1069,7 +1069,10 @@ def AffineVectorStoreOp : AffineStoreOpBase<"vector_store"> {
// AffineDelinearizeIndexOp
//===----------------------------------------------------------------------===//
-def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
+def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index",
+ [Pure, TypesMatchWith<"linear_index type must match result types",
+ "multi_index", "linear_index",
+ "$_self[0]">]> {
let summary = "delinearize an index";
let description = [{
The `affine.delinearize_index` operation takes a single index value and
@@ -1129,7 +1132,11 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
DenseI64ArrayAttr:$static_basis);
let results = (outs Variadic<Affine_IndexOrVectorOfIndex>:$multi_index);
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = [{
+ $linear_index `into`
+ custom<DynamicIndexList>($dynamic_basis, $static_basis, "{}", "::mlir::AsmParser::Delimiter::Paren")
+ attr-dict `:` type($multi_index)
+ }];
let builders = [
OpBuilder<(ins "Value":$linear_index, "ValueRange":$dynamic_basis, "ArrayRef<int64_t>":$static_basis, CArg<"bool", "true">:$hasOuterBound)>,
@@ -1169,7 +1176,10 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", [Pure]> {
// AffineLinearizeIndexOp
//===----------------------------------------------------------------------===//
def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
- [Pure, AttrSizedOperandSegments]> {
+ [Pure, AttrSizedOperandSegments,
+ TypesMatchWith<"multi_index types must match result type",
+ "linear_index", "multi_index", "$_self",
+ "[](::mlir::Type a, ::mlir::TypeRange b) { return llvm::all_of(b, [a](::mlir::Type t) { return t == a; }); }">]> {
let summary = "linearize an index";
let description = [{
The `affine.linearize_index` operation takes a sequence of index values and a
@@ -1229,7 +1239,12 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index",
UnitProp:$disjoint);
let results = (outs Affine_IndexOrVectorOfIndex:$linear_index);
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = [{
+ (`disjoint` $disjoint^)? ` `
+ `[` $multi_index `]` `by`
+ custom<DynamicIndexList>($dynamic_basis, $static_basis, "{}", "::mlir::AsmParser::Delimiter::Paren")
+ attr-dict `:` type($linear_index)
+ }];
let builders = [
OpBuilder<(ins "ValueRange":$multi_index, "ValueRange":$basis, CArg<"bool", "false">:$disjoint)>,
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 2060a74b061e7..eacf5a3bb74fb 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -21,7 +21,6 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
-#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
@@ -4856,58 +4855,6 @@ LogicalResult AffineVectorStoreOp::verify() {
// DelinearizeIndexOp
//===----------------------------------------------------------------------===//
-/// Parse format:
-/// affine.delinearize_index %idx into (%c4, %c8)
-/// : index, index (scalar)
-/// affine.delinearize_index %vec into (%c4, %c8)
-/// : vector<16xindex>, vector<16xindex> (vector)
-ParseResult AffineDelinearizeIndexOp::parse(OpAsmParser &parser,
- OperationState &result) {
- OpAsmParser::UnresolvedOperand linearIndex;
- if (parser.parseOperand(linearIndex) || parser.parseKeyword("into"))
- return failure();
-
- SmallVector<OpAsmParser::UnresolvedOperand> dynamicBasis;
- DenseI64ArrayAttr staticBasis;
- if (parseDynamicIndexList(parser, dynamicBasis, staticBasis, nullptr,
- AsmParser::Delimiter::Paren))
- return failure();
-
- if (parser.parseOptionalAttrDict(result.attributes))
- return failure();
-
- if (parser.parseColon())
- return failure();
-
- SmallVector<Type> resultTypes;
- if (parser.parseTypeList(resultTypes))
- return failure();
-
- // Infer the linear index type from the first result type. All types must
- // match (enforced by the verifier).
- Type indexType = resultTypes.empty() ? IndexType::get(parser.getContext())
- : resultTypes.front();
- if (parser.resolveOperand(linearIndex, indexType, result.operands))
- return failure();
- if (parser.resolveOperands(dynamicBasis, IndexType::get(parser.getContext()),
- result.operands))
- return failure();
-
- result.addTypes(resultTypes);
- result.getOrAddProperties<AffineDelinearizeIndexOp::Properties>()
- .static_basis = staticBasis;
- return success();
-}
-
-void AffineDelinearizeIndexOp::print(OpAsmPrinter &p) {
- p << ' ' << getLinearIndex() << " into ";
- printDynamicIndexList(p, *this, getDynamicBasis(), getStaticBasisAttr(),
- /*scalableFlags=*/{}, AsmParser::Delimiter::Paren);
- p.printOptionalAttrDict((*this)->getAttrs(), {getStaticBasisAttrName()});
- p << " : ";
- llvm::interleaveComma(getResultTypes(), p);
-}
-
void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
OperationState &odsState,
Value linearIndex, ValueRange dynamicBasis,
@@ -5305,61 +5252,6 @@ void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
// LinearizeIndexOp
//===----------------------------------------------------------------------===//
-/// Parse format:
-/// affine.linearize_index [%x, %y] by (%c4, %c8) : index
-/// affine.linearize_index disjoint [%v0, %v1] by (%c4, %c8)
-/// : vector<16xindex>
-ParseResult AffineLinearizeIndexOp::parse(OpAsmParser &parser,
- OperationState &result) {
- bool disjoint = succeeded(parser.parseOptionalKeyword("disjoint"));
-
- SmallVector<OpAsmParser::UnresolvedOperand> multiIndex;
- if (parser.parseOperandList(multiIndex, AsmParser::Delimiter::Square) ||
- parser.parseKeyword("by"))
- return failure();
-
- SmallVector<OpAsmParser::UnresolvedOperand> dynamicBasis;
- DenseI64ArrayAttr staticBasis;
- if (parseDynamicIndexList(parser, dynamicBasis, staticBasis, nullptr,
- AsmParser::Delimiter::Paren))
- return failure();
-
- if (parser.parseOptionalAttrDict(result.attributes))
- return failure();
-
- Type resultType;
- if (parser.parseColonType(resultType))
- return failure();
-
- if (parser.resolveOperands(multiIndex, resultType, result.operands))
- return failure();
- if (parser.resolveOperands(dynamicBasis, IndexType::get(parser.getContext()),
- result.operands))
- return failure();
-
- result.addTypes(resultType);
- auto &props = result.getOrAddProperties<AffineLinearizeIndexOp::Properties>();
- props.static_basis = staticBasis;
- props.disjoint = disjoint;
- props.operandSegmentSizes = {static_cast<int32_t>(multiIndex.size()),
- static_cast<int32_t>(dynamicBasis.size())};
- return success();
-}
-
-void AffineLinearizeIndexOp::print(OpAsmPrinter &p) {
- if (getDisjoint())
- p << " disjoint";
- p << " [";
- llvm::interleaveComma(getMultiIndex(), p);
- p << "] by ";
- printDynamicIndexList(p, *this, getDynamicBasis(), getStaticBasisAttr(),
- /*scalableFlags=*/{}, AsmParser::Delimiter::Paren);
- p.printOptionalAttrDict(
- (*this)->getAttrs(),
- {getStaticBasisAttrName(), getOperandSegmentSizesAttrName()});
- p << " : " << getLinearIndex().getType();
-}
-
/// Infer the index type from a set of multi-index values. Returns the common
/// type (index or vector<...xindex>), or IndexType if the set is empty.
static Type inferIndexType(MLIRContext *ctx, ValueRange multiIndex) {
>From fd2b27f247a0f8845ec45855a3861cc7ff8aa439 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 24 Mar 2026 23:57:25 +0000
Subject: [PATCH 3/5] [mlir][affine] Add lit tests for vector
linearize/delinearize index ops
Test coverage for vector support added in the previous commit:
- Roundtrip: 2D/3D static basis, dynamic basis, disjoint flag
- Canonicalization: linearize/delinearize cancellation, unit-extent
dropping, single-result/single-index folding
- Lowering: vector delinearize to arith div/rem, vector linearize
to arith mul/add, 3D case, scalar path unchanged
- Linearize->offset->delinearize pattern (vector.gather use case)
Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
mlir/test/Dialect/Affine/ops.mlir | 42 ++++++
.../test/Dialect/Affine/vector-index-ops.mlir | 128 ++++++++++++++++++
2 files changed, 170 insertions(+)
create mode 100644 mlir/test/Dialect/Affine/vector-index-ops.mlir
diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir
index 1562f5b1693c0..35b07c1c7fe1f 100644
--- a/mlir/test/Dialect/Affine/ops.mlir
+++ b/mlir/test/Dialect/Affine/ops.mlir
@@ -316,6 +316,48 @@ func.func @linearize_mixed(%index0: index, %index1: index, %index2: index, %basi
return %1 : index
}
+// CHECK-LABEL: @delinearize_vector
+func.func @delinearize_vector(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>) {
+ // CHECK: affine.delinearize_index %{{.+}} into (4, 8) : vector<16xindex>, vector<16xindex>
+ %0:2 = affine.delinearize_index %vec into (4, 8) : vector<16xindex>, vector<16xindex>
+ return %0#0, %0#1 : vector<16xindex>, vector<16xindex>
+}
+
+// CHECK-LABEL: @delinearize_vector_3d
+func.func @delinearize_vector_3d(%vec: vector<8xindex>) -> (vector<8xindex>, vector<8xindex>, vector<8xindex>) {
+ // CHECK: affine.delinearize_index %{{.+}} into (2, 3, 4) : vector<8xindex>, vector<8xindex>, vector<8xindex>
+ %0:3 = affine.delinearize_index %vec into (2, 3, 4) : vector<8xindex>, vector<8xindex>, vector<8xindex>
+ return %0#0, %0#1, %0#2 : vector<8xindex>, vector<8xindex>, vector<8xindex>
+}
+
+// CHECK-LABEL: @delinearize_vector_dynamic_basis
+func.func @delinearize_vector_dynamic_basis(%vec: vector<4xindex>, %b0: index, %b1: index) -> (vector<4xindex>, vector<4xindex>) {
+ // CHECK: affine.delinearize_index %{{.+}} into (%{{.+}}, %{{.+}}) : vector<4xindex>, vector<4xindex>
+ %0:2 = affine.delinearize_index %vec into (%b0, %b1) : vector<4xindex>, vector<4xindex>
+ return %0#0, %0#1 : vector<4xindex>, vector<4xindex>
+}
+
+// CHECK-LABEL: @linearize_vector
+func.func @linearize_vector(%v0: vector<16xindex>, %v1: vector<16xindex>) -> vector<16xindex> {
+ // CHECK: affine.linearize_index [%{{.+}}, %{{.+}}] by (4, 8) : vector<16xindex>
+ %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<16xindex>
+ return %0 : vector<16xindex>
+}
+
+// CHECK-LABEL: @linearize_vector_disjoint
+func.func @linearize_vector_disjoint(%v0: vector<16xindex>, %v1: vector<16xindex>) -> vector<16xindex> {
+ // CHECK: affine.linearize_index disjoint [%{{.+}}, %{{.+}}] by (4, 8) : vector<16xindex>
+ %0 = affine.linearize_index disjoint [%v0, %v1] by (4, 8) : vector<16xindex>
+ return %0 : vector<16xindex>
+}
+
+// CHECK-LABEL: @linearize_vector_3d
+func.func @linearize_vector_3d(%v0: vector<8xindex>, %v1: vector<8xindex>, %v2: vector<8xindex>) -> vector<8xindex> {
+ // CHECK: affine.linearize_index [%{{.+}}, %{{.+}}, %{{.+}}] by (2, 3, 4) : vector<8xindex>
+ %0 = affine.linearize_index [%v0, %v1, %v2] by (2, 3, 4) : vector<8xindex>
+ return %0 : vector<8xindex>
+}
+
// -----
// CHECK-LABEL: @gpu_launch_affine
diff --git a/mlir/test/Dialect/Affine/vector-index-ops.mlir b/mlir/test/Dialect/Affine/vector-index-ops.mlir
new file mode 100644
index 0000000000000..2c6fbc20a9316
--- /dev/null
+++ b/mlir/test/Dialect/Affine/vector-index-ops.mlir
@@ -0,0 +1,128 @@
+// RUN: mlir-opt %s -split-input-file --canonicalize | FileCheck %s --check-prefix=CANON
+// RUN: mlir-opt %s -split-input-file --affine-expand-index-ops-as-affine | FileCheck %s --check-prefix=EXPAND
+
+// Canonicalization: cancel disjoint linearize -> delinearize on vectors.
+
+// CANON-LABEL: @cancel_linearize_delinearize_vector
+// CANON-SAME: (%[[V0:.+]]: vector<16xindex>, %[[V1:.+]]: vector<16xindex>)
+// CANON: return %[[V0]], %[[V1]]
+func.func @cancel_linearize_delinearize_vector(%v0: vector<16xindex>, %v1: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>) {
+ %0 = affine.linearize_index disjoint [%v0, %v1] by (4, 8) : vector<16xindex>
+ %1:2 = affine.delinearize_index %0 into (4, 8) : vector<16xindex>, vector<16xindex>
+ return %1#0, %1#1 : vector<16xindex>, vector<16xindex>
+}
+
+// -----
+// Canonicalization: drop unit-extent basis on vector delinearize.
+
+// CANON-LABEL: @drop_unit_extent_vector
+// CANON-SAME: (%[[VEC:.+]]: vector<16xindex>)
+// CANON-DAG: %[[ZERO:.+]] = arith.constant dense<0> : vector<16xindex>
+// CANON: %[[R:.+]]:2 = affine.delinearize_index %[[VEC]] into (4, 8) : vector<16xindex>, vector<16xindex>
+// CANON: return %[[R]]#0, %[[ZERO]], %[[R]]#1
+func.func @drop_unit_extent_vector(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>, vector<16xindex>) {
+ %0:3 = affine.delinearize_index %vec into (4, 1, 8) : vector<16xindex>, vector<16xindex>, vector<16xindex>
+ return %0#0, %0#1, %0#2 : vector<16xindex>, vector<16xindex>, vector<16xindex>
+}
+
+// -----
+// Canonicalization: drop unit-extent basis on vector linearize.
+
+// CANON-LABEL: @drop_unit_linearize_vector
+// CANON-SAME: (%[[V0:.+]]: vector<8xindex>, %[[V1:.+]]: vector<8xindex>)
+// CANON: return %[[V0]]
+func.func @drop_unit_linearize_vector(%v0: vector<8xindex>, %v1: vector<8xindex>) -> vector<8xindex> {
+ %0 = affine.linearize_index disjoint [%v0, %v1] by (4, 1) : vector<8xindex>
+ return %0 : vector<8xindex>
+}
+
+// -----
+// Canonicalization: fold single-result vector delinearize to identity.
+
+// CANON-LABEL: @fold_single_result_vector
+// CANON-SAME: (%[[VEC:.+]]: vector<4xindex>)
+// CANON: return %[[VEC]]
+func.func @fold_single_result_vector(%vec: vector<4xindex>) -> vector<4xindex> {
+ %0:1 = affine.delinearize_index %vec into () : vector<4xindex>
+ return %0#0 : vector<4xindex>
+}
+
+// -----
+// Canonicalization: fold single-index vector linearize to identity.
+
+// CANON-LABEL: @fold_single_index_vector
+// CANON-SAME: (%[[VEC:.+]]: vector<4xindex>)
+// CANON: return %[[VEC]]
+func.func @fold_single_index_vector(%vec: vector<4xindex>) -> vector<4xindex> {
+ %0 = affine.linearize_index [%vec] by () : vector<4xindex>
+ return %0 : vector<4xindex>
+}
+
+// -----
+// Expansion: vector delinearize lowers to arith div/rem.
+
+// EXPAND-LABEL: @expand_delinearize_vector
+// EXPAND-SAME: (%[[VEC:.+]]: vector<16xindex>)
+// EXPAND-DAG: %[[C8:.+]] = arith.constant dense<8> : vector<16xindex>
+// EXPAND: %[[DIV:.+]] = arith.divsi %[[VEC]], %[[C8]]
+// EXPAND: %[[MUL:.+]] = arith.muli %[[DIV]], %[[C8]]
+// EXPAND: %[[REM:.+]] = arith.subi %[[VEC]], %[[MUL]]
+// EXPAND: return %[[DIV]], %[[REM]]
+func.func @expand_delinearize_vector(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>) {
+ %0:2 = affine.delinearize_index %vec into (4, 8) : vector<16xindex>, vector<16xindex>
+ return %0#0, %0#1 : vector<16xindex>, vector<16xindex>
+}
+
+// -----
+// Expansion: vector linearize lowers to arith mul/add.
+
+// EXPAND-LABEL: @expand_linearize_vector
+// EXPAND-SAME: (%[[V0:.+]]: vector<16xindex>, %[[V1:.+]]: vector<16xindex>)
+// EXPAND-DAG: %[[C8:.+]] = arith.constant dense<8> : vector<16xindex>
+// EXPAND: %[[MUL:.+]] = arith.muli %[[V0]], %[[C8]]
+// EXPAND: %[[ADD:.+]] = arith.addi %[[MUL]], %[[V1]]
+// EXPAND: return %[[ADD]]
+func.func @expand_linearize_vector(%v0: vector<16xindex>, %v1: vector<16xindex>) -> vector<16xindex> {
+ %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<16xindex>
+ return %0 : vector<16xindex>
+}
+
+// -----
+// Expansion: 3D vector delinearize.
+
+// EXPAND-LABEL: @expand_delinearize_vector_3d
+// EXPAND-SAME: (%[[VEC:.+]]: vector<16xindex>)
+// EXPAND-DAG: %[[C4:.+]] = arith.constant dense<4> : vector<16xindex>
+// EXPAND-DAG: %[[C12:.+]] = arith.constant dense<12> : vector<16xindex>
+// EXPAND: %[[D0:.+]] = arith.divsi %[[VEC]], %[[C12]]
+// EXPAND: %[[M0:.+]] = arith.muli %[[D0]], %[[C12]]
+// EXPAND: %[[R0:.+]] = arith.subi %[[VEC]], %[[M0]]
+// EXPAND: %[[D1:.+]] = arith.divsi %[[R0]], %[[C4]]
+// EXPAND: %[[M1:.+]] = arith.muli %[[D1]], %[[C4]]
+// EXPAND: %[[R1:.+]] = arith.subi %[[R0]], %[[M1]]
+// EXPAND: return %[[D0]], %[[D1]], %[[R1]]
+func.func @expand_delinearize_vector_3d(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>, vector<16xindex>) {
+ %0:3 = affine.delinearize_index %vec into (2, 3, 4) : vector<16xindex>, vector<16xindex>, vector<16xindex>
+ return %0#0, %0#1, %0#2 : vector<16xindex>, vector<16xindex>, vector<16xindex>
+}
+
+// -----
+// Expansion: vector linearize -> offset -> delinearize pattern
+// (as would be used in vector.gather lowering).
+
+// EXPAND-LABEL: @vector_linearize_offset_delinearize
+// EXPAND-SAME: (%[[V0:.+]]: vector<4xindex>, %[[V1:.+]]: vector<4xindex>, %[[OFF:.+]]: vector<4xindex>)
+// EXPAND-DAG: %[[C8:.+]] = arith.constant dense<8> : vector<4xindex>
+// EXPAND: %[[LIN:.+]] = arith.muli %[[V0]], %[[C8]]
+// EXPAND: %[[LIN2:.+]] = arith.addi %[[LIN]], %[[V1]]
+// EXPAND: %[[FLAT:.+]] = arith.addi %[[LIN2]], %[[OFF]]
+// EXPAND: %[[DIV:.+]] = arith.divsi %[[FLAT]], %[[C8]]
+// EXPAND: %[[MUL:.+]] = arith.muli %[[DIV]], %[[C8]]
+// EXPAND: %[[REM:.+]] = arith.subi %[[FLAT]], %[[MUL]]
+// EXPAND: return %[[DIV]], %[[REM]]
+func.func @vector_linearize_offset_delinearize(%v0: vector<4xindex>, %v1: vector<4xindex>, %offsets: vector<4xindex>) -> (vector<4xindex>, vector<4xindex>) {
+ %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<4xindex>
+ %1 = arith.addi %0, %offsets : vector<4xindex>
+ %2:2 = affine.delinearize_index %1 into (4, 8) : vector<4xindex>, vector<4xindex>
+ return %2#0, %2#1 : vector<4xindex>, vector<4xindex>
+}
>From 554a15883d096fceb858e5d851e41772fc6ace2c Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Wed, 25 Mar 2026 06:45:38 +0000
Subject: [PATCH 4/5] [mlir][affine] Add lit tests for vector
linearize/delinearize index ops
Add vector test cases to existing test files:
- ops.mlir: roundtrip tests (2D/3D, dynamic basis, disjoint)
- canonicalize.mlir: cancel, unit-extent drop, single-result fold
- affine-expand-index-ops-as-affine.mlir: lowering to arith ops,
3D case, linearize->offset->delinearize pattern
Co-authored-by: Claude Opus 4.6 <noreply at anthropic.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
.../affine-expand-index-ops-as-affine.mlir | 67 +++++++++
mlir/test/Dialect/Affine/canonicalize.mlir | 53 ++++++++
.../test/Dialect/Affine/vector-index-ops.mlir | 128 ------------------
3 files changed, 120 insertions(+), 128 deletions(-)
delete mode 100644 mlir/test/Dialect/Affine/vector-index-ops.mlir
diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
index bf9f00da5793a..21595356936fa 100644
--- a/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
+++ b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir
@@ -68,3 +68,70 @@ func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: in
%0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index
func.return %0 : index
}
+
+// -----
+
+// CHECK-LABEL: @expand_delinearize_vector
+// CHECK-SAME: (%[[VEC:.+]]: vector<16xindex>)
+// CHECK-DAG: %[[C8:.+]] = arith.constant dense<8> : vector<16xindex>
+// CHECK: %[[DIV:.+]] = arith.divsi %[[VEC]], %[[C8]]
+// CHECK: %[[MUL:.+]] = arith.muli %[[DIV]], %[[C8]]
+// CHECK: %[[REM:.+]] = arith.subi %[[VEC]], %[[MUL]]
+// CHECK: return %[[DIV]], %[[REM]]
+func.func @expand_delinearize_vector(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>) {
+ %0:2 = affine.delinearize_index %vec into (4, 8) : vector<16xindex>, vector<16xindex>
+ return %0#0, %0#1 : vector<16xindex>, vector<16xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @expand_linearize_vector
+// CHECK-SAME: (%[[V0:.+]]: vector<16xindex>, %[[V1:.+]]: vector<16xindex>)
+// CHECK-DAG: %[[C8:.+]] = arith.constant dense<8> : vector<16xindex>
+// CHECK: %[[MUL:.+]] = arith.muli %[[V0]], %[[C8]]
+// CHECK: %[[ADD:.+]] = arith.addi %[[MUL]], %[[V1]]
+// CHECK: return %[[ADD]]
+func.func @expand_linearize_vector(%v0: vector<16xindex>, %v1: vector<16xindex>) -> vector<16xindex> {
+ %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<16xindex>
+ return %0 : vector<16xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @expand_delinearize_vector_3d
+// CHECK-SAME: (%[[VEC:.+]]: vector<16xindex>)
+// CHECK-DAG: %[[C4:.+]] = arith.constant dense<4> : vector<16xindex>
+// CHECK-DAG: %[[C12:.+]] = arith.constant dense<12> : vector<16xindex>
+// CHECK: %[[D0:.+]] = arith.divsi %[[VEC]], %[[C12]]
+// CHECK: %[[M0:.+]] = arith.muli %[[D0]], %[[C12]]
+// CHECK: %[[R0:.+]] = arith.subi %[[VEC]], %[[M0]]
+// CHECK: %[[D1:.+]] = arith.divsi %[[R0]], %[[C4]]
+// CHECK: %[[M1:.+]] = arith.muli %[[D1]], %[[C4]]
+// CHECK: %[[R1:.+]] = arith.subi %[[R0]], %[[M1]]
+// CHECK: return %[[D0]], %[[D1]], %[[R1]]
+func.func @expand_delinearize_vector_3d(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>, vector<16xindex>) {
+ %0:3 = affine.delinearize_index %vec into (2, 3, 4) : vector<16xindex>, vector<16xindex>, vector<16xindex>
+ return %0#0, %0#1, %0#2 : vector<16xindex>, vector<16xindex>, vector<16xindex>
+}
+
+// -----
+
+// Vector linearize -> offset -> delinearize pattern
+// (as would be used in vector.gather lowering).
+
+// CHECK-LABEL: @vector_linearize_offset_delinearize
+// CHECK-SAME: (%[[V0:.+]]: vector<4xindex>, %[[V1:.+]]: vector<4xindex>, %[[OFF:.+]]: vector<4xindex>)
+// CHECK-DAG: %[[C8:.+]] = arith.constant dense<8> : vector<4xindex>
+// CHECK: %[[LIN:.+]] = arith.muli %[[V0]], %[[C8]]
+// CHECK: %[[LIN2:.+]] = arith.addi %[[LIN]], %[[V1]]
+// CHECK: %[[FLAT:.+]] = arith.addi %[[LIN2]], %[[OFF]]
+// CHECK: %[[DIV:.+]] = arith.divsi %[[FLAT]], %[[C8]]
+// CHECK: %[[MUL:.+]] = arith.muli %[[DIV]], %[[C8]]
+// CHECK: %[[REM:.+]] = arith.subi %[[FLAT]], %[[MUL]]
+// CHECK: return %[[DIV]], %[[REM]]
+func.func @vector_linearize_offset_delinearize(%v0: vector<4xindex>, %v1: vector<4xindex>, %offsets: vector<4xindex>) -> (vector<4xindex>, vector<4xindex>) {
+ %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<4xindex>
+ %1 = arith.addi %0, %offsets : vector<4xindex>
+ %2:2 = affine.delinearize_index %1 into (4, 8) : vector<4xindex>, vector<4xindex>
+ return %2#0, %2#1 : vector<4xindex>, vector<4xindex>
+}
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 5a0a2b004433e..008c4babb9408 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -2429,3 +2429,56 @@ func.func @linearize_dont_fold_poison_index(%arg0: index) -> index {
%ret = affine.linearize_index [%poison, %arg0] by (%c4) : index
return %ret : index
}
+
+// -----
+
+// CHECK-LABEL: @cancel_linearize_delinearize_vector
+// CHECK-SAME: (%[[V0:.+]]: vector<16xindex>, %[[V1:.+]]: vector<16xindex>)
+// CHECK: return %[[V0]], %[[V1]]
+func.func @cancel_linearize_delinearize_vector(%v0: vector<16xindex>, %v1: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>) {
+ %0 = affine.linearize_index disjoint [%v0, %v1] by (4, 8) : vector<16xindex>
+ %1:2 = affine.delinearize_index %0 into (4, 8) : vector<16xindex>, vector<16xindex>
+ return %1#0, %1#1 : vector<16xindex>, vector<16xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @drop_unit_extent_vector
+// CHECK-SAME: (%[[VEC:.+]]: vector<16xindex>)
+// CHECK-DAG: %[[ZERO:.+]] = arith.constant dense<0> : vector<16xindex>
+// CHECK: %[[R:.+]]:2 = affine.delinearize_index %[[VEC]] into (4, 8) : vector<16xindex>, vector<16xindex>
+// CHECK: return %[[R]]#0, %[[ZERO]], %[[R]]#1
+func.func @drop_unit_extent_vector(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>, vector<16xindex>) {
+ %0:3 = affine.delinearize_index %vec into (4, 1, 8) : vector<16xindex>, vector<16xindex>, vector<16xindex>
+ return %0#0, %0#1, %0#2 : vector<16xindex>, vector<16xindex>, vector<16xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @drop_unit_linearize_vector
+// CHECK-SAME: (%[[V0:.+]]: vector<8xindex>, %[[V1:.+]]: vector<8xindex>)
+// CHECK: return %[[V0]]
+func.func @drop_unit_linearize_vector(%v0: vector<8xindex>, %v1: vector<8xindex>) -> vector<8xindex> {
+ %0 = affine.linearize_index disjoint [%v0, %v1] by (4, 1) : vector<8xindex>
+ return %0 : vector<8xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_single_result_vector
+// CHECK-SAME: (%[[VEC:.+]]: vector<4xindex>)
+// CHECK: return %[[VEC]]
+func.func @fold_single_result_vector(%vec: vector<4xindex>) -> vector<4xindex> {
+ %0:1 = affine.delinearize_index %vec into () : vector<4xindex>
+ return %0#0 : vector<4xindex>
+}
+
+// -----
+
+// CHECK-LABEL: @fold_single_index_vector
+// CHECK-SAME: (%[[VEC:.+]]: vector<4xindex>)
+// CHECK: return %[[VEC]]
+func.func @fold_single_index_vector(%vec: vector<4xindex>) -> vector<4xindex> {
+ %0 = affine.linearize_index [%vec] by () : vector<4xindex>
+ return %0 : vector<4xindex>
+}
diff --git a/mlir/test/Dialect/Affine/vector-index-ops.mlir b/mlir/test/Dialect/Affine/vector-index-ops.mlir
deleted file mode 100644
index 2c6fbc20a9316..0000000000000
--- a/mlir/test/Dialect/Affine/vector-index-ops.mlir
+++ /dev/null
@@ -1,128 +0,0 @@
-// RUN: mlir-opt %s -split-input-file --canonicalize | FileCheck %s --check-prefix=CANON
-// RUN: mlir-opt %s -split-input-file --affine-expand-index-ops-as-affine | FileCheck %s --check-prefix=EXPAND
-
-// Canonicalization: cancel disjoint linearize -> delinearize on vectors.
-
-// CANON-LABEL: @cancel_linearize_delinearize_vector
-// CANON-SAME: (%[[V0:.+]]: vector<16xindex>, %[[V1:.+]]: vector<16xindex>)
-// CANON: return %[[V0]], %[[V1]]
-func.func @cancel_linearize_delinearize_vector(%v0: vector<16xindex>, %v1: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>) {
- %0 = affine.linearize_index disjoint [%v0, %v1] by (4, 8) : vector<16xindex>
- %1:2 = affine.delinearize_index %0 into (4, 8) : vector<16xindex>, vector<16xindex>
- return %1#0, %1#1 : vector<16xindex>, vector<16xindex>
-}
-
-// -----
-// Canonicalization: drop unit-extent basis on vector delinearize.
-
-// CANON-LABEL: @drop_unit_extent_vector
-// CANON-SAME: (%[[VEC:.+]]: vector<16xindex>)
-// CANON-DAG: %[[ZERO:.+]] = arith.constant dense<0> : vector<16xindex>
-// CANON: %[[R:.+]]:2 = affine.delinearize_index %[[VEC]] into (4, 8) : vector<16xindex>, vector<16xindex>
-// CANON: return %[[R]]#0, %[[ZERO]], %[[R]]#1
-func.func @drop_unit_extent_vector(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>, vector<16xindex>) {
- %0:3 = affine.delinearize_index %vec into (4, 1, 8) : vector<16xindex>, vector<16xindex>, vector<16xindex>
- return %0#0, %0#1, %0#2 : vector<16xindex>, vector<16xindex>, vector<16xindex>
-}
-
-// -----
-// Canonicalization: drop unit-extent basis on vector linearize.
-
-// CANON-LABEL: @drop_unit_linearize_vector
-// CANON-SAME: (%[[V0:.+]]: vector<8xindex>, %[[V1:.+]]: vector<8xindex>)
-// CANON: return %[[V0]]
-func.func @drop_unit_linearize_vector(%v0: vector<8xindex>, %v1: vector<8xindex>) -> vector<8xindex> {
- %0 = affine.linearize_index disjoint [%v0, %v1] by (4, 1) : vector<8xindex>
- return %0 : vector<8xindex>
-}
-
-// -----
-// Canonicalization: fold single-result vector delinearize to identity.
-
-// CANON-LABEL: @fold_single_result_vector
-// CANON-SAME: (%[[VEC:.+]]: vector<4xindex>)
-// CANON: return %[[VEC]]
-func.func @fold_single_result_vector(%vec: vector<4xindex>) -> vector<4xindex> {
- %0:1 = affine.delinearize_index %vec into () : vector<4xindex>
- return %0#0 : vector<4xindex>
-}
-
-// -----
-// Canonicalization: fold single-index vector linearize to identity.
-
-// CANON-LABEL: @fold_single_index_vector
-// CANON-SAME: (%[[VEC:.+]]: vector<4xindex>)
-// CANON: return %[[VEC]]
-func.func @fold_single_index_vector(%vec: vector<4xindex>) -> vector<4xindex> {
- %0 = affine.linearize_index [%vec] by () : vector<4xindex>
- return %0 : vector<4xindex>
-}
-
-// -----
-// Expansion: vector delinearize lowers to arith div/rem.
-
-// EXPAND-LABEL: @expand_delinearize_vector
-// EXPAND-SAME: (%[[VEC:.+]]: vector<16xindex>)
-// EXPAND-DAG: %[[C8:.+]] = arith.constant dense<8> : vector<16xindex>
-// EXPAND: %[[DIV:.+]] = arith.divsi %[[VEC]], %[[C8]]
-// EXPAND: %[[MUL:.+]] = arith.muli %[[DIV]], %[[C8]]
-// EXPAND: %[[REM:.+]] = arith.subi %[[VEC]], %[[MUL]]
-// EXPAND: return %[[DIV]], %[[REM]]
-func.func @expand_delinearize_vector(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>) {
- %0:2 = affine.delinearize_index %vec into (4, 8) : vector<16xindex>, vector<16xindex>
- return %0#0, %0#1 : vector<16xindex>, vector<16xindex>
-}
-
-// -----
-// Expansion: vector linearize lowers to arith mul/add.
-
-// EXPAND-LABEL: @expand_linearize_vector
-// EXPAND-SAME: (%[[V0:.+]]: vector<16xindex>, %[[V1:.+]]: vector<16xindex>)
-// EXPAND-DAG: %[[C8:.+]] = arith.constant dense<8> : vector<16xindex>
-// EXPAND: %[[MUL:.+]] = arith.muli %[[V0]], %[[C8]]
-// EXPAND: %[[ADD:.+]] = arith.addi %[[MUL]], %[[V1]]
-// EXPAND: return %[[ADD]]
-func.func @expand_linearize_vector(%v0: vector<16xindex>, %v1: vector<16xindex>) -> vector<16xindex> {
- %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<16xindex>
- return %0 : vector<16xindex>
-}
-
-// -----
-// Expansion: 3D vector delinearize.
-
-// EXPAND-LABEL: @expand_delinearize_vector_3d
-// EXPAND-SAME: (%[[VEC:.+]]: vector<16xindex>)
-// EXPAND-DAG: %[[C4:.+]] = arith.constant dense<4> : vector<16xindex>
-// EXPAND-DAG: %[[C12:.+]] = arith.constant dense<12> : vector<16xindex>
-// EXPAND: %[[D0:.+]] = arith.divsi %[[VEC]], %[[C12]]
-// EXPAND: %[[M0:.+]] = arith.muli %[[D0]], %[[C12]]
-// EXPAND: %[[R0:.+]] = arith.subi %[[VEC]], %[[M0]]
-// EXPAND: %[[D1:.+]] = arith.divsi %[[R0]], %[[C4]]
-// EXPAND: %[[M1:.+]] = arith.muli %[[D1]], %[[C4]]
-// EXPAND: %[[R1:.+]] = arith.subi %[[R0]], %[[M1]]
-// EXPAND: return %[[D0]], %[[D1]], %[[R1]]
-func.func @expand_delinearize_vector_3d(%vec: vector<16xindex>) -> (vector<16xindex>, vector<16xindex>, vector<16xindex>) {
- %0:3 = affine.delinearize_index %vec into (2, 3, 4) : vector<16xindex>, vector<16xindex>, vector<16xindex>
- return %0#0, %0#1, %0#2 : vector<16xindex>, vector<16xindex>, vector<16xindex>
-}
-
-// -----
-// Expansion: vector linearize -> offset -> delinearize pattern
-// (as would be used in vector.gather lowering).
-
-// EXPAND-LABEL: @vector_linearize_offset_delinearize
-// EXPAND-SAME: (%[[V0:.+]]: vector<4xindex>, %[[V1:.+]]: vector<4xindex>, %[[OFF:.+]]: vector<4xindex>)
-// EXPAND-DAG: %[[C8:.+]] = arith.constant dense<8> : vector<4xindex>
-// EXPAND: %[[LIN:.+]] = arith.muli %[[V0]], %[[C8]]
-// EXPAND: %[[LIN2:.+]] = arith.addi %[[LIN]], %[[V1]]
-// EXPAND: %[[FLAT:.+]] = arith.addi %[[LIN2]], %[[OFF]]
-// EXPAND: %[[DIV:.+]] = arith.divsi %[[FLAT]], %[[C8]]
-// EXPAND: %[[MUL:.+]] = arith.muli %[[DIV]], %[[C8]]
-// EXPAND: %[[REM:.+]] = arith.subi %[[FLAT]], %[[MUL]]
-// EXPAND: return %[[DIV]], %[[REM]]
-func.func @vector_linearize_offset_delinearize(%v0: vector<4xindex>, %v1: vector<4xindex>, %offsets: vector<4xindex>) -> (vector<4xindex>, vector<4xindex>) {
- %0 = affine.linearize_index [%v0, %v1] by (4, 8) : vector<4xindex>
- %1 = arith.addi %0, %offsets : vector<4xindex>
- %2:2 = affine.delinearize_index %1 into (4, 8) : vector<4xindex>, vector<4xindex>
- return %2#0, %2#1 : vector<4xindex>, vector<4xindex>
-}
>From 1393eb0291ea07f5a97d34df8a9df85202f2ac09 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Wed, 25 Mar 2026 16:08:48 +0000
Subject: [PATCH 5/5] Bracketed conditionals
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index eacf5a3bb74fb..92f2015022120 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -5048,12 +5048,13 @@ struct DropUnitExtentBasis
auto getZero = [&]() -> Value {
if (!zero) {
Value scalarZero = arith::ConstantIndexOp::create(rewriter, loc, 0);
- if (auto vecTy = dyn_cast<VectorType>(indexType))
+ if (auto vecTy = dyn_cast<VectorType>(indexType)) {
zero = arith::ConstantOp::create(
rewriter, loc,
DenseElementsAttr::get(vecTy, rewriter.getIndexAttr(0)));
- else
+ } else {
zero = scalarZero;
+ }
}
return zero.value();
};
More information about the Mlir-commits
mailing list