[Mlir-commits] [mlir] [mlir][vector] Account for subview offset in gather lowering. (PR #195359)
Han-Chung Wang
llvmlistbot at llvm.org
Fri May 1 14:13:34 PDT 2026
https://github.com/hanhanW created https://github.com/llvm/llvm-project/pull/195359
Strided vector.gather on a column subview was reading the wrong column because the rewrite to a collapsed gather dropped the subview's static offset.
>From 140cc291821e73b3f452a6a27a23f72955f58257 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Fri, 1 May 2026 13:50:32 -0700
Subject: [PATCH] [mlir][vector] Account for subview offset in gather lowering.
Strided vector.gather on a column subview was reading the wrong column
because the rewrite to a collapsed gather dropped the subview's static
offset.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../Vector/Transforms/LowerVectorGather.cpp | 72 +++++++++++++++----
.../Vector/vector-gather-lowering.mlir | 58 +++++++++++++++
2 files changed, 116 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 7194d41d60df7..de39f61b222b7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
@@ -79,23 +80,29 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
};
/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
-/// MemRef with updated indices that model the strided access.
+/// MemRef with updated offsets/indices that model the strided access.
///
/// ```mlir
-/// %subview = memref.subview %M (...)
-/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-/// %gather = vector.gather %subview[%idxs] (...)
-/// : memref<100xf32, strided<[3]>>
+/// %subview = memref.subview %M[%i, %j] [100, 1] [1, 1]
+/// : memref<100x3xf32> to memref<100xf32, strided<[3], offset: ?>>
+/// %gather = vector.gather %subview[%c0] [%idxs] (...)
+/// : memref<100xf32, strided<[3], offset: ?>>
/// ```
/// ==>
/// ```mlir
/// %collapse_shape = memref.collapse_shape %M (...)
/// : memref<100x3xf32> into memref<300xf32>
/// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
-/// %gather = vector.gather %collapse_shape[%new_idxs] (...)
+/// %new_off = arith.addi %c0_scaled, %subview_offset : index
+/// %gather = vector.gather %collapse_shape[%new_off] [%new_idxs] (...)
/// : memref<300xf32> (...)
/// ```
///
+/// The subview's static offset (the linearized position of the first element
+/// in the source memref) must be folded into the gather's base offsets, so a
+/// subview that selects e.g. column `j_sub` of a row-major `MxN` memref still
+/// reads from `M_base + j_sub + idx * N` instead of `M_base + idx * N`.
+///
/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
/// but should be fairly straightforward to extend beyond that.
struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
@@ -134,27 +141,64 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
return failure();
+ // The result memref's offset is the linearized position of the subview's
+ // first element within the source memref. Bail out on dynamic offsets so
+ // we don't have to materialize them; the conditional-load fallback will
+ // still produce correct code.
+ int64_t subviewOffset = stridedLayoutAttr.getOffset();
+ if (ShapedType::isDynamic(subviewOffset))
+ return failure();
+
// 1. Collapse the input memref so that it's "flat".
SmallVector<ReassociationIndices> reassoc = {{0, 1}};
Value collapsed = memref::CollapseShapeOp::create(
rewriter, op.getLoc(), subview.getSource(), reassoc);
- // 2. Generate new gather indices that will model the
- // strided access.
+ // 2. Generate new gather indices that will model the strided access.
+ // Take `memref<4xf32, strided<[3], offset: 1>>` and lane k as an example.
+ // For the rewrite to be correct, the flat positions must match:
+ // new_off + new_idxs[k] = 1 + (base_off + idxs[k]) * 3
+ // = 1 + base_off * 3 + idxs[k] * 3
+ // So the newIdxs is scaled with the stride.
IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
VectorType vType = op.getIndices().getType();
Value mulCst = arith::ConstantOp::create(
rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
-
Value newIdxs =
arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
- // 3. Create an updated gather op with the collapsed input memref and the
- // updated indices.
+ // 3. Linearize the gather's base offsets through the source memref. On the
+ // collapsed memref the trailing offset must be scaled by the source's
+ // trailing dim and shifted by the subview's static offset.
+ // Pick new_idxs[k] = idxs[k] * 3 (that's step 2), and solve for new_off:
+ // new_off = 1 + base_off * 3
+ // = subview_offset + base_off * stride
+ SmallVector<Value> newOffsets(op.getOffsets());
+ bool trailingOffsetIsZero = isZeroInteger(newOffsets.back());
+ if (!trailingOffsetIsZero) {
+ Value strideVal =
+ arith::ConstantIndexOp::create(rewriter, op.getLoc(), srcTrailingDim);
+ newOffsets.back() = arith::MulIOp::create(rewriter, op.getLoc(),
+ newOffsets.back(), strideVal);
+ }
+ if (subviewOffset != 0) {
+ Value subviewOffsetValue =
+ arith::ConstantIndexOp::create(rewriter, op.getLoc(), subviewOffset);
+ if (trailingOffsetIsZero) {
+ newOffsets.back() = subviewOffsetValue;
+ } else {
+ newOffsets.back() =
+ arith::AddIOp::create(rewriter, op.getLoc(), newOffsets.back(),
+ subviewOffsetValue)
+ .getResult();
+ }
+ }
+
+ // 4. Create an updated gather op with the collapsed input memref and the
+ // updated offsets/indices.
Value newGather = vector::GatherOp::create(
- rewriter, op.getLoc(), op.getResult().getType(), collapsed,
- op.getOffsets(), newIdxs, op.getMask(), op.getPassThru(),
- op.getAlignmentAttr());
+ rewriter, op.getLoc(), op.getResult().getType(), collapsed, newOffsets,
+ newIdxs, op.getMask(), op.getPassThru(), op.getAlignmentAttr());
rewriter.replaceOp(op, newGather);
return success();
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 59b13e300e5e5..1df7da00ba077 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -360,3 +360,61 @@ func.func @gather_memref_2d_delinearize_nonzero_offsets(
vector<2xi1>, vector<2xf32> into vector<2xf32>
return %0 : vector<2xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @strided_gather_with_offset(
+// CHECK-SAME: %[[BASE:.+]]: memref<4x3xf32>,
+// CHECK-SAME: %[[IDXS:.+]]: vector<2xindex>
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[CST_3:.+]] = arith.constant dense<3> : vector<2xindex>
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[BASE]] {{\[\[}}0, 1]] : memref<4x3xf32> into memref<12xf32>
+// CHECK: %[[NEW_IDXS:.+]] = arith.muli %[[IDXS]], %[[CST_3]]
+// CHECK: %[[IDX_0:.+]] = vector.extract %[[NEW_IDXS]][0]
+// CHECK: %[[ADDR_0:.+]] = arith.addi %[[IDX_0]], %[[C1]]
+// CHECK: scf.if
+// CHECK: vector.load %[[COLLAPSED]][%[[ADDR_0]]] : memref<12xf32>
+// CHECK: %[[IDX_1:.+]] = vector.extract %[[NEW_IDXS]][1]
+// CHECK: %[[ADDR_1:.+]] = arith.addi %[[IDX_1]], %[[C1]]
+// CHECK: scf.if
+// CHECK: vector.load %[[COLLAPSED]][%[[ADDR_1]]] : memref<12xf32>
+func.func @strided_gather_with_offset(%base: memref<4x3xf32>,
+ %idxs: vector<2xindex>,
+ %mask: vector<2xi1>,
+ %pass_thru: vector<2xf32>)
+ -> vector<2xf32> {
+ %c0 = arith.constant 0 : index
+ %sub = memref.subview %base[0, 1] [4, 1] [1, 1]
+ : memref<4x3xf32> to memref<4xf32, strided<[3], offset: 1>>
+ %0 = vector.gather %sub[%c0] [%idxs], %mask, %pass_thru
+ : memref<4xf32, strided<[3], offset: 1>>, vector<2xindex>,
+ vector<2xi1>, vector<2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// -----
+
+// TODO: Support dynamic offsets.
+// CHECK-LABEL: func.func @negative_strided_gather_with_dynamic_offset(
+// CHECK-SAME: %[[BASE:.+]]: memref<4x3xf32>,
+// CHECK-SAME: %[[COL:.+]]: index,
+// CHECK-NOT: memref.collapse_shape
+// CHECK: %[[SUB:.+]] = memref.subview %[[BASE]][0, %[[COL]]] [4, 1] [1, 1]
+// CHECK-SAME: : memref<4x3xf32> to memref<4xf32, strided<[3], offset: ?>>
+// CHECK: %[[RES:.+]] = vector.gather %[[SUB]]
+// CHECK-SAME: : memref<4xf32, strided<[3], offset: ?>>
+// CHECK: return %[[RES]]
+func.func @negative_strided_gather_with_dynamic_offset(
+ %base: memref<4x3xf32>,
+ %col: index,
+ %idxs: vector<2xindex>,
+ %mask: vector<2xi1>,
+ %pass_thru: vector<2xf32>) -> vector<2xf32> {
+ %c0 = arith.constant 0 : index
+ %sub = memref.subview %base[0, %col] [4, 1] [1, 1]
+ : memref<4x3xf32> to memref<4xf32, strided<[3], offset: ?>>
+ %0 = vector.gather %sub[%c0] [%idxs], %mask, %pass_thru
+ : memref<4xf32, strided<[3], offset: ?>>, vector<2xindex>,
+ vector<2xi1>, vector<2xf32> into vector<2xf32>
+ return %0 : vector<2xf32>
+}
More information about the Mlir-commits
mailing list