[Mlir-commits] [mlir] [mlir][Utils] Add verifyElementTypesMatch helper (NFC) (PR #176668)

Nick Kreeger llvmlistbot at llvm.org
Sun Jan 18 14:02:55 PST 2026


https://github.com/nkreeger updated https://github.com/llvm/llvm-project/pull/176668

>From 51a4900db6238a6b7d4af67d7b61a443d1f6e66b Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at gmail.com>
Date: Sun, 18 Jan 2026 15:56:22 -0600
Subject: [PATCH 1/2] [mlir][Utils] Add verifyElementTypesMatch helper (NFC)

This change builds on #174336 and #175880, which introduced shared
VerificationUtils with verifyDynamicDimensionCount() and
verifyRanksMatch() methods.

This patch adds a new verifyElementTypesMatch() verification utility
that checks if two shaped types have matching element types and emits
consistent error messages. The utility is applied to several ops across
the MemRef and Vector dialects.
---
 .../mlir/Dialect/Utils/VerificationUtils.h    |  6 ++++
 mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp      | 13 ++++---
 mlir/lib/Dialect/Utils/VerificationUtils.cpp  | 14 ++++++++
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 36 +++++++++++--------
 mlir/test/Dialect/MemRef/invalid.mlir         |  4 +--
 mlir/test/Dialect/Vector/invalid.mlir         | 14 ++++----
 6 files changed, 57 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
index 3d350aae7cf2f..b0f4102d15d62 100644
--- a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
@@ -32,6 +32,12 @@ LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type,
 LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs,
                                StringRef lhsName, StringRef rhsName);
 
+/// Verify that two shaped types have matching element types. Returns failure
+/// and emits an error if element types don't match.
+LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs,
+                                      ShapedType rhs, StringRef lhsName,
+                                      StringRef rhsName);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b782a8be19154..2f6cb52d4a339 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -340,10 +340,9 @@ LogicalResult ReallocOp::verify() {
            << sourceType << " and result memref type " << resultType;
 
   // The source memref and the result memref should have the same element type.
-  if (sourceType.getElementType() != resultType.getElementType())
-    return emitError("different element types specified for source memref "
-                     "type ")
-           << sourceType << " and result memref type " << resultType;
+  if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
+                                     "result")))
+    return failure();
 
   // Verify that we have the dynamic dimension operand when it is needed.
   if (resultType.getNumDynamicDims() && !getDynamicResultSize())
@@ -1948,9 +1947,9 @@ LogicalResult ReinterpretCastOp::verify() {
   if (srcType.getMemorySpace() != resultType.getMemorySpace())
     return emitError("different memory spaces specified for source type ")
            << srcType << " and result memref type " << resultType;
-  if (srcType.getElementType() != resultType.getElementType())
-    return emitError("different element types specified for source type ")
-           << srcType << " and result memref type " << resultType;
+  if (failed(verifyElementTypesMatch(*this, srcType, resultType, "source",
+                                     "result")))
+    return failure();
 
   // Match sizes in result memref type and in static_sizes attribute.
   for (auto [idx, resultSize, expectedSize] :
diff --git a/mlir/lib/Dialect/Utils/VerificationUtils.cpp b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
index 81f1e590a76ee..95a3348d5de90 100644
--- a/mlir/lib/Dialect/Utils/VerificationUtils.cpp
+++ b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
@@ -36,3 +36,17 @@ LogicalResult mlir::verifyRanksMatch(Operation *op, ShapedType lhs,
   }
   return success();
 }
+
+LogicalResult mlir::verifyElementTypesMatch(Operation *op, ShapedType lhs,
+                                            ShapedType rhs, StringRef lhsName,
+                                            StringRef rhsName) {
+  Type lhsElementType = lhs.getElementType();
+  Type rhsElementType = rhs.getElementType();
+  if (lhsElementType != rhsElementType) {
+    return op->emitOpError()
+           << lhsName << " element type (" << lhsElementType
+           << ") does not match " << rhsName << " element type ("
+           << rhsElementType << ")";
+  }
+  return success();
+}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 085f879c2d0e6..3f8d6b5236dee 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
@@ -5929,8 +5930,9 @@ LogicalResult MaskedLoadOp::verify() {
   VectorType resVType = getVectorType();
   MemRefType memType = getMemRefType();
 
-  if (resVType.getElementType() != memType.getElementType())
-    return emitOpError("base and result element type should match");
+  if (failed(verifyElementTypesMatch(*this, memType, resVType, "base",
+                                     "result")))
+    return failure();
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
   if (resVType.getShape() != maskVType.getShape())
@@ -5988,8 +5990,9 @@ LogicalResult MaskedStoreOp::verify() {
   VectorType valueVType = getVectorType();
   MemRefType memType = getMemRefType();
 
-  if (valueVType.getElementType() != memType.getElementType())
-    return emitOpError("base and valueToStore element type should match");
+  if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
+                                     "valueToStore")))
+    return failure();
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
   if (valueVType.getShape() != maskVType.getShape())
@@ -6048,8 +6051,9 @@ LogicalResult GatherOp::verify() {
   if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
     return emitOpError("requires base to be a memref or ranked tensor type");
 
-  if (resVType.getElementType() != baseType.getElementType())
-    return emitOpError("base and result element type should match");
+  if (failed(verifyElementTypesMatch(*this, baseType, resVType, "base",
+                                     "result")))
+    return failure();
   if (llvm::size(getOffsets()) != baseType.getRank())
     return emitOpError("requires ") << baseType.getRank() << " indices";
   if (resVType.getShape() != indVType.getShape())
@@ -6157,8 +6161,9 @@ LogicalResult ScatterOp::verify() {
   if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
     return emitOpError("requires base to be a memref or ranked tensor type");
 
-  if (valueVType.getElementType() != baseType.getElementType())
-    return emitOpError("base and valueToStore element type should match");
+  if (failed(verifyElementTypesMatch(*this, baseType, valueVType, "base",
+                                     "valueToStore")))
+    return failure();
   if (llvm::size(getOffsets()) != baseType.getRank())
     return emitOpError("requires ") << baseType.getRank() << " indices";
   if (valueVType.getShape() != indVType.getShape())
@@ -6240,8 +6245,9 @@ LogicalResult ExpandLoadOp::verify() {
   VectorType resVType = getVectorType();
   MemRefType memType = getMemRefType();
 
-  if (resVType.getElementType() != memType.getElementType())
-    return emitOpError("base and result element type should match");
+  if (failed(verifyElementTypesMatch(*this, memType, resVType, "base",
+                                     "result")))
+    return failure();
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
@@ -6293,8 +6299,9 @@ LogicalResult CompressStoreOp::verify() {
   VectorType valueVType = getVectorType();
   MemRefType memType = getMemRefType();
 
-  if (valueVType.getElementType() != memType.getElementType())
-    return emitOpError("base and valueToStore element type should match");
+  if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
+                                     "valueToStore")))
+    return failure();
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
@@ -6355,8 +6362,9 @@ LogicalResult ShapeCastOp::verify() {
   VectorType resultType = getResultVectorType();
 
   // Check that element type is preserved
-  if (sourceType.getElementType() != resultType.getElementType())
-    return emitOpError("has different source and result element types");
+  if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
+                                     "result")))
+    return failure();
 
   // Check that number of elements is preserved
   int64_t sourceNElms = sourceType.getNumElements();
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 2dfe733c13768..46e010fc878fe 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -159,7 +159,7 @@ func.func @memref_reinterpret_cast_too_many_offsets(%in: memref<?xf32>) {
 // -----
 
 func.func @memref_reinterpret_cast_incompatible_element_types(%in: memref<*xf32>) {
-  // expected-error @+1 {{different element types specified}}
+  // expected-error @+1 {{source element type ('f32') does not match result element type ('i32')}}
   %out = memref.reinterpret_cast %in to
            offset: [0], sizes: [10], strides: [1]
          : memref<*xf32> to memref<10xi32, strided<[1], offset: 0>>
@@ -1144,7 +1144,7 @@ func.func @memref_realloc_sizes_2(%src : memref<?xf32>, %d : index)
 // -----
 
 func.func @memref_realloc_type(%src : memref<256xf32>) -> memref<?xi32>{
-  // expected-error at +1 {{different element types}}
+  // expected-error at +1 {{source element type ('f32') does not match result element type ('i32')}}
   %0 = memref.realloc %src : memref<256xf32> to memref<?xi32>
   return %0 : memref<?xi32>
 }
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 54ad3d8ab0950..28e1206ff3d0a 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1133,7 +1133,7 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
 
 
 func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error at +1 {{'vector.shape_cast' op has different source and result element types}}
+  // expected-error at +1 {{'vector.shape_cast' op source element type ('f32') does not match result element type ('i32')}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
 }
 
@@ -1356,7 +1356,7 @@ func.func @maskedload_non_power_of_2_alignment(%base: memref<4xi32>, %mask: vect
 
 func.func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.maskedload' op base and result element type should match}}
+  // expected-error at +1 {{'vector.maskedload' op base element type ('f64') does not match result element type ('f32')}}
   %0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
 
@@ -1407,7 +1407,7 @@ func.func @maskedstore_non_power_of_2_alignment(%base: memref<4xi32>, %mask: vec
 
 func.func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.maskedstore' op base and valueToStore element type should match}}
+  // expected-error at +1 {{'vector.maskedstore' op base element type ('f64') does not match valueToStore element type ('f32')}}
   vector.maskedstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
 }
 
@@ -1442,7 +1442,7 @@ func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
 func.func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                 %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.gather' op base and result element type should match}}
+  // expected-error at +1 {{'vector.gather' op base element type ('f64') does not match result element type ('f32')}}
   %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
@@ -1531,7 +1531,7 @@ func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
 func.func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                  %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.scatter' op base and valueToStore element type should match}}
+  // expected-error at +1 {{'vector.scatter' op base element type ('f64') does not match valueToStore element type ('f32')}}
   vector.scatter %base[%c0][%indices], %mask, %value
     : memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 }
@@ -1598,7 +1598,7 @@ func.func @scatter_non_power_of_2_alignment(%base: memref<?xf32>, %indices: vect
 
 func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.expandload' op base and result element type should match}}
+  // expected-error at +1 {{'vector.expandload' op base element type ('f64') does not match result element type ('f32')}}
   %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
 
@@ -1652,7 +1652,7 @@ func.func @expand_non_power_of_2_alignment(%base: memref<?xf32>, %mask: vector<1
 
 func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.compressstore' op base and valueToStore element type should match}}
+  // expected-error at +1 {{'vector.compressstore' op base element type ('f64') does not match valueToStore element type ('f32')}}
   vector.compressstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
 }
 

>From dfc03e884d2105a2514aa017d8cf7b4e315c21fe Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at gmail.com>
Date: Sun, 18 Jan 2026 16:02:41 -0600
Subject: [PATCH 2/2] clang-format.

---
 mlir/lib/Dialect/Utils/VerificationUtils.cpp |  7 +++----
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp     | 12 ++++++------
 2 files changed, 9 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Utils/VerificationUtils.cpp b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
index 95a3348d5de90..bb51581e18cde 100644
--- a/mlir/lib/Dialect/Utils/VerificationUtils.cpp
+++ b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
@@ -43,10 +43,9 @@ LogicalResult mlir::verifyElementTypesMatch(Operation *op, ShapedType lhs,
   Type lhsElementType = lhs.getElementType();
   Type rhsElementType = rhs.getElementType();
   if (lhsElementType != rhsElementType) {
-    return op->emitOpError()
-           << lhsName << " element type (" << lhsElementType
-           << ") does not match " << rhsName << " element type ("
-           << rhsElementType << ")";
+    return op->emitOpError() << lhsName << " element type (" << lhsElementType
+                             << ") does not match " << rhsName
+                             << " element type (" << rhsElementType << ")";
   }
   return success();
 }
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3f8d6b5236dee..7efba98602fd3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5930,8 +5930,8 @@ LogicalResult MaskedLoadOp::verify() {
   VectorType resVType = getVectorType();
   MemRefType memType = getMemRefType();
 
-  if (failed(verifyElementTypesMatch(*this, memType, resVType, "base",
-                                     "result")))
+  if (failed(
+          verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
     return failure();
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
@@ -6051,8 +6051,8 @@ LogicalResult GatherOp::verify() {
   if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
     return emitOpError("requires base to be a memref or ranked tensor type");
 
-  if (failed(verifyElementTypesMatch(*this, baseType, resVType, "base",
-                                     "result")))
+  if (failed(
+          verifyElementTypesMatch(*this, baseType, resVType, "base", "result")))
     return failure();
   if (llvm::size(getOffsets()) != baseType.getRank())
     return emitOpError("requires ") << baseType.getRank() << " indices";
@@ -6245,8 +6245,8 @@ LogicalResult ExpandLoadOp::verify() {
   VectorType resVType = getVectorType();
   MemRefType memType = getMemRefType();
 
-  if (failed(verifyElementTypesMatch(*this, memType, resVType, "base",
-                                     "result")))
+  if (failed(
+          verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
     return failure();
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";



More information about the Mlir-commits mailing list