[Mlir-commits] [mlir] [mlir][Utils] Add verifyElementTypesMatch helper (NFC) (PR #176668)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 18 13:59:34 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir-vector

Author: Nick Kreeger (nkreeger)

<details>
<summary>Changes</summary>

This change builds on #<!-- -->174336 and #<!-- -->175880, which introduced shared VerificationUtils with verifyDynamicDimensionCount() and verifyRanksMatch() methods.

This patch adds a new verifyElementTypesMatch() verification utility that checks if two shaped types have matching element types and emits consistent error messages. The utility is applied to several ops across the MemRef and Vector dialects.

---
Full diff: https://github.com/llvm/llvm-project/pull/176668.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Utils/VerificationUtils.h (+6) 
- (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+6-7) 
- (modified) mlir/lib/Dialect/Utils/VerificationUtils.cpp (+14) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+22-14) 
- (modified) mlir/test/Dialect/MemRef/invalid.mlir (+2-2) 
- (modified) mlir/test/Dialect/Vector/invalid.mlir (+7-7) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
index 3d350aae7cf2f..b0f4102d15d62 100644
--- a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
@@ -32,6 +32,12 @@ LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type,
 LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs,
                                StringRef lhsName, StringRef rhsName);
 
+/// Verify that two shaped types have matching element types. Returns failure
+/// and emits an error if element types don't match.
+LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs,
+                                      ShapedType rhs, StringRef lhsName,
+                                      StringRef rhsName);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index b782a8be19154..2f6cb52d4a339 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -340,10 +340,9 @@ LogicalResult ReallocOp::verify() {
            << sourceType << " and result memref type " << resultType;
 
   // The source memref and the result memref should have the same element type.
-  if (sourceType.getElementType() != resultType.getElementType())
-    return emitError("different element types specified for source memref "
-                     "type ")
-           << sourceType << " and result memref type " << resultType;
+  if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
+                                     "result")))
+    return failure();
 
   // Verify that we have the dynamic dimension operand when it is needed.
   if (resultType.getNumDynamicDims() && !getDynamicResultSize())
@@ -1948,9 +1947,9 @@ LogicalResult ReinterpretCastOp::verify() {
   if (srcType.getMemorySpace() != resultType.getMemorySpace())
     return emitError("different memory spaces specified for source type ")
            << srcType << " and result memref type " << resultType;
-  if (srcType.getElementType() != resultType.getElementType())
-    return emitError("different element types specified for source type ")
-           << srcType << " and result memref type " << resultType;
+  if (failed(verifyElementTypesMatch(*this, srcType, resultType, "source",
+                                     "result")))
+    return failure();
 
   // Match sizes in result memref type and in static_sizes attribute.
   for (auto [idx, resultSize, expectedSize] :
diff --git a/mlir/lib/Dialect/Utils/VerificationUtils.cpp b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
index 81f1e590a76ee..95a3348d5de90 100644
--- a/mlir/lib/Dialect/Utils/VerificationUtils.cpp
+++ b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
@@ -36,3 +36,17 @@ LogicalResult mlir::verifyRanksMatch(Operation *op, ShapedType lhs,
   }
   return success();
 }
+
+LogicalResult mlir::verifyElementTypesMatch(Operation *op, ShapedType lhs,
+                                            ShapedType rhs, StringRef lhsName,
+                                            StringRef rhsName) {
+  Type lhsElementType = lhs.getElementType();
+  Type rhsElementType = rhs.getElementType();
+  if (lhsElementType != rhsElementType) {
+    return op->emitOpError()
+           << lhsName << " element type (" << lhsElementType
+           << ") does not match " << rhsName << " element type ("
+           << rhsElementType << ")";
+  }
+  return success();
+}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 085f879c2d0e6..3f8d6b5236dee 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
@@ -5929,8 +5930,9 @@ LogicalResult MaskedLoadOp::verify() {
   VectorType resVType = getVectorType();
   MemRefType memType = getMemRefType();
 
-  if (resVType.getElementType() != memType.getElementType())
-    return emitOpError("base and result element type should match");
+  if (failed(verifyElementTypesMatch(*this, memType, resVType, "base",
+                                     "result")))
+    return failure();
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
   if (resVType.getShape() != maskVType.getShape())
@@ -5988,8 +5990,9 @@ LogicalResult MaskedStoreOp::verify() {
   VectorType valueVType = getVectorType();
   MemRefType memType = getMemRefType();
 
-  if (valueVType.getElementType() != memType.getElementType())
-    return emitOpError("base and valueToStore element type should match");
+  if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
+                                     "valueToStore")))
+    return failure();
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
   if (valueVType.getShape() != maskVType.getShape())
@@ -6048,8 +6051,9 @@ LogicalResult GatherOp::verify() {
   if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
     return emitOpError("requires base to be a memref or ranked tensor type");
 
-  if (resVType.getElementType() != baseType.getElementType())
-    return emitOpError("base and result element type should match");
+  if (failed(verifyElementTypesMatch(*this, baseType, resVType, "base",
+                                     "result")))
+    return failure();
   if (llvm::size(getOffsets()) != baseType.getRank())
     return emitOpError("requires ") << baseType.getRank() << " indices";
   if (resVType.getShape() != indVType.getShape())
@@ -6157,8 +6161,9 @@ LogicalResult ScatterOp::verify() {
   if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
     return emitOpError("requires base to be a memref or ranked tensor type");
 
-  if (valueVType.getElementType() != baseType.getElementType())
-    return emitOpError("base and valueToStore element type should match");
+  if (failed(verifyElementTypesMatch(*this, baseType, valueVType, "base",
+                                     "valueToStore")))
+    return failure();
   if (llvm::size(getOffsets()) != baseType.getRank())
     return emitOpError("requires ") << baseType.getRank() << " indices";
   if (valueVType.getShape() != indVType.getShape())
@@ -6240,8 +6245,9 @@ LogicalResult ExpandLoadOp::verify() {
   VectorType resVType = getVectorType();
   MemRefType memType = getMemRefType();
 
-  if (resVType.getElementType() != memType.getElementType())
-    return emitOpError("base and result element type should match");
+  if (failed(verifyElementTypesMatch(*this, memType, resVType, "base",
+                                     "result")))
+    return failure();
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
@@ -6293,8 +6299,9 @@ LogicalResult CompressStoreOp::verify() {
   VectorType valueVType = getVectorType();
   MemRefType memType = getMemRefType();
 
-  if (valueVType.getElementType() != memType.getElementType())
-    return emitOpError("base and valueToStore element type should match");
+  if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
+                                     "valueToStore")))
+    return failure();
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
@@ -6355,8 +6362,9 @@ LogicalResult ShapeCastOp::verify() {
   VectorType resultType = getResultVectorType();
 
   // Check that element type is preserved
-  if (sourceType.getElementType() != resultType.getElementType())
-    return emitOpError("has different source and result element types");
+  if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
+                                     "result")))
+    return failure();
 
   // Check that number of elements is preserved
   int64_t sourceNElms = sourceType.getNumElements();
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 2dfe733c13768..46e010fc878fe 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -159,7 +159,7 @@ func.func @memref_reinterpret_cast_too_many_offsets(%in: memref<?xf32>) {
 // -----
 
 func.func @memref_reinterpret_cast_incompatible_element_types(%in: memref<*xf32>) {
-  // expected-error @+1 {{different element types specified}}
+  // expected-error @+1 {{source element type ('f32') does not match result element type ('i32')}}
   %out = memref.reinterpret_cast %in to
            offset: [0], sizes: [10], strides: [1]
          : memref<*xf32> to memref<10xi32, strided<[1], offset: 0>>
@@ -1144,7 +1144,7 @@ func.func @memref_realloc_sizes_2(%src : memref<?xf32>, %d : index)
 // -----
 
 func.func @memref_realloc_type(%src : memref<256xf32>) -> memref<?xi32>{
-  // expected-error at +1 {{different element types}}
+  // expected-error at +1 {{source element type ('f32') does not match result element type ('i32')}}
   %0 = memref.realloc %src : memref<256xf32> to memref<?xi32>
   return %0 : memref<?xi32>
 }
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 54ad3d8ab0950..28e1206ff3d0a 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1133,7 +1133,7 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
 
 
 func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error at +1 {{'vector.shape_cast' op has different source and result element types}}
+  // expected-error at +1 {{'vector.shape_cast' op source element type ('f32') does not match result element type ('i32')}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
 }
 
@@ -1356,7 +1356,7 @@ func.func @maskedload_non_power_of_2_alignment(%base: memref<4xi32>, %mask: vect
 
 func.func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.maskedload' op base and result element type should match}}
+  // expected-error at +1 {{'vector.maskedload' op base element type ('f64') does not match result element type ('f32')}}
   %0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
 
@@ -1407,7 +1407,7 @@ func.func @maskedstore_non_power_of_2_alignment(%base: memref<4xi32>, %mask: vec
 
 func.func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.maskedstore' op base and valueToStore element type should match}}
+  // expected-error at +1 {{'vector.maskedstore' op base element type ('f64') does not match valueToStore element type ('f32')}}
   vector.maskedstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
 }
 
@@ -1442,7 +1442,7 @@ func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
 func.func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                 %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.gather' op base and result element type should match}}
+  // expected-error at +1 {{'vector.gather' op base element type ('f64') does not match result element type ('f32')}}
   %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
@@ -1531,7 +1531,7 @@ func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
 func.func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                  %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.scatter' op base and valueToStore element type should match}}
+  // expected-error at +1 {{'vector.scatter' op base element type ('f64') does not match valueToStore element type ('f32')}}
   vector.scatter %base[%c0][%indices], %mask, %value
     : memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 }
@@ -1598,7 +1598,7 @@ func.func @scatter_non_power_of_2_alignment(%base: memref<?xf32>, %indices: vect
 
 func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.expandload' op base and result element type should match}}
+  // expected-error at +1 {{'vector.expandload' op base element type ('f64') does not match result element type ('f32')}}
   %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
 
@@ -1652,7 +1652,7 @@ func.func @expand_non_power_of_2_alignment(%base: memref<?xf32>, %mask: vector<1
 
 func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.compressstore' op base and valueToStore element type should match}}
+  // expected-error at +1 {{'vector.compressstore' op base element type ('f64') does not match valueToStore element type ('f32')}}
   vector.compressstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/176668


More information about the Mlir-commits mailing list