[Mlir-commits] [mlir] [mlir][vector] Lower vector.gather with delinearization approach (PR #184706)
Andrzej Warzyński
llvmlistbot at llvm.org
Tue Mar 17 05:45:58 PDT 2026
================
@@ -183,22 +191,39 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value condMask = op.getMask();
Value base = op.getBase();
- // vector.load requires the most minor memref dim to have unit stride
- // (unless reading exactly 1 element)
+ // For multi-dimensional memrefs, use linearize+delinearize to compute
+ // correct N-D load indices from the 1-D gather index.
+ bool useDelinearization = false;
if (auto memType = dyn_cast<MemRefType>(base.getType())) {
+ // vector.load requires the most minor memref dim to have unit stride
+ // (unless reading exactly 1 element).
if (auto stridesAttr =
dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
if (stridesAttr.getStrides().back() != 1 &&
resultTy.getNumElements() != 1)
- return failure();
+ return rewriter.notifyMatchFailure(
+ op, "most minor memref dim must have unit stride");
}
+
+ if (memType.getRank() > 1)
+ useDelinearization = true;
}
Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
op.getIndices());
- auto baseOffsets = llvm::to_vector(op.getOffsets());
- Value lastBaseOffset = baseOffsets.back();
+ auto loadOffsets = llvm::to_vector(op.getOffsets());
+ Value lastLoadOffset = loadOffsets.back();
+
+ // Compute the memref shape and linearized offsets once, outside the
+ // per-element loop.
+ SmallVector<OpFoldResult> baseShape;
+ Value linearizedOffsets;
+ if (useDelinearization) {
+ baseShape = memref::getMixedSizes(rewriter, loc, base);
+ linearizedOffsets = affine::AffineLinearizeIndexOp::create(
----------------
banach-space wrote:
> - the vector.gather to LLVM lowering - doesn't try do any such delinearizations / stride application.
Thanks for pointing that out - I agree this should be fixed.
> Second, to quote that documentation update itself
I don’t think that’s the case - my interpretation is slightly different. Let me break the wording down:
The part
* _“If the resulting position exceeds the size of a dimension, it naturally advances into the next row and/or plane according to the identity (row-major) layout of base.”_ is precisely what the delinearization is modeling.
The part
* _“indices are interpreted assuming an identity (contiguous) MemRef layout and do not account for non-identity strides.”_ is also respected: the computation ignores MemRef strides and operates purely on a contiguous (row-major) view.
So while the mechanism involves delinearization, it is still consistent with the documented semantics.
As a side note, to keep https://github.com/llvm/llvm-project/pull/181357 focused on the agreed RFC semantics, I’ve removed that specific wording change there - let’s continue this part of the discussion here.
> Third, if we look at the pseudocode in the gather documentation, it also implies no delinearization like this
I agree the pseudocode suggests that, but it is incomplete: it does not specify what happens when the index exceeds the size of the last dimension.
The delinearization is one way to make that behaviour explicit and consistent with row-major traversal. I see this PR as filling in that missing detail rather than contradicting the existing documentation.
> I don't think having complex delinearization logic like that showing up in LLVM lowering is the right way to go, and I'd like to go to k-tuples of indices instead
That’s a reasonable direction, but it would require changing the op’s definition. The approach here works within the current constraints of `vector.gather` (single index), and provides a consistent lowering for those semantics.
https://github.com/llvm/llvm-project/pull/184706
More information about the Mlir-commits
mailing list