[Mlir-commits] [mlir] [mlir][vector] Lower vector.gather with delinearization approach (PR #184706)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Mar 16 14:55:26 PDT 2026


================
@@ -183,22 +191,39 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     Value condMask = op.getMask();
     Value base = op.getBase();
 
-    // vector.load requires the most minor memref dim to have unit stride
-    // (unless reading exactly 1 element)
+    // For multi-dimensional memrefs, use linearize+delinearize to compute
+    // correct N-D load indices from the 1-D gather index.
+    bool useDelinearization = false;
     if (auto memType = dyn_cast<MemRefType>(base.getType())) {
+      // vector.load requires the most minor memref dim to have unit stride
+      // (unless reading exactly 1 element).
       if (auto stridesAttr =
               dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
         if (stridesAttr.getStrides().back() != 1 &&
             resultTy.getNumElements() != 1)
-          return failure();
+          return rewriter.notifyMatchFailure(
+              op, "most minor memref dim must have unit stride");
       }
+
+      if (memType.getRank() > 1)
+        useDelinearization = true;
     }
 
     Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
         loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
         op.getIndices());
-    auto baseOffsets = llvm::to_vector(op.getOffsets());
-    Value lastBaseOffset = baseOffsets.back();
+    auto loadOffsets = llvm::to_vector(op.getOffsets());
+    Value lastLoadOffset = loadOffsets.back();
+
+    // Compute the memref shape and linearized offsets once, outside the
+    // per-element loop.
+    SmallVector<OpFoldResult> baseShape;
+    Value linearizedOffsets;
+    if (useDelinearization) {
+      baseShape = memref::getMixedSizes(rewriter, loc, base);
+      linearizedOffsets = affine::AffineLinearizeIndexOp::create(
----------------
banach-space wrote:

Thanks for the example!

> Ok, let me clarify. If you do run this pattern on such a strided memref - let's say we've got gather indices [0, 1, ... 7] into a memref<2x4xf32, strided<[128, 1]>`, you'll produce vector.load operations from `[0, 0], ... [0, 3], [1, 0], ...`and`vector.load` will do the right thing.

Indeed, that is correct and that's the behaviour that I had in mind.

> If you _don't_ run this pattern and go directly to the llvm dialect, you'll get loads from `[0, 0], [0, 1], ... [0, 3], [0, 4] ... [0, 7]`. This is consistent with the "as if identity layout" semantics we're defining currently, but **not** consistent with the behavior of this lowering.

This would be incorrect and would mean that we are missing a translation from logical to physical indices, which should happen when lowering a strided memref to something that does not encode strides. Could you point us at the lowering that you are referring to? We should fix that.

In any case, this change is correct and consistent with https://discourse.llvm.org/t/rfc-semantics-of-vector-gather-indices-with-strided-memrefs/ + https://github.com/llvm/llvm-project/pull/181357.

https://github.com/llvm/llvm-project/pull/184706


More information about the Mlir-commits mailing list