[Mlir-commits] [mlir] [mlir][vector] Lower vector.gather with delinearization approach (PR #184706)

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Mar 17 07:48:16 PDT 2026


================
@@ -183,22 +191,39 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     Value condMask = op.getMask();
     Value base = op.getBase();
 
-    // vector.load requires the most minor memref dim to have unit stride
-    // (unless reading exactly 1 element)
+    // For multi-dimensional memrefs, use linearize+delinearize to compute
+    // correct N-D load indices from the 1-D gather index.
+    bool useDelinearization = false;
     if (auto memType = dyn_cast<MemRefType>(base.getType())) {
+      // vector.load requires the most minor memref dim to have unit stride
+      // (unless reading exactly 1 element).
       if (auto stridesAttr =
               dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
         if (stridesAttr.getStrides().back() != 1 &&
             resultTy.getNumElements() != 1)
-          return failure();
+          return rewriter.notifyMatchFailure(
+              op, "most minor memref dim must have unit stride");
       }
+
+      if (memType.getRank() > 1)
+        useDelinearization = true;
     }
 
     Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
         loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
         op.getIndices());
-    auto baseOffsets = llvm::to_vector(op.getOffsets());
-    Value lastBaseOffset = baseOffsets.back();
+    auto loadOffsets = llvm::to_vector(op.getOffsets());
+    Value lastLoadOffset = loadOffsets.back();
+
+    // Compute the memref shape and linearized offsets once, outside the
+    // per-element loop.
+    SmallVector<OpFoldResult> baseShape;
+    Value linearizedOffsets;
+    if (useDelinearization) {
+      baseShape = memref::getMixedSizes(rewriter, loc, base);
+      linearizedOffsets = affine::AffineLinearizeIndexOp::create(
----------------
krzysz00 wrote:

Well, no, that pseudocode specifies exactly what happens - you keep doing in the final dimension of the memref. Adding 1 to the gather index always means a step or 1 * the stride of the final dimension in the memref (so 1, by definition), which is consistent with the way vector load works 

Basically, the semantics from the LLVM lowering are the current semantics that I thought we'd agreed to, not the ones from this PR

(And what could "ignoring memref strides" possibly mean other than just continuing to step off the right edge of the trailing dimension the way we wrote out the pseudocode?



https://github.com/llvm/llvm-project/pull/184706


More information about the Mlir-commits mailing list