[Mlir-commits] [mlir] [mlir][vector] Fix vector.gather lowering for strided memrefs. (PR #184706)
Han-Chung Wang
llvmlistbot at llvm.org
Fri Mar 6 12:58:23 PST 2026
hanhanW wrote:
> Sorry for the delay responding!
>
> > The old implementation did not take strides into account, which leads to wrong access.
>
> IIUC, the original code was incorrect because it assumed that the underlying MemRef was **contiguous** in memory, right? However, in the presence of _non-default strides_ (i.e. _non-identity layout_), the underlying MemRef will not be contiguous, right? So, strides are actually secondary here. Instead, its the non-contiguity that's key. In this sense, #181357 is tangential to this PR - I would update the title to avoid the confusion.
>
> From your summary:
>
> > It is correct for continguous memrefs, but not strided memrefs.
>
> I don't think that that's 100% correct. Let me take this example from the test file:
>
> ```mlir
> func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> {
> %c0 = arith.constant 0 : index
> %c1 = arith.constant 1 : index
> %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
> return %0 : vector<2x3xf32>
> }
> ```
>
> After lowering:
>
> ```mlir
> func.func @gather_memref_2d(%arg0: memref<?x?xf32>, %arg1: vector<2x3xindex>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
> %c0 = arith.constant 0 : index
> %c1 = arith.constant 1 : index
> %0 = ub.poison : vector<2x3xf32>
> %1 = vector.extract %arg3[0] : vector<3xf32> from vector<2x3xf32>
> %2 = vector.extract %arg2[0, 0] : i1 from vector<2x3xi1>
> %3 = vector.extract %arg1[0, 0] : index from vector<2x3xindex>
> %4 = arith.addi %3, %c1 : index
> %5 = scf.if %2 -> (vector<3xf32>) {
> %29 = vector.load %arg0[%c0, %4] : memref<?x?xf32>, vector<1xf32>
> %30 = vector.extract %29[0] : f32 from vector<1xf32>
> %31 = vector.insert %30, %1 [0] : f32 into vector<3xf32>
> scf.yield %31 : vector<3xf32>
> } else {
> scf.yield %1 : vector<3xf32>
> }
> // ...
> ```
>
> **Question:** How do we know that `%4` won't be out-of-bounds? Interestingly, "out-of-bounds" is actually supported (from `vector.load` [docs](https://mlir.llvm.org/docs/Dialects/Vector/#vectorload-vectorloadop)):
>
> > Representation-wise, the ‘vector.load’ operation permits out-of-bounds reads. Support and implementation of out-of-bounds vector loads is target-specific
>
> I suspect that the current lowering of `vector.gather` just happens to "work" with the LLVM (and SPIR-V?) lowering of `vector.load`, but, in general, could be incorrect in other scenarios.
>
> Note, your change is valid for both contiguous and non-contiguous MemRef. With that in mind, I suggest that we avoid creating special cases and simply update the lowering unconditionally for all cases. That should be safer overall. WDYT?
>
> Thanks for working on this!
Very good point, and many thansk for the example! Yeah, it was only correct for 1-D cases after you clarify the semantics for me. It's a good idea to use delinearilizing approach to me, many thanks!
https://github.com/llvm/llvm-project/pull/184706
More information about the Mlir-commits
mailing list