[Mlir-commits] [mlir] [mlir][vector] Lower vector.gather with delinearization approach (PR #184706)
Krzysztof Drewniak
llvmlistbot at llvm.org
Tue Mar 17 07:51:16 PDT 2026
================
@@ -183,22 +191,39 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value condMask = op.getMask();
Value base = op.getBase();
- // vector.load requires the most minor memref dim to have unit stride
- // (unless reading exactly 1 element)
+ // For multi-dimensional memrefs, use linearize+delinearize to compute
+ // correct N-D load indices from the 1-D gather index.
+ bool useDelinearization = false;
if (auto memType = dyn_cast<MemRefType>(base.getType())) {
+ // vector.load requires the most minor memref dim to have unit stride
+ // (unless reading exactly 1 element).
if (auto stridesAttr =
dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
if (stridesAttr.getStrides().back() != 1 &&
resultTy.getNumElements() != 1)
- return failure();
+ return rewriter.notifyMatchFailure(
+ op, "most minor memref dim must have unit stride");
}
+
+ if (memType.getRank() > 1)
+ useDelinearization = true;
}
Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
op.getIndices());
- auto baseOffsets = llvm::to_vector(op.getOffsets());
- Value lastBaseOffset = baseOffsets.back();
+ auto loadOffsets = llvm::to_vector(op.getOffsets());
+ Value lastLoadOffset = loadOffsets.back();
+
+ // Compute the memref shape and linearized offsets once, outside the
+ // per-element loop.
+ SmallVector<OpFoldResult> baseShape;
+ Value linearizedOffsets;
+ if (useDelinearization) {
+ baseShape = memref::getMixedSizes(rewriter, loc, base);
+ linearizedOffsets = affine::AffineLinearizeIndexOp::create(
----------------
krzysz00 wrote:
... Ok, more accurately, the RFC didn't actually decide this question
And at a high level, these delinearize semantics are fine and decently useful, but we really need to update the documentation and all the vector.gather lowerings to match, including getting pseudocode that reflects the complexity
https://github.com/llvm/llvm-project/pull/184706
More information about the Mlir-commits
mailing list