[Mlir-commits] [mlir] 3a1288c - [mlir][Utils] Add verifyElementTypesMatch helper (NFC) (#176668)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 23 08:41:06 PST 2026


Author: Nick Kreeger
Date: 2026-01-23T16:41:00Z
New Revision: 3a1288c0fbe113834dfa8835223bd8f1830ceb0d

URL: https://github.com/llvm/llvm-project/commit/3a1288c0fbe113834dfa8835223bd8f1830ceb0d
DIFF: https://github.com/llvm/llvm-project/commit/3a1288c0fbe113834dfa8835223bd8f1830ceb0d.diff

LOG: [mlir][Utils] Add verifyElementTypesMatch helper (NFC) (#176668)

This change builds on #174336 and #175880, which introduced shared
VerificationUtils with verifyDynamicDimensionCount() and
verifyRanksMatch() methods.

This patch adds a new verifyElementTypesMatch() verification utility
that checks if two shaped types have matching element types and emits
consistent error messages. The utility is applied to several ops across
the MemRef and Vector dialects.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Utils/VerificationUtils.h
    mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
    mlir/lib/Dialect/Utils/VerificationUtils.cpp
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/MemRef/invalid.mlir
    mlir/test/Dialect/Vector/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
index 3d350aae7cf2f..b0f4102d15d62 100644
--- a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
@@ -32,6 +32,12 @@ LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type,
 LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs,
                                StringRef lhsName, StringRef rhsName);
 
+/// Verify that two shaped types have matching element types. Returns failure
+/// and emits an error if element types don't match.
+LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs,
+                                      ShapedType rhs, StringRef lhsName,
+                                      StringRef rhsName);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H

diff  --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a9103b4d438ea..1ca0cea0f6f2f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -340,10 +340,9 @@ LogicalResult ReallocOp::verify() {
            << sourceType << " and result memref type " << resultType;
 
   // The source memref and the result memref should have the same element type.
-  if (sourceType.getElementType() != resultType.getElementType())
-    return emitError("
diff erent element types specified for source memref "
-                     "type ")
-           << sourceType << " and result memref type " << resultType;
+  if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
+                                     "result")))
+    return failure();
 
   // Verify that we have the dynamic dimension operand when it is needed.
   if (resultType.getNumDynamicDims() && !getDynamicResultSize())
@@ -1965,9 +1964,9 @@ LogicalResult ReinterpretCastOp::verify() {
   if (srcType.getMemorySpace() != resultType.getMemorySpace())
     return emitError("
diff erent memory spaces specified for source type ")
            << srcType << " and result memref type " << resultType;
-  if (srcType.getElementType() != resultType.getElementType())
-    return emitError("
diff erent element types specified for source type ")
-           << srcType << " and result memref type " << resultType;
+  if (failed(verifyElementTypesMatch(*this, srcType, resultType, "source",
+                                     "result")))
+    return failure();
 
   // Match sizes in result memref type and in static_sizes attribute.
   for (auto [idx, resultSize, expectedSize] :

diff  --git a/mlir/lib/Dialect/Utils/VerificationUtils.cpp b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
index 81f1e590a76ee..bb51581e18cde 100644
--- a/mlir/lib/Dialect/Utils/VerificationUtils.cpp
+++ b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
@@ -36,3 +36,16 @@ LogicalResult mlir::verifyRanksMatch(Operation *op, ShapedType lhs,
   }
   return success();
 }
+
+LogicalResult mlir::verifyElementTypesMatch(Operation *op, ShapedType lhs,
+                                            ShapedType rhs, StringRef lhsName,
+                                            StringRef rhsName) {
+  Type lhsElementType = lhs.getElementType();
+  Type rhsElementType = rhs.getElementType();
+  if (lhsElementType != rhsElementType) {
+    return op->emitOpError() << lhsName << " element type (" << lhsElementType
+                             << ") does not match " << rhsName
+                             << " element type (" << rhsElementType << ")";
+  }
+  return success();
+}

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index ae23a4158155f..c6ffb81ede06e 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
@@ -5996,8 +5997,9 @@ LogicalResult MaskedLoadOp::verify() {
   VectorType resVType = getVectorType();
   MemRefType memType = getMemRefType();
 
-  if (resVType.getElementType() != memType.getElementType())
-    return emitOpError("base and result element type should match");
+  if (failed(
+          verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
+    return failure();
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
   if (resVType.getShape() != maskVType.getShape())
@@ -6055,8 +6057,9 @@ LogicalResult MaskedStoreOp::verify() {
   VectorType valueVType = getVectorType();
   MemRefType memType = getMemRefType();
 
-  if (valueVType.getElementType() != memType.getElementType())
-    return emitOpError("base and valueToStore element type should match");
+  if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
+                                     "valueToStore")))
+    return failure();
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
   if (valueVType.getShape() != maskVType.getShape())
@@ -6115,8 +6118,9 @@ LogicalResult GatherOp::verify() {
   if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
     return emitOpError("requires base to be a memref or ranked tensor type");
 
-  if (resVType.getElementType() != baseType.getElementType())
-    return emitOpError("base and result element type should match");
+  if (failed(
+          verifyElementTypesMatch(*this, baseType, resVType, "base", "result")))
+    return failure();
   if (llvm::size(getOffsets()) != baseType.getRank())
     return emitOpError("requires ") << baseType.getRank() << " indices";
   if (resVType.getShape() != indVType.getShape())
@@ -6224,8 +6228,9 @@ LogicalResult ScatterOp::verify() {
   if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
     return emitOpError("requires base to be a memref or ranked tensor type");
 
-  if (valueVType.getElementType() != baseType.getElementType())
-    return emitOpError("base and valueToStore element type should match");
+  if (failed(verifyElementTypesMatch(*this, baseType, valueVType, "base",
+                                     "valueToStore")))
+    return failure();
   if (llvm::size(getOffsets()) != baseType.getRank())
     return emitOpError("requires ") << baseType.getRank() << " indices";
   if (valueVType.getShape() != indVType.getShape())
@@ -6307,8 +6312,9 @@ LogicalResult ExpandLoadOp::verify() {
   VectorType resVType = getVectorType();
   MemRefType memType = getMemRefType();
 
-  if (resVType.getElementType() != memType.getElementType())
-    return emitOpError("base and result element type should match");
+  if (failed(
+          verifyElementTypesMatch(*this, memType, resVType, "base", "result")))
+    return failure();
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
   if (resVType.getDimSize(0) != maskVType.getDimSize(0))
@@ -6360,8 +6366,9 @@ LogicalResult CompressStoreOp::verify() {
   VectorType valueVType = getVectorType();
   MemRefType memType = getMemRefType();
 
-  if (valueVType.getElementType() != memType.getElementType())
-    return emitOpError("base and valueToStore element type should match");
+  if (failed(verifyElementTypesMatch(*this, memType, valueVType, "base",
+                                     "valueToStore")))
+    return failure();
   if (llvm::size(getIndices()) != memType.getRank())
     return emitOpError("requires ") << memType.getRank() << " indices";
   if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
@@ -6422,8 +6429,9 @@ LogicalResult ShapeCastOp::verify() {
   VectorType resultType = getResultVectorType();
 
   // Check that element type is preserved
-  if (sourceType.getElementType() != resultType.getElementType())
-    return emitOpError("has 
diff erent source and result element types");
+  if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
+                                     "result")))
+    return failure();
 
   // Check that number of elements is preserved
   int64_t sourceNElms = sourceType.getNumElements();

diff  --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index 2dfe733c13768..46e010fc878fe 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -159,7 +159,7 @@ func.func @memref_reinterpret_cast_too_many_offsets(%in: memref<?xf32>) {
 // -----
 
 func.func @memref_reinterpret_cast_incompatible_element_types(%in: memref<*xf32>) {
-  // expected-error @+1 {{
diff erent element types specified}}
+  // expected-error @+1 {{source element type ('f32') does not match result element type ('i32')}}
   %out = memref.reinterpret_cast %in to
            offset: [0], sizes: [10], strides: [1]
          : memref<*xf32> to memref<10xi32, strided<[1], offset: 0>>
@@ -1144,7 +1144,7 @@ func.func @memref_realloc_sizes_2(%src : memref<?xf32>, %d : index)
 // -----
 
 func.func @memref_realloc_type(%src : memref<256xf32>) -> memref<?xi32>{
-  // expected-error at +1 {{
diff erent element types}}
+  // expected-error at +1 {{source element type ('f32') does not match result element type ('i32')}}
   %0 = memref.realloc %src : memref<256xf32> to memref<?xi32>
   return %0 : memref<?xi32>
 }

diff  --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 54ad3d8ab0950..28e1206ff3d0a 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1133,7 +1133,7 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
 
 
 func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
-  // expected-error at +1 {{'vector.shape_cast' op has 
diff erent source and result element types}}
+  // expected-error at +1 {{'vector.shape_cast' op source element type ('f32') does not match result element type ('i32')}}
   %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
 }
 
@@ -1356,7 +1356,7 @@ func.func @maskedload_non_power_of_2_alignment(%base: memref<4xi32>, %mask: vect
 
 func.func @maskedload_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.maskedload' op base and result element type should match}}
+  // expected-error at +1 {{'vector.maskedload' op base element type ('f64') does not match result element type ('f32')}}
   %0 = vector.maskedload %base[%c0], %mask, %pass : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
 
@@ -1407,7 +1407,7 @@ func.func @maskedstore_non_power_of_2_alignment(%base: memref<4xi32>, %mask: vec
 
 func.func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.maskedstore' op base and valueToStore element type should match}}
+  // expected-error at +1 {{'vector.maskedstore' op base element type ('f64') does not match valueToStore element type ('f32')}}
   vector.maskedstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
 }
 
@@ -1442,7 +1442,7 @@ func.func @gather_from_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
 func.func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                 %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.gather' op base and result element type should match}}
+  // expected-error at +1 {{'vector.gather' op base element type ('f64') does not match result element type ('f32')}}
   %0 = vector.gather %base[%c0][%indices], %mask, %pass_thru
     : memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
@@ -1531,7 +1531,7 @@ func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
 func.func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
                                  %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.scatter' op base and valueToStore element type should match}}
+  // expected-error at +1 {{'vector.scatter' op base element type ('f64') does not match valueToStore element type ('f32')}}
   vector.scatter %base[%c0][%indices], %mask, %value
     : memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32>
 }
@@ -1598,7 +1598,7 @@ func.func @scatter_non_power_of_2_alignment(%base: memref<?xf32>, %indices: vect
 
 func.func @expand_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.expandload' op base and result element type should match}}
+  // expected-error at +1 {{'vector.expandload' op base element type ('f64') does not match result element type ('f32')}}
   %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?xf64>, vector<16xi1>, vector<16xf32> into vector<16xf32>
 }
 
@@ -1652,7 +1652,7 @@ func.func @expand_non_power_of_2_alignment(%base: memref<?xf32>, %mask: vector<1
 
 func.func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{'vector.compressstore' op base and valueToStore element type should match}}
+  // expected-error at +1 {{'vector.compressstore' op base element type ('f64') does not match valueToStore element type ('f32')}}
   vector.compressstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
 }
 


        


More information about the Mlir-commits mailing list