[Mlir-commits] [mlir] [mlir][vector] Clarify the semantics of gather/scatter indexing (nfc) (PR #181357)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Mar 17 03:08:07 PDT 2026


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/181357

>From b6c231199ea5805e59d25827954e928a5653bdad Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 13 Feb 2026 09:26:06 +0000
Subject: [PATCH 1/4] [mlir][vector] Clarify the semantics of gather/scatter
 indexing (nfc)

For context, see:
* https://discourse.llvm.org/t/rfc-semantics-of-vector-gather-indices-with-strided-memrefs

Currently, `ConvertVectorToLLVM` rejects strided memrefs when lowering
`vector.gather` and `vector.scatter`. This PR adds tests to document
that behavior.

Supporting strided memrefs in the lowering is left as future work.
However, it is still unclear whether gather/scatter on strided memrefs
should be supported at all (see the Discourse discussion above).

This PR also adds tests for `vector.load` and `vector.store` in
`invalid.mlir` to document that these ops do not support strided
memrefs.
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 10 ++++++-
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  2 ++
 .../vector-to-llvm-interface.mlir             | 27 +++++++++++++++++++
 mlir/test/Dialect/Vector/invalid.mlir         | 17 +++++++++++-
 4 files changed, 54 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43ad435ccf1c1..019a087484a01 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2098,7 +2098,10 @@ def Vector_GatherOp :
     result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]]
                    else pass_thru[i,j]
     ```
-    The index into `base` only varies in the innermost ((k-1)-th) dimension.
+    The index into `base` only varies in the innermost ((k-1)-th) dimension and
+    is treated as a logical rather than a physical index, i.e. it does not
+    encode any potential non-identity strides in the underlying MemRef layout
+    (Tensors do not model strided layouts, so only logical indexing is possible).
 
     If a mask bit is set and the corresponding index is out-of-bounds for the
     given base, the behavior is undefined. If a mask bit is not set, the value
@@ -2191,6 +2194,11 @@ def Vector_ScatterOp
     is stored regardless of the index, and the index is allowed to be
     out-of-bounds.
 
+    The index into `base` only varies in the innermost ((k-1)-th) dimension and
+    is treated as a logical rather than a physical index, i.e. it does not
+    encode any potential non-identity strides in the underlying MemRef layout
+    (Tensors do not model strided layouts, so only logical indexing is possible).
+
     If the index vector contains two or more duplicate indices, the behavior is
     undefined. Underlying implementation may enforce strict sequential
     semantics.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9d81702581131..815909169c6b8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -290,6 +290,7 @@ class VectorGatherOpConversion
     MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
     assert(memRefType && "The base should be bufferized");
 
+    // TODO: Add support for strided MemRef.
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return rewriter.notifyMatchFailure(gather, "memref type not supported");
 
@@ -348,6 +349,7 @@ class VectorScatterOpConversion
     auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
     assert(memRefType && "The base should be bufferized");
 
+    // TODO: Add support for strided MemRef.
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return rewriter.notifyMatchFailure(scatter, "memref type not supported");
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 49c55f5b54496..076209cbc7a4c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2066,6 +2066,20 @@ func.func @gather_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %ar
 
 // -----
 
+// TODO: Implement this lowering.
+func.func @negative_gather_on_strided_memref(%arg0: memref<?xf32, strided<[2], offset: ?>>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3
+    : memref<?xf32, strided<[2], offset: ?>>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+  return %1 : vector<3xf32>
+}
+
+// CHECK-LABEL: func @negative_gather_on_strided_memref
+// CHECK-NOT: llvm.intr.masked.gather
+// CHECK: vector.gather
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.scatter
 //===----------------------------------------------------------------------===//
@@ -2152,6 +2166,19 @@ func.func @scatter_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %a
 // CHECK-LABEL: func @scatter_with_alignment
 // CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
 
+// -----
+
+// TODO: Implement this lowering.
+func.func @negative_scatter_on_strided_memref(%arg0: memref<?xf32, strided<[2], offset: ?>>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
+  %0 = arith.constant 0: index
+  vector.scatter %arg0[%0][%arg1], %arg2, %arg3
+    : memref<?xf32, strided<[2], offset: ?>>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+  return
+}
+
+// CHECK-LABEL: func @negative_scatter_on_strided_memref
+// CHECK-NOT: llvm.intr.masked.scatter
+// CHECK: vector.scatter
 
 // -----
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3957455ccc76e..c92e5c9dcb73f 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -2046,6 +2046,15 @@ func.func @load_non_pow_of_2_alignment(%memref: memref<4xi32>, %c0: index) {
 
 // -----
 
+func.func @load_non_unit_stride(%src : memref<?xi8, strided<[2], offset: ?>>) {
+  %c0 = arith.constant 0 : index
+  // expected-error @+1 {{'vector.load' op most minor memref dim must have unit stride}}
+  %0 = vector.load %src[%c0] : memref<?xi8, strided<[2], offset: ?>>, vector<16xi8>
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.store
 //===----------------------------------------------------------------------===//
@@ -2073,6 +2082,13 @@ func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi
   return
 }
 
+// -----
+func.func @store_non_unit_stride(%src : memref<?xi8, strided<[2], offset:?>>,%val : vector<16xi8>, %c0: index) {
+  // expected-error @below {{'vector.store' op most minor memref dim must have unit stride}}
+  vector.store %val, %src[%c0] : memref<?xi8, strided<[2], offset: ?>>, vector<16xi8>
+  return
+}
+
 // -----
 
 // Verify that vector.bitcast rejects vectors with i0 (zero-bitwidth) element type.
@@ -2127,4 +2143,3 @@ func.func @scan_i0(%a: vector<4xi0>, %init: vector<1xi0>) -> (vector<4xi0>, vect
   %0:2 = vector.scan <add>, %a, %init {inclusive = true, reduction_dim = 0 : i64} :
     vector<4xi0>, vector<1xi0>
   return %0#0, %0#1 : vector<4xi0>, vector<1xi0>
-}

>From 1e9b50a70ee917019fed4f18032e356ed845e72b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 16 Feb 2026 20:03:49 +0000
Subject: [PATCH 2/4] Update per Jakub's suggestion

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 51 +++++++++++++------
 1 file changed, 35 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 019a087484a01..3d582f72c8a74 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2082,26 +2082,28 @@ def Vector_GatherOp :
     3-D and the result is 2-D:
 
     ```mlir
-    func.func @gather_3D_to_2D(
-        %base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
-        %indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
-        %fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
-            %result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2]
-                                   [%indices], %mask, %fall_thru : [...]
-            return %result : vector<2x3xf32>
+      %base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
+      %indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
+      %fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
+          %result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2]
+                                 [%indices], %mask, %fall_thru : [...]
     }
     ```
 
     The indexing semantics are then,
 
     ```
-    result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]]
+    result[i,j] := if mask[i,j] then base[%ofs_0, %ofs_1, %ofs_2 + indices[i,j]]
                    else pass_thru[i,j]
     ```
-    The index into `base` only varies in the innermost ((k-1)-th) dimension and
-    is treated as a logical rather than a physical index, i.e. it does not
-    encode any potential non-identity strides in the underlying MemRef layout
-    (Tensors do not model strided layouts, so only logical indexing is possible).
+    Note, `indices` are element offsets - they are expressed in units of
+    elements (not bytes). The offset is added to the underlying memory address,
+    so if the resulting position exceeds the size of the innermost dimension,
+    it naturally advances into the next row and/or plane according to the
+    identity (row-major) layout of 3D `base` (col = dim 0, row = dim 1, plane
+    = dim 2). Importantly, `indices` are interpreted assuming an identity
+    (contiguous) MemRef layout and do not account for non-identity strides.
+
 
     If a mask bit is set and the corresponding index is out-of-bounds for the
     given base, the behavior is undefined. If a mask bit is not set, the value
@@ -2194,10 +2196,27 @@ def Vector_ScatterOp
     is stored regardless of the index, and the index is allowed to be
     out-of-bounds.
 
-    The index into `base` only varies in the innermost ((k-1)-th) dimension and
-    is treated as a logical rather than a physical index, i.e. it does not
-    encode any potential non-identity strides in the underlying MemRef layout
-    (Tensors do not model strided layouts, so only logical indexing is possible).
+    ```mlir
+      %base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
+      %indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
+      %src: vector<2x3xf32>) -> memref<?x10x?xf32> {
+          %result = vector.scatter %base[%ofs_0, %ofs_1, %ofs_2]
+                                   [%indices], %mask, %src : [...]
+    ```
+    The indexing semantics are then,
+
+    ```
+    if mask[i,j] then
+      base[%ofs_0, %ofs_1, %ofs_2 + indices[i,j]] := valueToStore[i,j]
+    ```
+
+    Note, `indices` are element offsets - they are expressed in units of
+    elements (not bytes). The offset is added to the underlying memory address,
+    so if the resulting position exceeds the size of the innermost dimension,
+    it naturally advances into the next row and/or plane according to the
+    identity (row-major) layout of 3D `base` (col = dim 0, row = dim 1, plane
+    = dim 2). Importantly, `indices` are interpreted assuming an identity
+    (contiguous) MemRef layout and do not account for non-identity strides.
 
     If the index vector contains two or more duplicate indices, the behavior is
     undefined. Underlying implementation may enforce strict sequential

>From a7a26c644e823a4681e550137eb5c95c92065700 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 3 Mar 2026 19:50:56 +0000
Subject: [PATCH 3/4] Tweak

---
 .../mlir/Dialect/Vector/IR/VectorOps.td       | 26 ++++++++++---------
 1 file changed, 14 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3d582f72c8a74..f857d0677a439 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2097,12 +2097,13 @@ def Vector_GatherOp :
                    else pass_thru[i,j]
     ```
     Note, `indices` are element offsets - they are expressed in units of
-    elements (not bytes). The offset is added to the underlying memory address,
-    so if the resulting position exceeds the size of the innermost dimension,
-    it naturally advances into the next row and/or plane according to the
-    identity (row-major) layout of 3D `base` (col = dim 0, row = dim 1, plane
-    = dim 2). Importantly, `indices` are interpreted assuming an identity
-    (contiguous) MemRef layout and do not account for non-identity strides.
+    elements (not bytes). Each element in `indices` represents a displacement
+    in units of elements from the starting element, i.e. `%base[%ofs_0, %ofs_1,
+    %ofs_2]` for the example above. If the resulting position exceeds the size
+    of a dimension, it naturally advances into the next row and/or plane
+    according to the identity (row-major) layout of `base`. Importantly, for
+    MemRefs, `indices` are interpreted assuming an identity (contiguous) MemRef
+    layout and do not account for non-identity strides.
 
 
     If a mask bit is set and the corresponding index is out-of-bounds for the
@@ -2211,12 +2212,13 @@ def Vector_ScatterOp
     ```
 
     Note, `indices` are element offsets - they are expressed in units of
-    elements (not bytes). The offset is added to the underlying memory address,
-    so if the resulting position exceeds the size of the innermost dimension,
-    it naturally advances into the next row and/or plane according to the
-    identity (row-major) layout of 3D `base` (col = dim 0, row = dim 1, plane
-    = dim 2). Importantly, `indices` are interpreted assuming an identity
-    (contiguous) MemRef layout and do not account for non-identity strides.
+    elements (not bytes). Each element in `indices` represents a displacement
+    in units of elements from the starting element, i.e. `%base[%ofs_0, %ofs_1,
+    %ofs_2]` for the example above. If the resulting position exceeds the size
+    of a dimension, it naturally advances into the next row and/or plane
+    according to the identity (row-major) layout of `base`. Importantly, for
+    MemRefs, `indices` are interpreted assuming an identity (contiguous) MemRef
+    layout and do not account for non-identity strides.
 
     If the index vector contains two or more duplicate indices, the behavior is
     undefined. Underlying implementation may enforce strict sequential

>From fd02f731e6309deba2e7ee00807f3d667c22f902 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 17 Mar 2026 09:38:36 +0000
Subject: [PATCH 4/4] Extract changes in tests - these will be moved to a new
 PR

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  2 --
 .../vector-to-llvm-interface.mlir             | 27 -------------------
 mlir/test/Dialect/Vector/invalid.mlir         | 17 +-----------
 3 files changed, 1 insertion(+), 45 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 815909169c6b8..9d81702581131 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -290,7 +290,6 @@ class VectorGatherOpConversion
     MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
     assert(memRefType && "The base should be bufferized");
 
-    // TODO: Add support for strided MemRef.
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return rewriter.notifyMatchFailure(gather, "memref type not supported");
 
@@ -349,7 +348,6 @@ class VectorScatterOpConversion
     auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
     assert(memRefType && "The base should be bufferized");
 
-    // TODO: Add support for strided MemRef.
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return rewriter.notifyMatchFailure(scatter, "memref type not supported");
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 076209cbc7a4c..49c55f5b54496 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2066,20 +2066,6 @@ func.func @gather_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %ar
 
 // -----
 
-// TODO: Implement this lowering.
-func.func @negative_gather_on_strided_memref(%arg0: memref<?xf32, strided<[2], offset: ?>>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
-  %0 = arith.constant 0: index
-  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3
-    : memref<?xf32, strided<[2], offset: ?>>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
-  return %1 : vector<3xf32>
-}
-
-// CHECK-LABEL: func @negative_gather_on_strided_memref
-// CHECK-NOT: llvm.intr.masked.gather
-// CHECK: vector.gather
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // vector.scatter
 //===----------------------------------------------------------------------===//
@@ -2166,19 +2152,6 @@ func.func @scatter_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %a
 // CHECK-LABEL: func @scatter_with_alignment
 // CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
 
-// -----
-
-// TODO: Implement this lowering.
-func.func @negative_scatter_on_strided_memref(%arg0: memref<?xf32, strided<[2], offset: ?>>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
-  %0 = arith.constant 0: index
-  vector.scatter %arg0[%0][%arg1], %arg2, %arg3
-    : memref<?xf32, strided<[2], offset: ?>>, vector<3xi32>, vector<3xi1>, vector<3xf32>
-  return
-}
-
-// CHECK-LABEL: func @negative_scatter_on_strided_memref
-// CHECK-NOT: llvm.intr.masked.scatter
-// CHECK: vector.scatter
 
 // -----
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c92e5c9dcb73f..3957455ccc76e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -2046,15 +2046,6 @@ func.func @load_non_pow_of_2_alignment(%memref: memref<4xi32>, %c0: index) {
 
 // -----
 
-func.func @load_non_unit_stride(%src : memref<?xi8, strided<[2], offset: ?>>) {
-  %c0 = arith.constant 0 : index
-  // expected-error @+1 {{'vector.load' op most minor memref dim must have unit stride}}
-  %0 = vector.load %src[%c0] : memref<?xi8, strided<[2], offset: ?>>, vector<16xi8>
-  return
-}
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // vector.store
 //===----------------------------------------------------------------------===//
@@ -2082,13 +2073,6 @@ func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi
   return
 }
 
-// -----
-func.func @store_non_unit_stride(%src : memref<?xi8, strided<[2], offset:?>>,%val : vector<16xi8>, %c0: index) {
-  // expected-error @below {{'vector.store' op most minor memref dim must have unit stride}}
-  vector.store %val, %src[%c0] : memref<?xi8, strided<[2], offset: ?>>, vector<16xi8>
-  return
-}
-
 // -----
 
 // Verify that vector.bitcast rejects vectors with i0 (zero-bitwidth) element type.
@@ -2143,3 +2127,4 @@ func.func @scan_i0(%a: vector<4xi0>, %init: vector<1xi0>) -> (vector<4xi0>, vect
   %0:2 = vector.scan <add>, %a, %init {inclusive = true, reduction_dim = 0 : i64} :
     vector<4xi0>, vector<1xi0>
   return %0#0, %0#1 : vector<4xi0>, vector<1xi0>
+}



More information about the Mlir-commits mailing list