[Mlir-commits] [mlir] [mlir][vector] Clarify the semantics of gather/scatter indexing (nfc) (PR #181357)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Tue Mar 17 05:07:19 PDT 2026
https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/181357
>From b6c231199ea5805e59d25827954e928a5653bdad Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 13 Feb 2026 09:26:06 +0000
Subject: [PATCH 1/6] [mlir][vector] Clarify the semantics of gather/scatter
indexing (nfc)
For context, see:
* https://discourse.llvm.org/t/rfc-semantics-of-vector-gather-indices-with-strided-memrefs
Currently, `ConvertVectorToLLVM` rejects strided memrefs when lowering
`vector.gather` and `vector.scatter`. This PR adds tests to document
that behavior.
Supporting strided memrefs in the lowering is left as future work.
However, it is still unclear whether gather/scatter on strided memrefs
should be supported at all (see the Discourse discussion above).
This PR also adds tests for `vector.load` and `vector.store` in
`invalid.mlir` to document that these ops do not support strided
memrefs.
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 10 ++++++-
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 2 ++
.../vector-to-llvm-interface.mlir | 27 +++++++++++++++++++
mlir/test/Dialect/Vector/invalid.mlir | 17 +++++++++++-
4 files changed, 54 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43ad435ccf1c1..019a087484a01 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2098,7 +2098,10 @@ def Vector_GatherOp :
result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]]
else pass_thru[i,j]
```
- The index into `base` only varies in the innermost ((k-1)-th) dimension.
+ The index into `base` only varies in the innermost ((k-1)-th) dimension and
+ is treated as a logical rather than a physical index, i.e. it does not
+ encode any potential non-identity strides in the underlying MemRef layout
+ (Tensors do not model strided layouts, so only logical indexing is possible).
If a mask bit is set and the corresponding index is out-of-bounds for the
given base, the behavior is undefined. If a mask bit is not set, the value
@@ -2191,6 +2194,11 @@ def Vector_ScatterOp
is stored regardless of the index, and the index is allowed to be
out-of-bounds.
+ The index into `base` only varies in the innermost ((k-1)-th) dimension and
+ is treated as a logical rather than a physical index, i.e. it does not
+ encode any potential non-identity strides in the underlying MemRef layout
+ (Tensors do not model strided layouts, so only logical indexing is possible).
+
If the index vector contains two or more duplicate indices, the behavior is
undefined. Underlying implementation may enforce strict sequential
semantics.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9d81702581131..815909169c6b8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -290,6 +290,7 @@ class VectorGatherOpConversion
MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
assert(memRefType && "The base should be bufferized");
+ // TODO: Add support for strided MemRef.
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return rewriter.notifyMatchFailure(gather, "memref type not supported");
@@ -348,6 +349,7 @@ class VectorScatterOpConversion
auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
assert(memRefType && "The base should be bufferized");
+ // TODO: Add support for strided MemRef.
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return rewriter.notifyMatchFailure(scatter, "memref type not supported");
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 49c55f5b54496..076209cbc7a4c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2066,6 +2066,20 @@ func.func @gather_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %ar
// -----
+// TODO: Implement this lowering.
+func.func @negative_gather_on_strided_memref(%arg0: memref<?xf32, strided<[2], offset: ?>>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+ %0 = arith.constant 0: index
+ %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3
+ : memref<?xf32, strided<[2], offset: ?>>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+ return %1 : vector<3xf32>
+}
+
+// CHECK-LABEL: func @negative_gather_on_strided_memref
+// CHECK-NOT: llvm.intr.masked.gather
+// CHECK: vector.gather
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.scatter
//===----------------------------------------------------------------------===//
@@ -2152,6 +2166,19 @@ func.func @scatter_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %a
// CHECK-LABEL: func @scatter_with_alignment
// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+// -----
+
+// TODO: Implement this lowering.
+func.func @negative_scatter_on_strided_memref(%arg0: memref<?xf32, strided<[2], offset: ?>>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
+ %0 = arith.constant 0: index
+ vector.scatter %arg0[%0][%arg1], %arg2, %arg3
+ : memref<?xf32, strided<[2], offset: ?>>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+ return
+}
+
+// CHECK-LABEL: func @negative_scatter_on_strided_memref
+// CHECK-NOT: llvm.intr.masked.scatter
+// CHECK: vector.scatter
// -----
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3957455ccc76e..c92e5c9dcb73f 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -2046,6 +2046,15 @@ func.func @load_non_pow_of_2_alignment(%memref: memref<4xi32>, %c0: index) {
// -----
+func.func @load_non_unit_stride(%src : memref<?xi8, strided<[2], offset: ?>>) {
+ %c0 = arith.constant 0 : index
+ // expected-error @+1 {{'vector.load' op most minor memref dim must have unit stride}}
+ %0 = vector.load %src[%c0] : memref<?xi8, strided<[2], offset: ?>>, vector<16xi8>
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.store
//===----------------------------------------------------------------------===//
@@ -2073,6 +2082,13 @@ func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi
return
}
+// -----
+func.func @store_non_unit_stride(%src : memref<?xi8, strided<[2], offset:?>>,%val : vector<16xi8>, %c0: index) {
+ // expected-error @below {{'vector.store' op most minor memref dim must have unit stride}}
+ vector.store %val, %src[%c0] : memref<?xi8, strided<[2], offset: ?>>, vector<16xi8>
+ return
+}
+
// -----
// Verify that vector.bitcast rejects vectors with i0 (zero-bitwidth) element type.
@@ -2127,4 +2143,3 @@ func.func @scan_i0(%a: vector<4xi0>, %init: vector<1xi0>) -> (vector<4xi0>, vect
%0:2 = vector.scan <add>, %a, %init {inclusive = true, reduction_dim = 0 : i64} :
vector<4xi0>, vector<1xi0>
return %0#0, %0#1 : vector<4xi0>, vector<1xi0>
-}
>From 1e9b50a70ee917019fed4f18032e356ed845e72b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 16 Feb 2026 20:03:49 +0000
Subject: [PATCH 2/6] Update per Jakub's suggestion
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 51 +++++++++++++------
1 file changed, 35 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 019a087484a01..3d582f72c8a74 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2082,26 +2082,28 @@ def Vector_GatherOp :
3-D and the result is 2-D:
```mlir
- func.func @gather_3D_to_2D(
- %base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
- %indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
- %fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
- %result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2]
- [%indices], %mask, %fall_thru : [...]
- return %result : vector<2x3xf32>
+ %base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
+ %indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
+ %fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
+ %result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2]
+ [%indices], %mask, %fall_thru : [...]
}
```
The indexing semantics are then,
```
- result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]]
+ result[i,j] := if mask[i,j] then base[%ofs_0, %ofs_1, %ofs_2 + indices[i,j]]
else pass_thru[i,j]
```
- The index into `base` only varies in the innermost ((k-1)-th) dimension and
- is treated as a logical rather than a physical index, i.e. it does not
- encode any potential non-identity strides in the underlying MemRef layout
- (Tensors do not model strided layouts, so only logical indexing is possible).
+ Note, `indices` are element offsets - they are expressed in units of
+ elements (not bytes). The offset is added to the underlying memory address,
+ so if the resulting position exceeds the size of the innermost dimension,
+ it naturally advances into the next row and/or plane according to the
+ identity (row-major) layout of 3D `base` (col = dim 0, row = dim 1, plane
+ = dim 2). Importantly, `indices` are interpreted assuming an identity
+ (contiguous) MemRef layout and do not account for non-identity strides.
+
If a mask bit is set and the corresponding index is out-of-bounds for the
given base, the behavior is undefined. If a mask bit is not set, the value
@@ -2194,10 +2196,27 @@ def Vector_ScatterOp
is stored regardless of the index, and the index is allowed to be
out-of-bounds.
- The index into `base` only varies in the innermost ((k-1)-th) dimension and
- is treated as a logical rather than a physical index, i.e. it does not
- encode any potential non-identity strides in the underlying MemRef layout
- (Tensors do not model strided layouts, so only logical indexing is possible).
+ ```mlir
+ %base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
+ %indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
+ %src: vector<2x3xf32>) -> memref<?x10x?xf32> {
+ %result = vector.scatter %base[%ofs_0, %ofs_1, %ofs_2]
+ [%indices], %mask, %src : [...]
+ ```
+ The indexing semantics are then,
+
+ ```
+ if mask[i,j] then
+ base[%ofs_0, %ofs_1, %ofs_2 + indices[i,j]] := valueToStore[i,j]
+ ```
+
+ Note, `indices` are element offsets - they are expressed in units of
+ elements (not bytes). The offset is added to the underlying memory address,
+ so if the resulting position exceeds the size of the innermost dimension,
+ it naturally advances into the next row and/or plane according to the
+ identity (row-major) layout of 3D `base` (col = dim 0, row = dim 1, plane
+ = dim 2). Importantly, `indices` are interpreted assuming an identity
+ (contiguous) MemRef layout and do not account for non-identity strides.
If the index vector contains two or more duplicate indices, the behavior is
undefined. Underlying implementation may enforce strict sequential
>From a7a26c644e823a4681e550137eb5c95c92065700 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 3 Mar 2026 19:50:56 +0000
Subject: [PATCH 3/6] Tweak
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 26 ++++++++++---------
1 file changed, 14 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3d582f72c8a74..f857d0677a439 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2097,12 +2097,13 @@ def Vector_GatherOp :
else pass_thru[i,j]
```
Note, `indices` are element offsets - they are expressed in units of
- elements (not bytes). The offset is added to the underlying memory address,
- so if the resulting position exceeds the size of the innermost dimension,
- it naturally advances into the next row and/or plane according to the
- identity (row-major) layout of 3D `base` (col = dim 0, row = dim 1, plane
- = dim 2). Importantly, `indices` are interpreted assuming an identity
- (contiguous) MemRef layout and do not account for non-identity strides.
+ elements (not bytes). Each element in `indices` represents a displacement
+ in units of elements from the starting element, i.e. `%base[%ofs_0, %ofs_1,
+ %ofs_2]` for the example above. If the resulting position exceeds the size
+ of a dimension, it naturally advances into the next row and/or plane
+ according to the identity (row-major) layout of `base`. Importantly, for
+ MemRefs, `indices` are interpreted assuming an identity (contiguous) MemRef
+ layout and do not account for non-identity strides.
If a mask bit is set and the corresponding index is out-of-bounds for the
@@ -2211,12 +2212,13 @@ def Vector_ScatterOp
```
Note, `indices` are element offsets - they are expressed in units of
- elements (not bytes). The offset is added to the underlying memory address,
- so if the resulting position exceeds the size of the innermost dimension,
- it naturally advances into the next row and/or plane according to the
- identity (row-major) layout of 3D `base` (col = dim 0, row = dim 1, plane
- = dim 2). Importantly, `indices` are interpreted assuming an identity
- (contiguous) MemRef layout and do not account for non-identity strides.
+ elements (not bytes). Each element in `indices` represents a displacement
+ in units of elements from the starting element, i.e. `%base[%ofs_0, %ofs_1,
+ %ofs_2]` for the example above. If the resulting position exceeds the size
+ of a dimension, it naturally advances into the next row and/or plane
+ according to the identity (row-major) layout of `base`. Importantly, for
+ MemRefs, `indices` are interpreted assuming an identity (contiguous) MemRef
+ layout and do not account for non-identity strides.
If the index vector contains two or more duplicate indices, the behavior is
undefined. Underlying implementation may enforce strict sequential
>From fd02f731e6309deba2e7ee00807f3d667c22f902 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 17 Mar 2026 09:38:36 +0000
Subject: [PATCH 4/6] Extract changes in tests - these will be moved to a new
PR
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 2 --
.../vector-to-llvm-interface.mlir | 27 -------------------
mlir/test/Dialect/Vector/invalid.mlir | 17 +-----------
3 files changed, 1 insertion(+), 45 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 815909169c6b8..9d81702581131 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -290,7 +290,6 @@ class VectorGatherOpConversion
MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
assert(memRefType && "The base should be bufferized");
- // TODO: Add support for strided MemRef.
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return rewriter.notifyMatchFailure(gather, "memref type not supported");
@@ -349,7 +348,6 @@ class VectorScatterOpConversion
auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
assert(memRefType && "The base should be bufferized");
- // TODO: Add support for strided MemRef.
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return rewriter.notifyMatchFailure(scatter, "memref type not supported");
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 076209cbc7a4c..49c55f5b54496 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2066,20 +2066,6 @@ func.func @gather_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %ar
// -----
-// TODO: Implement this lowering.
-func.func @negative_gather_on_strided_memref(%arg0: memref<?xf32, strided<[2], offset: ?>>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
- %0 = arith.constant 0: index
- %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3
- : memref<?xf32, strided<[2], offset: ?>>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
- return %1 : vector<3xf32>
-}
-
-// CHECK-LABEL: func @negative_gather_on_strided_memref
-// CHECK-NOT: llvm.intr.masked.gather
-// CHECK: vector.gather
-
-// -----
-
//===----------------------------------------------------------------------===//
// vector.scatter
//===----------------------------------------------------------------------===//
@@ -2166,19 +2152,6 @@ func.func @scatter_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %a
// CHECK-LABEL: func @scatter_with_alignment
// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
-// -----
-
-// TODO: Implement this lowering.
-func.func @negative_scatter_on_strided_memref(%arg0: memref<?xf32, strided<[2], offset: ?>>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
- %0 = arith.constant 0: index
- vector.scatter %arg0[%0][%arg1], %arg2, %arg3
- : memref<?xf32, strided<[2], offset: ?>>, vector<3xi32>, vector<3xi1>, vector<3xf32>
- return
-}
-
-// CHECK-LABEL: func @negative_scatter_on_strided_memref
-// CHECK-NOT: llvm.intr.masked.scatter
-// CHECK: vector.scatter
// -----
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index c92e5c9dcb73f..3957455ccc76e 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -2046,15 +2046,6 @@ func.func @load_non_pow_of_2_alignment(%memref: memref<4xi32>, %c0: index) {
// -----
-func.func @load_non_unit_stride(%src : memref<?xi8, strided<[2], offset: ?>>) {
- %c0 = arith.constant 0 : index
- // expected-error @+1 {{'vector.load' op most minor memref dim must have unit stride}}
- %0 = vector.load %src[%c0] : memref<?xi8, strided<[2], offset: ?>>, vector<16xi8>
- return
-}
-
-// -----
-
//===----------------------------------------------------------------------===//
// vector.store
//===----------------------------------------------------------------------===//
@@ -2082,13 +2073,6 @@ func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi
return
}
-// -----
-func.func @store_non_unit_stride(%src : memref<?xi8, strided<[2], offset:?>>,%val : vector<16xi8>, %c0: index) {
- // expected-error @below {{'vector.store' op most minor memref dim must have unit stride}}
- vector.store %val, %src[%c0] : memref<?xi8, strided<[2], offset: ?>>, vector<16xi8>
- return
-}
-
// -----
// Verify that vector.bitcast rejects vectors with i0 (zero-bitwidth) element type.
@@ -2143,3 +2127,4 @@ func.func @scan_i0(%a: vector<4xi0>, %init: vector<1xi0>) -> (vector<4xi0>, vect
%0:2 = vector.scan <add>, %a, %init {inclusive = true, reduction_dim = 0 : i64} :
vector<4xi0>, vector<1xi0>
return %0#0, %0#1 : vector<4xi0>, vector<1xi0>
+}
>From 9513d44b9fc1f921485fec10b8fb84aa82d78378 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 17 Mar 2026 10:19:32 +0000
Subject: [PATCH 5/6] Trim the docs - remove the contested addition.
---
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 12 ++++--------
1 file changed, 4 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index f857d0677a439..81b6958062756 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2099,10 +2099,8 @@ def Vector_GatherOp :
Note, `indices` are element offsets - they are expressed in units of
elements (not bytes). Each element in `indices` represents a displacement
in units of elements from the starting element, i.e. `%base[%ofs_0, %ofs_1,
- %ofs_2]` for the example above. If the resulting position exceeds the size
- of a dimension, it naturally advances into the next row and/or plane
- according to the identity (row-major) layout of `base`. Importantly, for
- MemRefs, `indices` are interpreted assuming an identity (contiguous) MemRef
+ %ofs_2]` for the example above. Importantly, for MemRefs, `indices` are
+ %interpreted assuming an identity (contiguous) MemRef
layout and do not account for non-identity strides.
@@ -2214,10 +2212,8 @@ def Vector_ScatterOp
Note, `indices` are element offsets - they are expressed in units of
elements (not bytes). Each element in `indices` represents a displacement
in units of elements from the starting element, i.e. `%base[%ofs_0, %ofs_1,
- %ofs_2]` for the example above. If the resulting position exceeds the size
- of a dimension, it naturally advances into the next row and/or plane
- according to the identity (row-major) layout of `base`. Importantly, for
- MemRefs, `indices` are interpreted assuming an identity (contiguous) MemRef
+ %ofs_2]` for the example above. Importantly, for MemRefs, `indices` are
+ %interpreted assuming an identity (contiguous) MemRef
layout and do not account for non-identity strides.
If the index vector contains two or more duplicate indices, the behavior is
>From 11a0c50b7fb0d08869144646d139825588bf6cfe Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 17 Mar 2026 12:06:57 +0000
Subject: [PATCH 6/6] Address comments from Renato.
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 22 ++++++++++---------
1 file changed, 12 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 81b6958062756..0bd4f4ffe11b2 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2084,24 +2084,25 @@ def Vector_GatherOp :
```mlir
%base: memref<?x10x?xf32>, %ofs_0: index, %ofs_1: index, %ofs_2: index,
%indices: vector<2x3xi32>, %mask: vector<2x3xi1>,
- %fall_thru: vector<2x3xf32>) -> vector<2x3xf32> {
+ %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> {
%result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2]
- [%indices], %mask, %fall_thru : [...]
+ [%indices], %mask, %pass_thru : [...]
}
```
The indexing semantics are then,
```
- result[i,j] := if mask[i,j] then base[%ofs_0, %ofs_1, %ofs_2 + indices[i,j]]
- else pass_thru[i,j]
+ %result[i,j] := if %mask[i,j] then %base[%ofs_0, %ofs_1, %ofs_2 + %indices[i,j]]
+ else %pass_thru[i,j]
```
Note, `indices` are element offsets - they are expressed in units of
elements (not bytes). Each element in `indices` represents a displacement
in units of elements from the starting element, i.e. `%base[%ofs_0, %ofs_1,
%ofs_2]` for the example above. Importantly, for MemRefs, `indices` are
- %interpreted assuming an identity (contiguous) MemRef
- layout and do not account for non-identity strides.
+ interpreted assuming an identity (contiguous) MemRef layout. Any
+ non-identity layout (e.g. strided) is not reflected in the indices
+ themselves and is instead handled during lowering.
If a mask bit is set and the corresponding index is out-of-bounds for the
@@ -2205,16 +2206,17 @@ def Vector_ScatterOp
The indexing semantics are then,
```
- if mask[i,j] then
- base[%ofs_0, %ofs_1, %ofs_2 + indices[i,j]] := valueToStore[i,j]
+ if %mask[i,j] then
+ %base[%ofs_0, %ofs_1, %ofs_2 + %indices[i,j]] := %valueToStore[i,j]
```
Note, `indices` are element offsets - they are expressed in units of
elements (not bytes). Each element in `indices` represents a displacement
in units of elements from the starting element, i.e. `%base[%ofs_0, %ofs_1,
%ofs_2]` for the example above. Importantly, for MemRefs, `indices` are
- %interpreted assuming an identity (contiguous) MemRef
- layout and do not account for non-identity strides.
+ interpreted assuming an identity (contiguous) MemRef layout. Any
+ non-identity layout (e.g. strided) is not reflected in the indices
+ themselves and is instead handled during lowering.
If the index vector contains two or more duplicate indices, the behavior is
undefined. Underlying implementation may enforce strict sequential
More information about the Mlir-commits
mailing list