[llvm-branch-commits] [mlir] e1cc4f0 - Revert "[mlir][vector] add consistent stride verification to `masked load/sto…"
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jun 25 07:30:30 PDT 2026
Author: Ingo Müller
Date: 2026-06-25T16:30:26+02:00
New Revision: e1cc4f00bbe62b10e880008f61705278ef58d410
URL: https://github.com/llvm/llvm-project/commit/e1cc4f00bbe62b10e880008f61705278ef58d410
DIFF: https://github.com/llvm/llvm-project/commit/e1cc4f00bbe62b10e880008f61705278ef58d410.diff
LOG: Revert "[mlir][vector] add consistent stride verification to `masked load/sto…"
This reverts commit 4d4c865933e1048842f836490e02296b2cb48711.
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index a5e5095514dc2..24442a6336090 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1946,9 +1946,6 @@ def Vector_MaskedLoadOp :
: memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
```
- The memref must have non-negative strides. Negative strides are not supported
- and will trigger a verification error.
-
An optional `alignment` attribute allows to specify the byte alignment of the
load operation. It must be a positive power of 2. The operation must access
memory at an address aligned to this boundary. Violating this requirement
@@ -2044,9 +2041,6 @@ def Vector_MaskedStoreOp :
: memref<?x?xf32>, vector<16xi1>, vector<16xf32>
```
- The memref must have non-negative strides. Negative strides are not supported
- and will trigger a verification error.
-
An optional `alignment` attribute allows to specify the byte alignment of the
store operation. It must be a positive power of 2. The operation must access
memory at an address aligned to this boundary. Violating this requirement
@@ -2141,9 +2135,6 @@ def Vector_GatherOp :
during progressively lowering to bring other memory operations closer to
hardware ISA support for a gather.
- The memref must have non-negative strides. Negative strides are not supported
- and will trigger a verification error.
-
An optional `alignment` attribute allows to specify the byte alignment of the
gather operation. It must be a positive power of 2. The operation must access
memory at an address aligned to this boundary. Violating this requirement
@@ -2237,9 +2228,6 @@ def Vector_ScatterOp
correspond to those of the `llvm.masked.scatter`
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-scatter-intrinsics).
- The memref must have non-negative strides. Negative strides are not supported
- and will trigger a verification error.
-
An optional `alignment` attribute allows to specify the byte alignment of the
scatter operation. It must be a positive power of 2. The operation must access
memory at an address aligned to this boundary. Violating this requirement
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 9ce05fd70cd6b..81ffabca6ecf0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6195,9 +6195,7 @@ LogicalResult vector::LoadOp::verify() {
if (failed(verifyLoadStoreMemRefLayout(*this, resVecTy, memRefTy)))
return failure();
- // Negative strides are not supported on vector.load. The lowering to LLVM
- // emits arithmetic operations (e.g., GEP, mul) with nuw flags that assume
- // non-negative strides to avoid undefined behavior.
+ // Negative strides are not supported on vector.load.
if (memref::hasNegativeStaticStride(memRefTy))
return emitOpError("memref strides must be non-negative");
@@ -6247,9 +6245,7 @@ LogicalResult vector::StoreOp::verify() {
if (failed(verifyLoadStoreMemRefLayout(*this, valueVecTy, memRefTy)))
return failure();
- // Negative strides are not supported on vector.store. The lowering to LLVM
- // emits arithmetic operations (e.g., GEP, mul) with nuw flags that assume
- // non-negative strides to avoid undefined behavior.
+ // Negative strides are not supported on vector.store.
if (memref::hasNegativeStaticStride(memRefTy))
return emitOpError("memref strides must be non-negative");
@@ -6297,15 +6293,6 @@ LogicalResult MaskedLoadOp::verify() {
VectorType resVType = getVectorType();
MemRefType memType = getMemRefType();
- if (failed(verifyLoadStoreMemRefLayout(*this, resVType, memType)))
- return failure();
-
- // Negative strides are not supported on vector.maskedload. The lowering to
- // LLVM emits arithmetic operations (e.g., GEP, mul) with nuw flags that
- // assume non-negative strides to avoid undefined behavior.
- if (memref::hasNegativeStaticStride(memType))
- return emitOpError("memref strides must be non-negative");
-
if (failed(
verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
return failure();
@@ -6366,15 +6353,6 @@ LogicalResult MaskedStoreOp::verify() {
VectorType valueVType = getVectorType();
MemRefType memType = getMemRefType();
- if (failed(verifyLoadStoreMemRefLayout(*this, valueVType, memType)))
- return failure();
-
- // Negative strides are not supported on vector.maskedstore. The lowering to
- // LLVM emits arithmetic operations (e.g., GEP, mul) with nuw flags that
- // assume non-negative strides to avoid undefined behavior.
- if (memref::hasNegativeStaticStride(memType))
- return emitOpError("memref strides must be non-negative");
-
if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
"valueToStore")))
return failure();
@@ -6436,13 +6414,6 @@ LogicalResult GatherOp::verify() {
if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
return emitOpError("requires base to be a memref or ranked tensor type");
- // Negative strides are not supported on vector.gather.
- // The lowering to LLVM emits arithmetic operations (e.g., GEP, mul) with nuw
- // flags that assume non-negative strides to avoid undefined behavior.
- if (auto memRefType = dyn_cast<MemRefType>(baseType))
- if (memref::hasNegativeStaticStride(memRefType))
- return emitOpError("memref strides must be non-negative");
-
if (failed(
verifyElementTypesMatch(*this, baseType, resVType, "base", "result")))
return failure();
@@ -6557,13 +6528,6 @@ LogicalResult ScatterOp::verify() {
if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
return emitOpError("requires base to be a memref or ranked tensor type");
- // Negative strides are not supported on vector.scatter.
- // The lowering to LLVM emits arithmetic operations (e.g., GEP, mul) with nuw
- // flags that assume non-negative strides to avoid undefined behavior.
- if (auto memRefType = dyn_cast<MemRefType>(baseType))
- if (memref::hasNegativeStaticStride(memRefType))
- return emitOpError("memref strides must be non-negative");
-
if (failed(verifyElementTypesMatch(*this, baseType, valueVType, "base",
"valueToStore")))
return failure();
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 800ef75fde864..403581e338a6f 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1413,15 +1413,6 @@ func.func @maskedload_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>
// -----
-func.func @maskedload_negative_stride(%src: memref<100x100xf32, strided<[-100, 1]>>, %mask: vector<8xi1>, %pass: vector<8xf32>) -> vector<8xf32> {
- %c0 = arith.constant 0 : index
- // expected-error @+1 {{'vector.maskedload' op memref strides must be non-negative}}
- %0 = vector.maskedload %src[%c0, %c0], %mask, %pass : memref<100x100xf32, strided<[-100, 1]>>, vector<8xi1>, vector<8xf32> into vector<8xf32>
- return %0 : vector<8xf32>
-}
-
-// -----
-
//===----------------------------------------------------------------------===//
// vector.maskedstore
//===----------------------------------------------------------------------===//
@@ -1466,15 +1457,6 @@ func.func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1
// -----
-func.func @maskedstore_negative_stride(%src: memref<100x100xf32, strided<[-100, 1]>>, %mask: vector<8xi1>, %value: vector<8xf32>) {
- %c0 = arith.constant 0 : index
- // expected-error @+1 {{'vector.maskedstore' op memref strides must be non-negative}}
- vector.maskedstore %src[%c0, %c0], %mask, %value : memref<100x100xf32, strided<[-100, 1]>>, vector<8xi1>, vector<8xf32>
- return
-}
-
-// -----
-
func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
@@ -1572,16 +1554,6 @@ func.func @gather_tensor_alignment(%base: tensor<16xf32>, %indices: vector<16xi3
// -----
-func.func @gather_negative_stride(%src: memref<100x100xf32, strided<[-100, 1]>>, %indices: vector<16xi32>,
- %mask: vector<16xi1>, %pass_thru: vector<16xf32>, %idx: index) -> vector<16xf32> {
- // expected-error @+1 {{'vector.gather' op memref strides must be non-negative}}
- %0 = vector.gather %src[%idx, %idx][%indices], %mask, %pass_thru
- : memref<100x100xf32, strided<[-100, 1]>>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
- return %0 : vector<16xf32>
-}
-
-// -----
-
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
@@ -1670,16 +1642,6 @@ func.func @scatter_tensor_alignment(%base: tensor<?xf32>, %indices: vector<16xi3
// -----
-func.func @scatter_negative_stride(%src: memref<100x100xf32, strided<[-100, 1]>>, %indices: vector<16xi32>,
- %mask: vector<16xi1>, %value: vector<16xf32>, %idx: index) {
- // expected-error @+1 {{'vector.scatter' op memref strides must be non-negative}}
- vector.scatter %src[%idx, %idx][%indices], %mask, %value
- : memref<100x100xf32, strided<[-100, 1]>>, vector<16xi32>, vector<16xi1>, vector<16xf32>
- return
-}
-
-// -----
-
func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
// expected-error at +1 {{'vector.expandload' op base element type ('f64') does not match result element type ('f32')}}
More information about the llvm-branch-commits
mailing list