[Mlir-commits] [mlir] [mlir][vector] Fix vector.gather lowering for strided memrefs. (PR #184706)

Han-Chung Wang llvmlistbot at llvm.org
Fri Mar 6 12:58:23 PST 2026


hanhanW wrote:

> Sorry for the delay responding!
> 
> > The old implementation did not take strides into account, which leads to wrong access.
> 
> IIUC, the original code was incorrect because it assumed that the underlying MemRef was **contiguous** in memory, right? However, in the presence of _non-default strides_ (i.e. _non-identity layout_), the underlying MemRef will not be contiguous, right? So, strides are actually secondary here. Instead, its the non-contiguity that's key. In this sense, #181357 is tangential to this PR - I would update the title to avoid the confusion.
> 
> From your summary:
> 
> > It is correct for continguous memrefs, but not strided memrefs.
> 
> I don't think that that's 100% correct. Let me take this example from the test file:
> 
> ```mlir
>  func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> {
>   %c0 = arith.constant 0 : index
>   %c1 = arith.constant 1 : index
>   %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
>   return %0 : vector<2x3xf32>
>  }
> ```
> 
> After lowering:
> 
> ```mlir
>   func.func @gather_memref_2d(%arg0: memref<?x?xf32>, %arg1: vector<2x3xindex>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
>     %c0 = arith.constant 0 : index
>     %c1 = arith.constant 1 : index
>     %0 = ub.poison : vector<2x3xf32>
>     %1 = vector.extract %arg3[0] : vector<3xf32> from vector<2x3xf32>
>     %2 = vector.extract %arg2[0, 0] : i1 from vector<2x3xi1>
>     %3 = vector.extract %arg1[0, 0] : index from vector<2x3xindex>
>     %4 = arith.addi %3, %c1 : index
>     %5 = scf.if %2 -> (vector<3xf32>) {
>       %29 = vector.load %arg0[%c0, %4] : memref<?x?xf32>, vector<1xf32>
>       %30 = vector.extract %29[0] : f32 from vector<1xf32>
>       %31 = vector.insert %30, %1 [0] : f32 into vector<3xf32>
>       scf.yield %31 : vector<3xf32>
>     } else {
>       scf.yield %1 : vector<3xf32>
>     }
>     // ...
> ```
> 
> **Question:** How do we know that `%4` won't be out-of-bounds? Interestingly, "out-of-bounds" is actually supported (from `vector.load` [docs](https://mlir.llvm.org/docs/Dialects/Vector/#vectorload-vectorloadop)):
> 
> > Representation-wise, the ‘vector.load’ operation permits out-of-bounds reads. Support and implementation of out-of-bounds vector loads is target-specific
> 
> I suspect that the current lowering of `vector.gather` just happens to "work" with the LLVM (and SPIR-V?) lowering of `vector.load`, but, in general, could be incorrect in other scenarios.
> 
> Note, your change is valid for both contiguous and non-contiguous MemRef. With that in mind, I suggest that we avoid creating special cases and simply update the lowering unconditionally for all cases. That should be safer overall. WDYT?
> 
> Thanks for working on this!

Very good point, and many thansk for the example! Yeah, it was only correct for 1-D cases after you clarify the semantics for me. It's a good idea to use delinearilizing approach to me, many thanks!

https://github.com/llvm/llvm-project/pull/184706


More information about the Mlir-commits mailing list