[Mlir-commits] [mlir] [mlir][Utils] Add verifyIndexCount helper (NFC) (PR #176678)
Nick Kreeger
llvmlistbot at llvm.org
Sun Jan 18 17:04:12 PST 2026
https://github.com/nkreeger created https://github.com/llvm/llvm-project/pull/176678
Description:
This change builds on #174336, which introduced shared VerificationUtils and assorted helpers.
This patch adds a new verifyIndexCount() verification utility that checks if the number of indices matches the rank of a shaped type and emits consistent error messages. The utility is applied to several ops across the MemRef and Vector dialects.
>From 8659e499dda4c63f18c454d4cbf631652a7c1481 Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at gmail.com>
Date: Sun, 18 Jan 2026 18:59:54 -0600
Subject: [PATCH] [mlir][Utils] Add verifyIndexCount helper (NFC)
Description:
This change builds on #174336, which introduced shared
VerificationUtils and assorted helpers.
This patch adds a new verifyIndexCount() verification utility that
checks if the number of indices matches the rank of a shaped type and
emits consistent error messages. The utility is applied to several ops
across the MemRef and Vector dialects.
---
.../mlir/Dialect/Utils/VerificationUtils.h | 5 +++
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 6 +--
mlir/lib/Dialect/Utils/VerificationUtils.cpp | 10 +++++
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 41 ++++++++++---------
mlir/test/Dialect/Linalg/invalid.mlir | 2 +-
mlir/test/Dialect/MemRef/invalid.mlir | 2 +-
mlir/test/Dialect/Vector/invalid.mlir | 22 +++++-----
7 files changed, 50 insertions(+), 38 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
index 3d350aae7cf2f..16dfaebbfcdd2 100644
--- a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
@@ -32,6 +32,11 @@ LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type,
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs,
StringRef lhsName, StringRef rhsName);
+/// Verify that the number of indices matches the rank of a shaped type.
+/// Returns failure and emits an error if the counts don't match.
+LogicalResult verifyIndexCount(Operation *op, ShapedType type,
+ size_t indexCount);
+
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b782a8be19154..382fcb84a6949 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1694,11 +1694,7 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
//===----------------------------------------------------------------------===//
LogicalResult LoadOp::verify() {
- if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
- return emitOpError("incorrect number of indices for load, expected ")
- << getMemRefType().getRank() << " but got " << getIndices().size();
- }
- return success();
+ return verifyIndexCount(*this, getMemRefType(), getIndices().size());
}
OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Utils/VerificationUtils.cpp b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
index 81f1e590a76ee..78606eb23bfa9 100644
--- a/mlir/lib/Dialect/Utils/VerificationUtils.cpp
+++ b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
@@ -36,3 +36,13 @@ LogicalResult mlir::verifyRanksMatch(Operation *op, ShapedType lhs,
}
return success();
}
+
+LogicalResult mlir::verifyIndexCount(Operation *op, ShapedType type,
+ size_t indexCount) {
+ int64_t rank = type.getRank();
+ if (static_cast<int64_t>(indexCount) != rank) {
+ return op->emitOpError("incorrect number of indices, expected ")
+ << rank << " but got " << indexCount;
+ }
+ return success();
+}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 085f879c2d0e6..ca5c791532f18 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -5001,8 +5002,8 @@ LogicalResult TransferReadOp::verify() {
: VectorType();
auto sourceElementType = shapedType.getElementType();
- if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
- return emitOpError("requires ") << shapedType.getRank() << " indices";
+ if (failed(verifyIndexCount(*this, shapedType, getIndices().size())))
+ return failure();
if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
shapedType, vectorType, maskType,
@@ -5462,8 +5463,8 @@ LogicalResult TransferWriteOp::verify() {
maskType ? inferTransferOpMaskType(vectorType, permutationMap)
: VectorType();
- if (llvm::size(getIndices()) != shapedType.getRank())
- return emitOpError("requires ") << shapedType.getRank() << " indices";
+ if (failed(verifyIndexCount(*this, shapedType, llvm::size(getIndices()))))
+ return failure();
// We do not allow broadcast dimensions on TransferWriteOps for the moment,
// as the semantics is unclear. This can be revisited later if necessary.
@@ -5853,8 +5854,8 @@ LogicalResult vector::LoadOp::verify() {
if (resVecTy.getElementType() != memElemTy)
return emitOpError("base and result element types should match");
- if (llvm::size(getIndices()) != memRefTy.getRank())
- return emitOpError("requires ") << memRefTy.getRank() << " indices";
+ if (failed(verifyIndexCount(*this, memRefTy, llvm::size(getIndices()))))
+ return failure();
return success();
}
@@ -5899,8 +5900,8 @@ LogicalResult vector::StoreOp::verify() {
if (valueVecTy.getElementType() != memElemTy)
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(getIndices()) != memRefTy.getRank())
- return emitOpError("requires ") << memRefTy.getRank() << " indices";
+ if (failed(verifyIndexCount(*this, memRefTy, llvm::size(getIndices()))))
+ return failure();
return success();
}
@@ -5931,8 +5932,8 @@ LogicalResult MaskedLoadOp::verify() {
if (resVType.getElementType() != memType.getElementType())
return emitOpError("base and result element type should match");
- if (llvm::size(getIndices()) != memType.getRank())
- return emitOpError("requires ") << memType.getRank() << " indices";
+ if (failed(verifyIndexCount(*this, memType, llvm::size(getIndices()))))
+ return failure();
if (resVType.getShape() != maskVType.getShape())
return emitOpError("expected result shape to match mask shape");
if (resVType != passVType)
@@ -5990,8 +5991,8 @@ LogicalResult MaskedStoreOp::verify() {
if (valueVType.getElementType() != memType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(getIndices()) != memType.getRank())
- return emitOpError("requires ") << memType.getRank() << " indices";
+ if (failed(verifyIndexCount(*this, memType, llvm::size(getIndices()))))
+ return failure();
if (valueVType.getShape() != maskVType.getShape())
return emitOpError("expected valueToStore shape to match mask shape");
return success();
@@ -6050,8 +6051,8 @@ LogicalResult GatherOp::verify() {
if (resVType.getElementType() != baseType.getElementType())
return emitOpError("base and result element type should match");
- if (llvm::size(getOffsets()) != baseType.getRank())
- return emitOpError("requires ") << baseType.getRank() << " indices";
+ if (failed(verifyIndexCount(*this, baseType, llvm::size(getOffsets()))))
+ return failure();
if (resVType.getShape() != indVType.getShape())
return emitOpError("expected result dim to match indices dim");
if (resVType.getShape() != maskVType.getShape())
@@ -6159,8 +6160,8 @@ LogicalResult ScatterOp::verify() {
if (valueVType.getElementType() != baseType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(getOffsets()) != baseType.getRank())
- return emitOpError("requires ") << baseType.getRank() << " indices";
+ if (failed(verifyIndexCount(*this, baseType, llvm::size(getOffsets()))))
+ return failure();
if (valueVType.getShape() != indVType.getShape())
return emitOpError("expected valueToStore dim to match indices dim");
if (valueVType.getShape() != maskVType.getShape())
@@ -6242,8 +6243,8 @@ LogicalResult ExpandLoadOp::verify() {
if (resVType.getElementType() != memType.getElementType())
return emitOpError("base and result element type should match");
- if (llvm::size(getIndices()) != memType.getRank())
- return emitOpError("requires ") << memType.getRank() << " indices";
+ if (failed(verifyIndexCount(*this, memType, llvm::size(getIndices()))))
+ return failure();
if (resVType.getDimSize(0) != maskVType.getDimSize(0))
return emitOpError("expected result dim to match mask dim");
if (resVType != passVType)
@@ -6295,8 +6296,8 @@ LogicalResult CompressStoreOp::verify() {
if (valueVType.getElementType() != memType.getElementType())
return emitOpError("base and valueToStore element type should match");
- if (llvm::size(getIndices()) != memType.getRank())
- return emitOpError("requires ") << memType.getRank() << " indices";
+ if (failed(verifyIndexCount(*this, memType, llvm::size(getIndices()))))
+ return failure();
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
return emitOpError("expected valueToStore dim to match mask dim");
return success();
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 355d801f8732c..10a4602626e80 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
func.func @load_number_of_indices(%v : memref<f32>) {
- // expected-error @+2 {{incorrect number of indices for load}}
+ // expected-error @+2 {{incorrect number of indices, expected 0 but got 1}}
%c0 = arith.constant 0 : index
memref.load %v[%c0] : memref<f32>
}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 2dfe733c13768..3a4471e9a6387 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -948,7 +948,7 @@ func.func @bad_alloc_wrong_symbol_count() {
func.func @load_invalid_memref_indexes() {
%0 = memref.alloca() : memref<10xi32>
%c0 = arith.constant 0 : index
- // expected-error at +1 {{incorrect number of indices for load, expected 1 but got 2}}
+ // expected-error at +1 {{incorrect number of indices, expected 1 but got 2}}
%1 = memref.load %0[%c0, %c0] : memref<10xi32>
}
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 54ad3d8ab0950..92da4fc6a35be 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -348,7 +348,7 @@ func.func @test_vector.transfer_read(%arg0: memref<4x3xf32>) {
func.func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
%c3 = arith.constant 3 : index
%cst = arith.constant 3.0 : f32
- // expected-error at +1 {{requires 2 indices}}
+ // expected-error at +1 {{incorrect number of indices, expected 2 but got 3}}
%0 = vector.transfer_read %arg0[%c3, %c3, %c3], %cst { permutation_map = affine_map<()->(0)> } : memref<?x?xf32>, vector<128xf32>
}
@@ -357,7 +357,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
func.func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
%c3 = arith.constant 3 : index
%cst = arith.constant 3.0 : f32
- // expected-error at +1 {{requires 2 indices}}
+ // expected-error at +1 {{incorrect number of indices, expected 2 but got 1}}
%0 = vector.transfer_read %arg0[%c3], %cst { permutation_map = affine_map<()->(0)> } : memref<?x?xf32>, vector<128xf32>
}
@@ -529,7 +529,7 @@ func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
%c3 = arith.constant 3 : index
%cst = arith.constant dense<3.0> : vector<128 x f32>
- // expected-error at +1 {{requires 2 indices}}
+ // expected-error at +1 {{incorrect number of indices, expected 2 but got 3}}
vector.transfer_write %cst, %arg0[%c3, %c3, %c3] {permutation_map = affine_map<()->(0)>} : vector<128xf32>, memref<?x?xf32>
}
@@ -538,7 +538,7 @@ func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
%c3 = arith.constant 3 : index
%cst = arith.constant dense<3.0> : vector<128 x f32>
- // expected-error at +1 {{requires 2 indices}}
+ // expected-error at +1 {{incorrect number of indices, expected 2 but got 1}}
vector.transfer_write %cst, %arg0[%c3] {permutation_map = affine_map<()->(0)>} : vector<128xf32>, memref<?x?xf32>
}
@@ -1328,7 +1328,7 @@ func.func @store_base_type_mismatch(%base : memref<?xf64>, %value : vector<16xf3
// -----
func.func @store_memref_index_mismatch(%base : memref<?xf32>, %value : vector<16xf32>) {
- // expected-error at +1 {{'vector.store' op requires 1 indices}}
+ // expected-error at +1 {{'vector.store' op incorrect number of indices, expected 1 but got 0}}
vector.store %value, %base[] : memref<?xf32>, vector<16xf32>
}
@@ -1379,7 +1379,7 @@ func.func @maskedload_pass_thru_type_mask_mismatch(%base: memref<?xf32>, %mask:
// -----
func.func @maskedload_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
- // expected-error at +1 {{'vector.maskedload' op requires 1 indices}}
+ // expected-error at +1 {{'vector.maskedload' op incorrect number of indices, expected 1 but got 0}}
%0 = vector.maskedload %base[], %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
@@ -1423,7 +1423,7 @@ func.func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15x
func.func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{'vector.maskedstore' op requires 1 indices}}
+ // expected-error at +1 {{'vector.maskedstore' op incorrect number of indices, expected 1 but got 2}}
vector.maskedstore %base[%c0, %c0], %mask, %value : memref<?xf32>, vector<16xi1>, vector<16xf32>
}
@@ -1452,7 +1452,7 @@ func.func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi
func.func @gather_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %pass_thru: vector<16xf64>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{'vector.gather' op requires 2 indices}}
+ // expected-error at +1 {{'vector.gather' op incorrect number of indices, expected 2 but got 1}}
%0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
: memref<?x?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf64> into vector<16xf64>
}
@@ -1541,7 +1541,7 @@ func.func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16x
func.func @scatter_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi32>,
%mask: vector<16xi1>, %value: vector<16xf64>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{'vector.scatter' op requires 2 indices}}
+ // expected-error at +1 {{'vector.scatter' op incorrect number of indices, expected 2 but got 1}}
vector.scatter %base[%c0][%indices], %mask, %value
: memref<?x?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf64>
}
@@ -1630,7 +1630,7 @@ func.func @expand_pass_thru_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>,
func.func @expand_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{'vector.expandload' op requires 2 indices}}
+ // expected-error at +1 {{'vector.expandload' op incorrect number of indices, expected 2 but got 1}}
%0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
}
@@ -1676,7 +1676,7 @@ func.func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>
func.func @compress_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
%c0 = arith.constant 0 : index
- // expected-error at +1 {{'vector.compressstore' op requires 2 indices}}
+ // expected-error at +1 {{'vector.compressstore' op incorrect number of indices, expected 2 but got 3}}
vector.compressstore %base[%c0, %c0, %c0], %mask, %value : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
}
More information about the Mlir-commits
mailing list