[Mlir-commits] [mlir] a78d1d9 - [mlir][vector] Add missing tests (nfc) (#186990)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 17 05:07:33 PDT 2026
Author: Andrzej WarzyĆski
Date: 2026-03-17T12:07:28Z
New Revision: a78d1d9a8b0c361353a1fec64450353e66c59a60
URL: https://github.com/llvm/llvm-project/commit/a78d1d9a8b0c361353a1fec64450353e66c59a60
DIFF: https://github.com/llvm/llvm-project/commit/a78d1d9a8b0c361353a1fec64450353e66c59a60.diff
LOG: [mlir][vector] Add missing tests (nfc) (#186990)
Currently, `ConvertVectorToLLVM` rejects strided memrefs when lowering
`vector.gather` and `vector.scatter`. This PR adds tests to document
that behavior.
Supporting strided memrefs in the lowering is left as future work.
However, it is still unclear whether gather/scatter on strided memrefs
should be supported at all (see the Discourse discussion [1]).
This PR also adds tests for `vector.load` and `vector.store` in
`invalid.mlir` to document that these ops do not support strided
memrefs.
[1] https://discourse.llvm.org/t/rfc-semantics-of-vector-gather-indices-with-strided-memrefs
Added:
Modified:
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
mlir/test/Dialect/Vector/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 9d81702581131..815909169c6b8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -290,6 +290,7 @@ class VectorGatherOpConversion
MemRefType memRefType = dyn_cast<MemRefType>(gather.getBaseType());
assert(memRefType && "The base should be bufferized");
+ // TODO: Add support for strided MemRef.
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return rewriter.notifyMatchFailure(gather, "memref type not supported");
@@ -348,6 +349,7 @@ class VectorScatterOpConversion
auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
assert(memRefType && "The base should be bufferized");
+ // TODO: Add support for strided MemRef.
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
return rewriter.notifyMatchFailure(scatter, "memref type not supported");
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 49c55f5b54496..076209cbc7a4c 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2066,6 +2066,20 @@ func.func @gather_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %ar
// -----
+// TODO: Implement this lowering.
+func.func @negative_gather_on_strided_memref(%arg0: memref<?xf32, strided<[2], offset: ?>>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> {
+ %0 = arith.constant 0: index
+ %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3
+ : memref<?xf32, strided<[2], offset: ?>>, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32>
+ return %1 : vector<3xf32>
+}
+
+// CHECK-LABEL: func @negative_gather_on_strided_memref
+// CHECK-NOT: llvm.intr.masked.gather
+// CHECK: vector.gather
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.scatter
//===----------------------------------------------------------------------===//
@@ -2152,6 +2166,19 @@ func.func @scatter_with_alignment(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %a
// CHECK-LABEL: func @scatter_with_alignment
// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xf32>, vector<3xi1> into vector<3x!llvm.ptr>
+// -----
+
+// TODO: Implement this lowering.
+func.func @negative_scatter_on_strided_memref(%arg0: memref<?xf32, strided<[2], offset: ?>>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) {
+ %0 = arith.constant 0: index
+ vector.scatter %arg0[%0][%arg1], %arg2, %arg3
+ : memref<?xf32, strided<[2], offset: ?>>, vector<3xi32>, vector<3xi1>, vector<3xf32>
+ return
+}
+
+// CHECK-LABEL: func @negative_scatter_on_strided_memref
+// CHECK-NOT: llvm.intr.masked.scatter
+// CHECK: vector.scatter
// -----
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 3957455ccc76e..d8e08c8b2a850 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -2046,6 +2046,15 @@ func.func @load_non_pow_of_2_alignment(%memref: memref<4xi32>, %c0: index) {
// -----
+func.func @load_non_unit_stride(%src : memref<?xi8, strided<[2], offset: ?>>) {
+ %c0 = arith.constant 0 : index
+ // expected-error @+1 {{'vector.load' op most minor memref dim must have unit stride}}
+ %0 = vector.load %src[%c0] : memref<?xi8, strided<[2], offset: ?>>, vector<16xi8>
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.store
//===----------------------------------------------------------------------===//
@@ -2073,6 +2082,13 @@ func.func @store_non_pow_of_2_alignment(%memref: memref<4xi32>, %val: vector<4xi
return
}
+// -----
+func.func @store_non_unit_stride(%src : memref<?xi8, strided<[2], offset:?>>,%val : vector<16xi8>, %c0: index) {
+ // expected-error @below {{'vector.store' op most minor memref dim must have unit stride}}
+ vector.store %val, %src[%c0] : memref<?xi8, strided<[2], offset: ?>>, vector<16xi8>
+ return
+}
+
// -----
// Verify that vector.bitcast rejects vectors with i0 (zero-bitwidth) element type.
More information about the Mlir-commits
mailing list