[Mlir-commits] [mlir] [mlir][vector] Fix vector.gather lowering for strided memrefs. (PR #184706)
Andrzej Warzyński
llvmlistbot at llvm.org
Fri Mar 6 08:47:05 PST 2026
banach-space wrote:
Sorry for the delay responding!
> The old implementation did not take strides into account, which leads to wrong access.
IIUC, the original code was incorrect because it assumed that the underlying MemRef was **contiguous** in memory, right? However, in the presence of _non-default strides_ (i.e. _non-identity layout_), the underlying MemRef will not be contiguous, right? So, strides are actually secondary here. Instead, its the non-contiguity that's key. In this sense, https://github.com/llvm/llvm-project/pull/181357 is tangential to this PR - I would update the title to avoid the confusion.
>From your summary:
> It is correct for continguous memrefs, but not strided memrefs.
I don't think that that's 100% correct. Let me take this example from the test file:
```mlir
func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
return %0 : vector<2x3xf32>
}
```
After lowering:
```mlir
func.func @gather_memref_2d(%arg0: memref<?x?xf32>, %arg1: vector<2x3xindex>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = ub.poison : vector<2x3xf32>
%1 = vector.extract %arg3[0] : vector<3xf32> from vector<2x3xf32>
%2 = vector.extract %arg2[0, 0] : i1 from vector<2x3xi1>
%3 = vector.extract %arg1[0, 0] : index from vector<2x3xindex>
%4 = arith.addi %3, %c1 : index
%5 = scf.if %2 -> (vector<3xf32>) {
%29 = vector.load %arg0[%c0, %4] : memref<?x?xf32>, vector<1xf32>
%30 = vector.extract %29[0] : f32 from vector<1xf32>
%31 = vector.insert %30, %1 [0] : f32 into vector<3xf32>
scf.yield %31 : vector<3xf32>
} else {
scf.yield %1 : vector<3xf32>
}
// ...
```
**Question:** How do we know that `%4` won't be out-of-bounds? Interestingly, "out-of-bounds" is actually supported (from `vector.load` [docs](https://mlir.llvm.org/docs/Dialects/Vector/#vectorload-vectorloadop)):
> Representation-wise, the ‘vector.load’ operation permits out-of-bounds reads. Support and implementation of out-of-bounds vector loads is target-specific
I suspect that the current lowering of `vector.gather` just happens to "work" with the LLVM (and SPIR-V?) lowering of `vector.load`, but, in general, could be incorrect in other scenarios.
Note, your change is valid for both contiguous and non-contiguous MemRef. With that in mind, I suggest that we avoid creating special cases and simply update the lowering unconditionally for all cases. That should be safer overall. WDYT?
Thanks for working on this!
https://github.com/llvm/llvm-project/pull/184706
More information about the Mlir-commits
mailing list