[Mlir-commits] [mlir] [mlir][Utils] Add verifyIndexCount helper (NFC) (PR #176678)

Nick Kreeger llvmlistbot at llvm.org
Sun Jan 18 17:04:12 PST 2026


https://github.com/nkreeger created https://github.com/llvm/llvm-project/pull/176678

Description:
This change builds on #174336, which introduced shared VerificationUtils and assorted helpers.

This patch adds a new verifyIndexCount() verification utility that checks if the number of indices matches the rank of a shaped type and emits consistent error messages. The utility is applied to several ops across the MemRef and Vector dialects.

>From 8659e499dda4c63f18c454d4cbf631652a7c1481 Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at gmail.com>
Date: Sun, 18 Jan 2026 18:59:54 -0600
Subject: [PATCH] [mlir][Utils] Add verifyIndexCount helper (NFC)

Description:
This change builds on #174336, which introduced shared
VerificationUtils and assorted helpers.

This patch adds a new verifyIndexCount() verification utility that
checks if the number of indices matches the rank of a shaped type and
emits consistent error messages. The utility is applied to several ops
across the MemRef and Vector dialects.
---
 .../mlir/Dialect/Utils/VerificationUtils.h    |  5 +++
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      |  6 +--
 mlir/lib/Dialect/Utils/VerificationUtils.cpp  | 10 +++++
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 41 ++++++++++---------
 mlir/test/Dialect/Linalg/invalid.mlir         |  2 +-
 mlir/test/Dialect/MemRef/invalid.mlir         |  2 +-
 mlir/test/Dialect/Vector/invalid.mlir         | 22 +++++-----
 7 files changed, 50 insertions(+), 38 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
index 3d350aae7cf2f..16dfaebbfcdd2 100644
--- a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
@@ -32,6 +32,11 @@ LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type,
 LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs,
                                StringRef lhsName, StringRef rhsName);
 
+/// Verify that the number of indices matches the rank of a shaped type.
+/// Returns failure and emits an error if the counts don't match.
+LogicalResult verifyIndexCount(Operation *op, ShapedType type,
+                               size_t indexCount);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b782a8be19154..382fcb84a6949 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1694,11 +1694,7 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
 //===----------------------------------------------------------------------===//
 
 LogicalResult LoadOp::verify() {
-  if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
-    return emitOpError("incorrect number of indices for load, expected ")
-           << getMemRefType().getRank() << " but got " << getIndices().size();
-  }
-  return success();
+  return verifyIndexCount(*this, getMemRefType(), getIndices().size());
 }
 
 OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Utils/VerificationUtils.cpp b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
index 81f1e590a76ee..78606eb23bfa9 100644
--- a/mlir/lib/Dialect/Utils/VerificationUtils.cpp
+++ b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
@@ -36,3 +36,13 @@ LogicalResult mlir::verifyRanksMatch(Operation *op, ShapedType lhs,
   }
   return success();
 }
+
+LogicalResult mlir::verifyIndexCount(Operation *op, ShapedType type,
+                                     size_t indexCount) {
+  int64_t rank = type.getRank();
+  if (static_cast<int64_t>(indexCount) != rank) {
+    return op->emitOpError("incorrect number of indices, expected ")
+           << rank << " but got " << indexCount;
+  }
+  return success();
+}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 085f879c2d0e6..ca5c791532f18 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
@@ -5001,8 +5002,8 @@ LogicalResult TransferReadOp::verify() {
                : VectorType();
   auto sourceElementType = shapedType.getElementType();
 
-  if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
-    return emitOpError("requires ") << shapedType.getRank() << " indices";
+  if (failed(verifyIndexCount(*this, shapedType, getIndices().size())))
+    return failure();
 
   if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
                               shapedType, vectorType, maskType,
@@ -5462,8 +5463,8 @@ LogicalResult TransferWriteOp::verify() {
       maskType ? inferTransferOpMaskType(vectorType, permutationMap)
                : VectorType();
 
-  if (llvm::size(getIndices()) != shapedType.getRank())
-    return emitOpError("requires ") << shapedType.getRank() << " indices";
+  if (failed(verifyIndexCount(*this, shapedType, llvm::size(getIndices()))))
+    return failure();
 
   // We do not allow broadcast dimensions on TransferWriteOps for the moment,
   // as the semantics is unclear. This can be revisited later if necessary.
@@ -5853,8 +5854,8 @@ LogicalResult vector::LoadOp::verify() {
 
   if (resVecTy.getElementType() != memElemTy)
     return emitOpError("base and result element types should match");
-  if (llvm::size(getIndices()) != memRefTy.getRank())
-    return emitOpError("requires ") << memRefTy.getRank() << " indices";
+  if (failed(verifyIndexCount(*this, memRefTy, llvm::size(getIndices()))))
+    return failure();
   return success();
 }
 
@@ -5899,8 +5900,8 @@ LogicalResult vector::StoreOp::verify() {
 
   if (valueVecTy.getElementType() != memElemTy)
     return emitOpError("base and valueToStore element type should match");
-  if (llvm::size(getIndices()) != memRefTy.getRank())
-    return emitOpError("requires ") << memRefTy.getRank() << " indices";
+  if (failed(verifyIndexCount(*this, memRefTy, llvm::size(getIndices()))))
+    return failure();
   return success();
 }
 
@@ -5931,8 +5932,8 @@ LogicalResult MaskedLoadOp::verify() {
 
   if (resVType.getElementType() != memType.getElementType())
     return emitOpError("base and result element type should match");
-  if (llvm::size(getIndices()) != memType.getRank())
-    return emitOpError("requires ") << memType.getRank() << " indices";
+  if (failed(verifyIndexCount(*this, memType, llvm::size(getIndices()))))
+    return failure();
   if (resVType.getShape() != maskVType.getShape())
     return emitOpError("expected result shape to match mask shape");
   if (resVType != passVType)
@@ -5990,8 +5991,8 @@ LogicalResult MaskedStoreOp::verify() {
 
   if (valueVType.getElementType() != memType.getElementType())
     return emitOpError("base and valueToStore element type should match");
-  if (llvm::size(getIndices()) != memType.getRank())
-    return emitOpError("requires ") << memType.getRank() << " indices";
+  if (failed(verifyIndexCount(*this, memType, llvm::size(getIndices()))))
+    return failure();
   if (valueVType.getShape() != maskVType.getShape())
     return emitOpError("expected valueToStore shape to match mask shape");
   return success();
@@ -6050,8 +6051,8 @@ LogicalResult GatherOp::verify() {
 
   if (resVType.getElementType() != baseType.getElementType())
     return emitOpError("base and result element type should match");
-  if (llvm::size(getOffsets()) != baseType.getRank())
-    return emitOpError("requires ") << baseType.getRank() << " indices";
+  if (failed(verifyIndexCount(*this, baseType, llvm::size(getOffsets()))))
+    return failure();
   if (resVType.getShape() != indVType.getShape())
     return emitOpError("expected result dim to match indices dim");
   if (resVType.getShape() != maskVType.getShape())
@@ -6159,8 +6160,8 @@ LogicalResult ScatterOp::verify() {
 
   if (valueVType.getElementType() != baseType.getElementType())
     return emitOpError("base and valueToStore element type should match");
-  if (llvm::size(getOffsets()) != baseType.getRank())
-    return emitOpError("requires ") << baseType.getRank() << " indices";
+  if (failed(verifyIndexCount(*this, baseType, llvm::size(getOffsets()))))
+    return failure();
   if (valueVType.getShape() != indVType.getShape())
     return emitOpError("expected valueToStore dim to match indices dim");
   if (valueVType.getShape() != maskVType.getShape())
@@ -6242,8 +6243,8 @@ LogicalResult ExpandLoadOp::verify() {
 
   if (resVType.getElementType() != memType.getElementType())
     return emitOpError("base and result element type should match");
-  if (llvm::size(getIndices()) != memType.getRank())
-    return emitOpError("requires ") << memType.getRank() << " indices";
+  if (failed(verifyIndexCount(*this, memType, llvm::size(getIndices()))))
+    return failure();
   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
     return emitOpError("expected result dim to match mask dim");
   if (resVType != passVType)
@@ -6295,8 +6296,8 @@ LogicalResult CompressStoreOp::verify() {
 
   if (valueVType.getElementType() != memType.getElementType())
     return emitOpError("base and valueToStore element type should match");
-  if (llvm::size(getIndices()) != memType.getRank())
-    return emitOpError("requires ") << memType.getRank() << " indices";
+  if (failed(verifyIndexCount(*this, memType, llvm::size(getIndices()))))
+    return failure();
   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
     return emitOpError("expected valueToStore dim to match mask dim");
   return success();
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 355d801f8732c..10a4602626e80 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
 func.func @load_number_of_indices(%v : memref<f32>) {
-  // expected-error @+2 {{incorrect number of indices for load}}
+  // expected-error @+2 {{incorrect number of indices, expected 0 but got 1}}
   %c0 = arith.constant 0 : index
   memref.load %v[%c0] : memref<f32>
 }
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 2dfe733c13768..3a4471e9a6387 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -948,7 +948,7 @@ func.func @bad_alloc_wrong_symbol_count() {
 func.func @load_invalid_memref_indexes() {
   %0 = memref.alloca() : memref<10xi32>
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{incorrect number of indices for load, expected 1 but got 2}}
+  // expected-error at +1 {{incorrect number of indices, expected 1 but got 2}}
   %1 = memref.load %0[%c0, %c0] : memref<10xi32>
 }
 
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 54ad3d8ab0950..92da4fc6a35be 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -348,7 +348,7 @@ func.func @test_vector.transfer_read(%arg0: memref<4x3xf32>) {
 func.func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
   %c3 = arith.constant 3 : index
   %cst = arith.constant 3.0 : f32
-  // expected-error at +1 {{requires 2 indices}}
+  // expected-error at +1 {{incorrect number of indices, expected 2 but got 3}}
   %0 = vector.transfer_read %arg0[%c3, %c3, %c3], %cst { permutation_map = affine_map<()->(0)> } : memref<?x?xf32>, vector<128xf32>
 }
 
@@ -357,7 +357,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
 func.func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
   %c3 = arith.constant 3 : index
   %cst = arith.constant 3.0 : f32
-  // expected-error at +1 {{requires 2 indices}}
+  // expected-error at +1 {{incorrect number of indices, expected 2 but got 1}}
   %0 = vector.transfer_read %arg0[%c3], %cst { permutation_map = affine_map<()->(0)> } : memref<?x?xf32>, vector<128xf32>
 }
 
@@ -529,7 +529,7 @@ func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
 func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
   %c3 = arith.constant 3 : index
   %cst = arith.constant dense<3.0> : vector<128 x f32>
-  // expected-error at +1 {{requires 2 indices}}
+  // expected-error at +1 {{incorrect number of indices, expected 2 but got 3}}
   vector.transfer_write %cst, %arg0[%c3, %c3, %c3] {permutation_map = affine_map<()->(0)>} : vector<128xf32>, memref<?x?xf32>
 }
 
@@ -538,7 +538,7 @@ func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
 func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
   %c3 = arith.constant 3 : index
   %cst = arith.constant dense<3.0> : vector<128 x f32>
-  // expected-error at +1 {{requires 2 indices}}
+  // expected-error at +1 {{incorrect number of indices, expected 2 but got 1}}
   vector.transfer_write %cst, %arg0[%c3] {permutation_map = affine_map<()->(0)>} : vector<128xf32>, memref<?x?xf32>
 }
 
@@ -1328,7 +1328,7 @@ func.func @store_base_type_mismatch(%base : memref<?xf64>, %value : vector<16xf3
 // -----
 
 func.func @store_memref_index_mismatch(%base : memref<?xf32>, %value : vector<16xf32>) {
-  // expected-error at +1 {{'vector.store' op requires 1 indices}}
+  // expected-error at +1 {{'vector.store' op incorrect number of indices, expected 1 but got 0}}
   vector.store %value, %base[] : memref<?xf32>, vector<16xf32>
 }
 
@@ -1379,7 +1379,7 @@ func.func @maskedload_pass_thru_type_mask_mismatch(%base: memref<?xf32>, %mask:
 // -----
 
 func.func @maskedload_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
-  // expected-error at +1 {{'vector.maskedload' op requires 1 indices}}
+  // expected-error at +1 {{'vector.maskedload' op incorrect number of indices, expected 1 but got 0}}
   %0 = vector.maskedload %base[], %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
 
@@ -1423,7 +1423,7 @@ func.func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15x
 
 func.func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.maskedstore' op requires 1 indices}}
+  // expected-error at +1 {{'vector.maskedstore' op incorrect number of indices, expected 1 but got 2}}
   vector.maskedstore %base[%c0, %c0], %mask, %value : memref<?xf32>, vector<16xi1>, vector<16xf32>
 }
 
@@ -1452,7 +1452,7 @@ func.func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi
 func.func @gather_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi32>,
                              %mask: vector<16xi1>, %pass_thru: vector<16xf64>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.gather' op requires 2 indices}}
+  // expected-error at +1 {{'vector.gather' op incorrect number of indices, expected 2 but got 1}}
   %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : memref<?x?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf64> into vector<16xf64>
 }
@@ -1541,7 +1541,7 @@ func.func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16x
 func.func @scatter_memref_mismatch(%base: memref<?x?xf64>, %indices: vector<16xi32>,
                               %mask: vector<16xi1>, %value: vector<16xf64>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.scatter' op requires 2 indices}}
+  // expected-error at +1 {{'vector.scatter' op incorrect number of indices, expected 2 but got 1}}
   vector.scatter %base[%c0][%indices], %mask, %value
     : memref<?x?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf64>
 }
@@ -1630,7 +1630,7 @@ func.func @expand_pass_thru_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>,
 
 func.func @expand_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.expandload' op requires 2 indices}}
+  // expected-error at +1 {{'vector.expandload' op incorrect number of indices, expected 2 but got 1}}
   %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
 
@@ -1676,7 +1676,7 @@ func.func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>
 
 func.func @compress_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.compressstore' op requires 2 indices}}
+  // expected-error at +1 {{'vector.compressstore' op incorrect number of indices, expected 2 but got 3}}
   vector.compressstore %base[%c0, %c0, %c0], %mask, %value : memref<?x?xf32>, vector<16xi1>, vector<16xf32>
 }
 



More information about the Mlir-commits mailing list