[Mlir-commits] [mlir] [mlir][vector] Fix vector.gather lowering for strided memrefs. (PR #184706)

Andrzej Warzyński llvmlistbot at llvm.org
Fri Mar 6 08:47:05 PST 2026


banach-space wrote:

Sorry for the delay responding!

> The old implementation did not take strides into account, which leads to wrong access. 

IIUC, the original code was incorrect because it assumed that the underlying MemRef was **contiguous** in memory, right? However, in the presence of _non-default strides_ (i.e. _non-identity layout_), the underlying MemRef will not be contiguous, right? So, strides are actually secondary here. Instead, its the non-contiguity that's key. In this sense, https://github.com/llvm/llvm-project/pull/181357 is tangential to this PR - I would update the title to avoid the confusion.

>From your summary:
> It is correct for continguous memrefs, but not strided memrefs.

I don't think that that's 100% correct. Let me take this example from the test file:
```mlir
 func.func @gather_memref_2d(%base: memref<?x?xf32>, %v: vector<2x3xindex>, %mask: vector<2x3xi1>, %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> {
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  %0 = vector.gather %base[%c0, %c1][%v], %mask, %pass_thru : memref<?x?xf32>, vector<2x3xindex>, vector<2x3xi1>, vector<2x3xf32> into vector<2x3xf32>
  return %0 : vector<2x3xf32>
 }
```

After lowering:
```mlir
  func.func @gather_memref_2d(%arg0: memref<?x?xf32>, %arg1: vector<2x3xindex>, %arg2: vector<2x3xi1>, %arg3: vector<2x3xf32>) -> vector<2x3xf32> {
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %0 = ub.poison : vector<2x3xf32>
    %1 = vector.extract %arg3[0] : vector<3xf32> from vector<2x3xf32>
    %2 = vector.extract %arg2[0, 0] : i1 from vector<2x3xi1>
    %3 = vector.extract %arg1[0, 0] : index from vector<2x3xindex>
    %4 = arith.addi %3, %c1 : index
    %5 = scf.if %2 -> (vector<3xf32>) {
      %29 = vector.load %arg0[%c0, %4] : memref<?x?xf32>, vector<1xf32>
      %30 = vector.extract %29[0] : f32 from vector<1xf32>
      %31 = vector.insert %30, %1 [0] : f32 into vector<3xf32>
      scf.yield %31 : vector<3xf32>
    } else {
      scf.yield %1 : vector<3xf32>
    }
    // ...
```

**Question:** How do we know that `%4` won't be out-of-bounds? Interestingly, "out-of-bounds" is actually supported (from `vector.load` [docs](https://mlir.llvm.org/docs/Dialects/Vector/#vectorload-vectorloadop)):
> Representation-wise, the ‘vector.load’ operation permits out-of-bounds reads. Support and implementation of out-of-bounds vector loads is target-specific

I suspect that the current lowering of `vector.gather` just happens to "work" with the LLVM (and SPIR-V?) lowering of `vector.load`, but, in general, could be incorrect in other scenarios. 

Note, your change is valid for both contiguous and non-contiguous MemRef. With that in mind, I suggest that we avoid creating special cases and simply update the lowering unconditionally for all cases. That should be safer overall. WDYT?

Thanks for working on this!

https://github.com/llvm/llvm-project/pull/184706


More information about the Mlir-commits mailing list