[Mlir-commits] [mlir] [mlir][vector] Lower vector.gather with delinearization approach (PR #184706)
Krzysztof Drewniak
llvmlistbot at llvm.org
Mon Mar 16 15:02:31 PDT 2026
================
@@ -183,22 +191,39 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value condMask = op.getMask();
Value base = op.getBase();
- // vector.load requires the most minor memref dim to have unit stride
- // (unless reading exactly 1 element)
+ // For multi-dimensional memrefs, use linearize+delinearize to compute
+ // correct N-D load indices from the 1-D gather index.
+ bool useDelinearization = false;
if (auto memType = dyn_cast<MemRefType>(base.getType())) {
+ // vector.load requires the most minor memref dim to have unit stride
+ // (unless reading exactly 1 element).
if (auto stridesAttr =
dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
if (stridesAttr.getStrides().back() != 1 &&
resultTy.getNumElements() != 1)
- return failure();
+ return rewriter.notifyMatchFailure(
+ op, "most minor memref dim must have unit stride");
}
+
+ if (memType.getRank() > 1)
+ useDelinearization = true;
}
Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
op.getIndices());
- auto baseOffsets = llvm::to_vector(op.getOffsets());
- Value lastBaseOffset = baseOffsets.back();
+ auto loadOffsets = llvm::to_vector(op.getOffsets());
+ Value lastLoadOffset = loadOffsets.back();
+
+ // Compute the memref shape and linearized offsets once, outside the
+ // per-element loop.
+ SmallVector<OpFoldResult> baseShape;
+ Value linearizedOffsets;
+ if (useDelinearization) {
+ baseShape = memref::getMixedSizes(rewriter, loc, base);
+ linearizedOffsets = affine::AffineLinearizeIndexOp::create(
----------------
krzysz00 wrote:
https://github.com/llvm/llvm-project/blob/51937fc9969c39bafc4991ceeb7c7113696aa7df/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp#L277 - the vector.gather to LLVM lowering - doesn't try do any such delinearizations / stride application.
Second, to quote that documentation update itself
> If the resulting position exceeds the size
> of a dimension, it naturally advances into the next row and/or plane
> according to the identity (row-major) layout of `base`. Importantly, for
> MemRefs, `indices` are interpreted assuming an identity (contiguous) MemRef
> layout and do not account for non-identity strides.
This PR violates that wording
Third, if we look at the pseudocode in the gather documentation, it also implies no delinearization like this
```
result[i,j] := if mask[i,j] then base[%ofs_0, %ofs_1, %ofs_2 + indices[i,j]]
else pass_thru[i,j]
```
Note how the index is added to the last dimension
https://github.com/llvm/llvm-project/pull/184706
More information about the Mlir-commits
mailing list