[Mlir-commits] [mlir] [mlir][vector] Clarify the semantics of gather/scatter indexing (nfc) (PR #181357)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 13 04:58:42 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

For context, see:
* https://discourse.llvm.org/t/rfc-semantics-of-vector-gather-indices-with-strided-memrefs

Currently, `ConvertVectorToLLVM` rejects strided memrefs when lowering
`vector.gather` and `vector.scatter`. This PR adds tests to document
that behavior.

Supporting strided memrefs in the lowering is left as future work.
However, it is still unclear whether gather/scatter on strided memrefs
should be supported at all (see the Discourse discussion above).

This PR also adds tests for `vector.load` and `vector.store` in
`invalid.mlir` to document that these ops do not support strided
memrefs.


---
Full diff: https://github.com/llvm/llvm-project/pull/181357.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+9-1) 
- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+2) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+27) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+17) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index ddb04b6bbe40d..f49c578c91092 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2087,7 +2087,10 @@ def Vector_GatherOp :
     result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]]
                    else pass_thru[i,j]
     ```
-    The index into `base` only varies in the innermost ((k-1)-th) dimension.
+    The index into `base` only varies in the innermost ((k-1)-th) dimension and
+    is treated as a logical rather than a physical index, i.e. it does not
+    encode any potential non-identity strides in the underlying MemRef layout
+    (Tensors do not model strided layouts, so only logical indexing is possible).
 
     If a mask bit is set and the corresponding index is out-of-bounds for the
     given base, the behavior is undefined. If a mask bit is not set, the value
@@ -2180,6 +2183,11 @@ def Vector_ScatterOp
     is stored regardless of the index, and the index is allowed to be
     out-of-bounds.
 
+    The index into `base` only varies in the innermost ((k-1)-th) dimension and
+    is treated as a logical rather than a physical index, i.e. it does not
+    encode any potential non-identity strides in the underlying MemRef layout
+    (Tensors do not model strided layouts, so only logical indexing is possible).
+
     If the index vector contains two or more duplicate indices, the behavior is
     undefined. Underlying implementation may enforce strict sequential
     semantics.
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 05d541fe80356..0a255c853b07e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -290,6 +290,7 @@ class VectorGatherOpConversion
     MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
     assert(memRefType && "The base should be bufferized");
 
+    // TODO: Add support for strided MemRef.
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return rewriter.notifyMatchFailure(gather, "memref type not supported");
 
@@ -348,6 +349,7 @@ class VectorScatterOpConversion
     auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
     assert(memRefType && "The base should be bufferized");
 
+    // TODO: Add support for strided MemRef.
     if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
       return rewriter.notifyMatchFailure(scatter, "memref type not supported");
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index cb48ca3374e8d..3d075fe3b24ee 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2052,6 +2052,20 @@ func.func @gather_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %ar
 
 // -----
 
+// TODO: Implement this lowering.
+func.func @negative_gather_on_strided_memref(%arg0: memref<?xf32, strided<[2], offset: ?>>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+  %0 = arith.constant 0: index
+  %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3
+    : memref<?xf32, strided<[2], offset: ?>>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+  return %1 : vector<3xf32>
+}
+
+// CHECK-LABEL: func @negative_gather_on_strided_memref
+// CHECK-NOT: llvm.intr.masked.gather
+// CHECK: vector.gather
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.scatter
 //===----------------------------------------------------------------------===//
@@ -2138,6 +2152,19 @@ func.func @scatter_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %a
 // CHECK-LABEL: func @scatter_with_alignment
 // CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
 
+// -----
+
+// TODO: Implement this lowering.
+func.func @negative_scatter_on_strided_memref(%arg0: memref<?xf32, strided<[2], offset: ?>>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
+  %0 = arith.constant 0: index
+  vector.scatter %arg0[%0][%arg1], %arg2, %arg3
+    : memref<?xf32, strided<[2], offset: ?>>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+  return
+}
+
+// CHECK-LABEL: func @negative_scatter_on_strided_memref
+// CHECK-NOT: llvm.intr.masked.scatter
+// CHECK: vector.scatter
 
 // -----
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 28e1206ff3d0a..94cb0fe70d3fc 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -2038,6 +2038,15 @@ func.func @load_non_pow_of_2_alignment(%memref: memref<4xi32>, %c0: index) {
 
 // -----
 
+func.func @load_non_unit_stride(%src : memref<?xi8, strided<[2], offset: ?>>) {
+  %c0 = arith.constant 0 : index
+  // expected-error @+1 {{'vector.load' op most minor memref dim must have unit stride}}
+  %0 = vector.load %src[%c0] : memref<?xi8, strided<[2], offset: ?>>, vector<16xi8>
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.store
 //===----------------------------------------------------------------------===//
@@ -2064,3 +2073,11 @@ func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi
   vector.store %val, %memref[%c0] { alignment = 3 } : memref<4xi32>, vector<4xi32>
   return
 }
+
+// -----
+
+func.func @store_non_unit_stride(%src : memref<?xi8, strided<[2], offset:?>>,%val : vector<16xi8>, %c0: index) {
+  // expected-error @below {{'vector.store' op most minor memref dim must have unit stride}}
+  vector.store %val, %src[%c0] : memref<?xi8, strided<[2], offset: ?>>, vector<16xi8>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/181357


More information about the Mlir-commits mailing list