[Mlir-commits] [mlir] [mlir][vector] Account for subview offset in gather lowering. (PR #195359)

Han-Chung Wang llvmlistbot at llvm.org
Fri May 1 14:13:34 PDT 2026


https://github.com/hanhanW created https://github.com/llvm/llvm-project/pull/195359

Strided vector.gather on a column subview was reading the wrong column because the rewrite to a collapsed gather dropped the subview's static offset.

>From 140cc291821e73b3f452a6a27a23f72955f58257 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 1 May 2026 13:50:32 -0700
Subject: [PATCH] [mlir][vector] Account for subview offset in gather lowering.

Strided vector.gather on a column subview was reading the wrong column
because the rewrite to a collapsed gather dropped the subview's static
offset.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../Vector/Transforms/LowerVectorGather.cpp   | 72 +++++++++++++++----
 .../Vector/vector-gather-lowering.mlir        | 58 +++++++++++++++
 2 files changed, 116 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 7194d41d60df7..de39f61b222b7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
 
@@ -79,23 +80,29 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
 };
 
 /// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
-/// MemRef with updated indices that model the strided access.
+/// MemRef with updated offsets/indices that model the strided access.
 ///
 /// ```mlir
-///   %subview = memref.subview %M (...)
-///     : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-///   %gather = vector.gather %subview[%idxs] (...)
-///     : memref<100xf32, strided<[3]>>
+///   %subview = memref.subview %M[%i, %j] [100, 1] [1, 1]
+///     : memref<100x3xf32> to memref<100xf32, strided<[3], offset: ?>>
+///   %gather = vector.gather %subview[%c0] [%idxs] (...)
+///     : memref<100xf32, strided<[3], offset: ?>>
 /// ```
 /// ==>
 /// ```mlir
 ///   %collapse_shape = memref.collapse_shape %M (...)
 ///     : memref<100x3xf32> into memref<300xf32>
 ///   %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
-///   %gather = vector.gather %collapse_shape[%new_idxs] (...)
+///   %new_off  = arith.addi %c0_scaled, %subview_offset : index
+///   %gather = vector.gather %collapse_shape[%new_off] [%new_idxs] (...)
 ///     : memref<300xf32> (...)
 /// ```
 ///
+/// The subview's static offset (the linearized position of the first element
+/// in the source memref) must be folded into the gather's base offsets, so a
+/// subview that selects e.g. column `j_sub` of a row-major `MxN` memref still
+/// reads from `M_base + j_sub + idx * N` instead of `M_base + idx * N`.
+///
 /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
 /// but should be fairly straightforward to extend beyond that.
 struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
@@ -134,27 +141,64 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
     if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
       return failure();
 
+    // The result memref's offset is the linearized position of the subview's
+    // first element within the source memref. Bail out on dynamic offsets so
+    // we don't have to materialize them; the conditional-load fallback will
+    // still produce correct code.
+    int64_t subviewOffset = stridedLayoutAttr.getOffset();
+    if (ShapedType::isDynamic(subviewOffset))
+      return failure();
+
     // 1. Collapse the input memref so that it's "flat".
     SmallVector<ReassociationIndices> reassoc = {{0, 1}};
     Value collapsed = memref::CollapseShapeOp::create(
         rewriter, op.getLoc(), subview.getSource(), reassoc);
 
-    // 2. Generate new gather indices that will model the
-    // strided access.
+    // 2. Generate new gather indices that will model the strided access.
+    // Take `memref<4xf32, strided<[3], offset: 1>>` and lane k as an example.
+    // For the rewrite to be correct, the flat positions must match:
+    //   new_off + new_idxs[k] = 1 + (base_off + idxs[k]) * 3
+    //                         = 1 + base_off * 3 + idxs[k] * 3
+    // So the newIdxs is scaled with the stride.
     IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
     VectorType vType = op.getIndices().getType();
     Value mulCst = arith::ConstantOp::create(
         rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
-
     Value newIdxs =
         arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
 
-    // 3. Create an updated gather op with the collapsed input memref and the
-    // updated indices.
+    // 3. Linearize the gather's base offsets through the source memref. On the
+    // collapsed memref the trailing offset must be scaled by the source's
+    // trailing dim and shifted by the subview's static offset.
+    // Pick new_idxs[k] = idxs[k] * 3 (that's step 2), and solve for new_off:
+    //   new_off = 1 + base_off * 3
+    //           = subview_offset + base_off * stride
+    SmallVector<Value> newOffsets(op.getOffsets());
+    bool trailingOffsetIsZero = isZeroInteger(newOffsets.back());
+    if (!trailingOffsetIsZero) {
+      Value strideVal =
+          arith::ConstantIndexOp::create(rewriter, op.getLoc(), srcTrailingDim);
+      newOffsets.back() = arith::MulIOp::create(rewriter, op.getLoc(),
+                                                newOffsets.back(), strideVal);
+    }
+    if (subviewOffset != 0) {
+      Value subviewOffsetValue =
+          arith::ConstantIndexOp::create(rewriter, op.getLoc(), subviewOffset);
+      if (trailingOffsetIsZero) {
+        newOffsets.back() = subviewOffsetValue;
+      } else {
+        newOffsets.back() =
+            arith::AddIOp::create(rewriter, op.getLoc(), newOffsets.back(),
+                                  subviewOffsetValue)
+                .getResult();
+      }
+    }
+
+    // 4. Create an updated gather op with the collapsed input memref and the
+    // updated offsets/indices.
     Value newGather = vector::GatherOp::create(
-        rewriter, op.getLoc(), op.getResult().getType(), collapsed,
-        op.getOffsets(), newIdxs, op.getMask(), op.getPassThru(),
-        op.getAlignmentAttr());
+        rewriter, op.getLoc(), op.getResult().getType(), collapsed, newOffsets,
+        newIdxs, op.getMask(), op.getPassThru(), op.getAlignmentAttr());
     rewriter.replaceOp(op, newGather);
 
     return success();
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 59b13e300e5e5..1df7da00ba077 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -360,3 +360,61 @@ func.func @gather_memref_2d_delinearize_nonzero_offsets(
       vector<2xi1>, vector<2xf32> into vector<2xf32>
   return %0 : vector<2xf32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @strided_gather_with_offset(
+// CHECK-SAME:        %[[BASE:.+]]: memref<4x3xf32>,
+// CHECK-SAME:        %[[IDXS:.+]]: vector<2xindex>
+// CHECK-DAG:       %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG:       %[[CST_3:.+]] = arith.constant dense<3> : vector<2xindex>
+// CHECK:           %[[COLLAPSED:.+]] = memref.collapse_shape %[[BASE]] {{\[\[}}0, 1]] : memref<4x3xf32> into memref<12xf32>
+// CHECK:           %[[NEW_IDXS:.+]] = arith.muli %[[IDXS]], %[[CST_3]]
+// CHECK:           %[[IDX_0:.+]] = vector.extract %[[NEW_IDXS]][0]
+// CHECK:           %[[ADDR_0:.+]] = arith.addi %[[IDX_0]], %[[C1]]
+// CHECK:           scf.if
+// CHECK:             vector.load %[[COLLAPSED]][%[[ADDR_0]]] : memref<12xf32>
+// CHECK:           %[[IDX_1:.+]] = vector.extract %[[NEW_IDXS]][1]
+// CHECK:           %[[ADDR_1:.+]] = arith.addi %[[IDX_1]], %[[C1]]
+// CHECK:           scf.if
+// CHECK:             vector.load %[[COLLAPSED]][%[[ADDR_1]]] : memref<12xf32>
+func.func @strided_gather_with_offset(%base: memref<4x3xf32>,
+                                      %idxs: vector<2xindex>,
+                                      %mask: vector<2xi1>,
+                                      %pass_thru: vector<2xf32>)
+    -> vector<2xf32> {
+  %c0 = arith.constant 0 : index
+  %sub = memref.subview %base[0, 1] [4, 1] [1, 1]
+    : memref<4x3xf32> to memref<4xf32, strided<[3], offset: 1>>
+  %0 = vector.gather %sub[%c0] [%idxs], %mask, %pass_thru
+    : memref<4xf32, strided<[3], offset: 1>>, vector<2xindex>,
+      vector<2xi1>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// -----
+
+// TODO: Support dynamic offsets.
+// CHECK-LABEL:   func.func @negative_strided_gather_with_dynamic_offset(
+// CHECK-SAME:        %[[BASE:.+]]: memref<4x3xf32>,
+// CHECK-SAME:        %[[COL:.+]]: index,
+// CHECK-NOT:       memref.collapse_shape
+// CHECK:           %[[SUB:.+]] = memref.subview %[[BASE]][0, %[[COL]]] [4, 1] [1, 1]
+// CHECK-SAME:        : memref<4x3xf32> to memref<4xf32, strided<[3], offset: ?>>
+// CHECK:           %[[RES:.+]] = vector.gather %[[SUB]]
+// CHECK-SAME:        : memref<4xf32, strided<[3], offset: ?>>
+// CHECK:           return %[[RES]]
+func.func @negative_strided_gather_with_dynamic_offset(
+    %base: memref<4x3xf32>,
+    %col: index,
+    %idxs: vector<2xindex>,
+    %mask: vector<2xi1>,
+    %pass_thru: vector<2xf32>) -> vector<2xf32> {
+  %c0 = arith.constant 0 : index
+  %sub = memref.subview %base[0, %col] [4, 1] [1, 1]
+    : memref<4x3xf32> to memref<4xf32, strided<[3], offset: ?>>
+  %0 = vector.gather %sub[%c0] [%idxs], %mask, %pass_thru
+    : memref<4xf32, strided<[3], offset: ?>>, vector<2xindex>,
+      vector<2xi1>, vector<2xf32> into vector<2xf32>
+  return %0 : vector<2xf32>
+}



More information about the Mlir-commits mailing list