[Mlir-commits] [mlir] [mlir][vector] Lower vector.gather with delinearization approach (PR #184706)
Krzysztof Drewniak
llvmlistbot at llvm.org
Tue Mar 17 07:52:50 PDT 2026
================
@@ -183,22 +191,39 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Value condMask = op.getMask();
Value base = op.getBase();
- // vector.load requires the most minor memref dim to have unit stride
- // (unless reading exactly 1 element)
+ // For multi-dimensional memrefs, use linearize+delinearize to compute
+ // correct N-D load indices from the 1-D gather index.
+ bool useDelinearization = false;
if (auto memType = dyn_cast<MemRefType>(base.getType())) {
+ // vector.load requires the most minor memref dim to have unit stride
+ // (unless reading exactly 1 element).
if (auto stridesAttr =
dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
if (stridesAttr.getStrides().back() != 1 &&
resultTy.getNumElements() != 1)
- return failure();
+ return rewriter.notifyMatchFailure(
+ op, "most minor memref dim must have unit stride");
}
+
+ if (memType.getRank() > 1)
+ useDelinearization = true;
}
Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
op.getIndices());
- auto baseOffsets = llvm::to_vector(op.getOffsets());
- Value lastBaseOffset = baseOffsets.back();
+ auto loadOffsets = llvm::to_vector(op.getOffsets());
+ Value lastLoadOffset = loadOffsets.back();
+
+ // Compute the memref shape and linearized offsets once, outside the
+ // per-element loop.
+ SmallVector<OpFoldResult> baseShape;
+ Value linearizedOffsets;
+ if (useDelinearization) {
+ baseShape = memref::getMixedSizes(rewriter, loc, base);
+ linearizedOffsets = affine::AffineLinearizeIndexOp::create(
----------------
krzysz00 wrote:
But also, while it "works", these delinearizatook semantics are a pain to do transformations on and I don't endorse them
https://github.com/llvm/llvm-project/pull/184706
More information about the Mlir-commits
mailing list