[Mlir-commits] [mlir] [mlir][Utils] Add VerificationUtils for common op verification patterns (NFC) (PR #174336)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jan 4 08:25:53 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-bufferization

Author: Nick Kreeger (nkreeger)

<details>
<summary>Changes</summary>

Add a new VerificationUtils.h/cpp that provides reusable verification utilities for common patterns across MLIR dialects. This reduces code duplication and provides more consistent error messages.

New utilities:
- verifyDynamicDimensionCount: Check dynamic sizes match type dims
- verifyRanksMatch/verifyRankEquals/verifyRankInRange: Rank verification
- verifyIndexCountMatchesRank: Check index count equals rank
- verifyDimensionIndicesInRange/verifyDimensionIndicesUnique: Index validation
- verifyAllShapesMatch/verifyShapesCompatible: Shape verification
- verifyAllElementTypesMatch/verifyElementTypesMatch: Element type checks
- verifyElementCountsMatch: Total element count comparison

---

Patch is 25.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174336.diff


9 Files Affected:

- (added) mlir/include/mlir/Dialect/Utils/VerificationUtils.h (+123) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+5-3) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+11-25) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+7-13) 
- (modified) mlir/lib/Dialect/Utils/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Utils/VerificationUtils.cpp (+235) 
- (modified) mlir/test/Dialect/Bufferization/invalid.mlir (+1-1) 
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+3-3) 
- (modified) mlir/test/Dialect/Tensor/invalid.mlir (+1-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
new file mode 100644
index 0000000000000..894d225c32b98
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
@@ -0,0 +1,123 @@
+//===- VerificationUtils.h - Common verification utilities ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines common verification utilities that can be shared
+// across multiple MLIR dialects. These utilities help reduce code duplication
+// for common verification patterns such as checking dynamic dimensions,
+// rank matching, and index validation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H
+#define MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// Dynamic Dimension Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that the number of dynamic size operands matches the number of
+/// dynamic dimensions in the shaped type. Returns failure and emits an error
+/// if the counts don't match.
+LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type,
+                                          ValueRange dynamicSizes);
+
+//===----------------------------------------------------------------------===//
+// Rank Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that two shaped types have matching ranks. Returns failure and emits
+/// an error if ranks don't match.
+LogicalResult verifyRanksMatch(Operation *op, ShapedType type1,
+                               ShapedType type2, StringRef name1,
+                               StringRef name2);
+
+/// Verify that a shaped type has the expected rank. Returns failure and emits
+/// an error if the rank doesn't match.
+LogicalResult verifyRankEquals(Operation *op, ShapedType type,
+                               int64_t expectedRank, StringRef typeName);
+
+/// Verify that a shaped type's rank is within the specified range [minRank,
+/// maxRank]. Returns failure and emits an error if out of range.
+LogicalResult verifyRankInRange(Operation *op, ShapedType type, int64_t minRank,
+                                int64_t maxRank, StringRef typeName);
+
+//===----------------------------------------------------------------------===//
+// Index/Dimension Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that the number of indices matches the rank of the shaped type.
+/// Returns failure and emits an error if the counts don't match.
+LogicalResult verifyIndexCountMatchesRank(Operation *op, int64_t rank,
+                                          size_t indexCount,
+                                          StringRef indexName = "indices");
+
+/// Verify that all dimension indices in the array are within the valid range
+/// [0, maxDim). Returns failure and emits an error if any index is out of
+/// range.
+LogicalResult verifyDimensionIndicesInRange(Operation *op,
+                                            ArrayRef<int64_t> indices,
+                                            int64_t maxDim,
+                                            StringRef context = "dimensions");
+
+/// Verify that all dimension indices are unique (no duplicates). Returns
+/// failure and emits an error if duplicates are found.
+LogicalResult verifyDimensionIndicesUnique(Operation *op,
+                                           ArrayRef<int64_t> indices,
+                                           StringRef context = "dimensions");
+
+//===----------------------------------------------------------------------===//
+// Shape Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that all values in the range have the same shape. Returns failure
+/// and emits an error if shapes don't match.
+LogicalResult verifyAllShapesMatch(Operation *op, ValueRange values,
+                                   StringRef context);
+
+/// Verify that two shaped types have compatible shapes (same rank and matching
+/// dimensions, with dynamic dimensions considered compatible with any size).
+/// Returns failure and emits an error if incompatible.
+LogicalResult verifyShapesCompatible(Operation *op, ShapedType type1,
+                                     ShapedType type2, StringRef name1,
+                                     StringRef name2);
+
+//===----------------------------------------------------------------------===//
+// Element Type Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that all values in the range have the same element type. Returns
+/// failure and emits an error if element types don't match.
+LogicalResult verifyAllElementTypesMatch(Operation *op, ValueRange values,
+                                         StringRef context);
+
+/// Verify that two shaped types have the same element type. Returns failure
+/// and emits an error if they don't match.
+LogicalResult verifyElementTypesMatch(Operation *op, ShapedType type1,
+                                      ShapedType type2, StringRef name1,
+                                      StringRef name2);
+
+//===----------------------------------------------------------------------===//
+// Element Count Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that two shaped types have the same total number of elements.
+/// Returns failure and emits an error if element counts don't match.
+/// Useful for reshape and shape_cast operations.
+LogicalResult verifyElementCountsMatch(Operation *op, ShapedType type1,
+                                       ShapedType type2, StringRef name1,
+                                       StringRef name2);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 56ff2121e4620..5df97324c2dce 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
 #include "mlir/IR/Matchers.h"
 #include <optional>
 
@@ -251,9 +252,10 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
 LogicalResult AllocTensorOp::verify() {
   if (getCopy() && !getDynamicSizes().empty())
     return emitError("dynamic sizes not needed when copying a tensor");
-  if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size())
-    return emitError("expected ")
-           << getType().getNumDynamicDims() << " dynamic sizes";
+  if (!getCopy() &&
+      failed(verifyDynamicDimensionCount(getOperation(), getType(),
+                                         getDynamicSizes())))
+    return failure();
   if (getCopy() && getCopy().getType() != getType())
     return emitError("expected that `copy` and return type match");
   return success();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 210f9584c1e86..25b9f7743d294 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -24,6 +24,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -1883,36 +1884,21 @@ void ReduceOp::print(OpAsmPrinter &p) {
 LogicalResult ReduceOp::verify() {
   ArrayRef<int64_t> dimensionsRef = getDimensions();
 
-  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
-    if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
-        llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
-      return emitOpError() << "expects all inputs to have the same shapes. "
-                              "Shape at input-index "
-                           << i
-                           << " is not equal to the shape at input-index 0.";
-    }
-  }
-  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
-    if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
-        llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
-      return emitOpError() << "expects all outputs to have the same shapes. "
-                              "Shape at output-index "
-                           << i
-                           << " is not equal to the shape at output-index 0.";
-    }
-  }
+  if (failed(verifyAllShapesMatch(getOperation(), getInputs(), "inputs")))
+    return failure();
+  if (failed(verifyAllShapesMatch(getOperation(), getInits(), "inits")))
+    return failure();
   auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
   auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
 
+  if (failed(verifyDimensionIndicesInRange(getOperation(), dimensionsRef,
+                                           inputType.getRank(),
+                                           "reduction dimensions")))
+    return failure();
+
   DenseSet<int64_t> dimensionsToReduce;
-  for (int64_t dimension : dimensionsRef) {
-    if (dimension < 0 || dimension >= inputType.getRank()) {
-      return emitOpError()
-             << "dimensions for reduction should be in the range [0, "
-             << inputType.getRank() - 1 << "].";
-    }
+  for (int64_t dimension : dimensionsRef)
     dimensionsToReduce.insert(dimension);
-  }
 
   auto inputDims = inputType.getShape();
   auto initDims = initType.getShape();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index a0c7e40c20a46..32bfa603677ba 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -1075,11 +1076,8 @@ void EmptyOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult EmptyOp::verify() {
-  if (getType().getNumDynamicDims() != getDynamicSizes().size())
-    return emitOpError("incorrect number of dynamic sizes, has ")
-           << getDynamicSizes().size() << ", expected "
-           << getType().getNumDynamicDims();
-  return success();
+  return verifyDynamicDimensionCount(getOperation(), getType(),
+                                     getDynamicSizes());
 }
 
 LogicalResult
@@ -1354,9 +1352,8 @@ void ExtractOp::getAsmResultNames(
 LogicalResult ExtractOp::verify() {
   // Verify the # indices match if we have a ranked type.
   auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
-  if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
-    return emitOpError("incorrect number of indices for extract_element");
-  return success();
+  return verifyIndexCountMatchesRank(getOperation(), tensorType.getRank(),
+                                     getIndices().size());
 }
 
 /// If we have an ExtractOp consuming an InsertOp with the same
@@ -4039,11 +4036,8 @@ void SplatOp::getAsmResultNames(
 }
 
 LogicalResult SplatOp::verify() {
-  if (getType().getNumDynamicDims() != getDynamicSizes().size())
-    return emitOpError("incorrect number of dynamic sizes, has ")
-           << getDynamicSizes().size() << ", expected "
-           << getType().getNumDynamicDims();
-  return success();
+  return verifyDynamicDimensionCount(getOperation(), getType(),
+                                     getDynamicSizes());
 }
 
 LogicalResult
diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt
index f5bb687ae071a..7673da5932304 100644
--- a/mlir/lib/Dialect/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Utils/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRDialectUtils
   ReshapeOpsUtils.cpp
   StructuredOpsUtils.cpp
   StaticValueUtils.cpp
+  VerificationUtils.cpp
 
   DEPENDS
   MLIRDialectUtilsIncGen
diff --git a/mlir/lib/Dialect/Utils/VerificationUtils.cpp b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
new file mode 100644
index 0000000000000..25b84a67956cb
--- /dev/null
+++ b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
@@ -0,0 +1,235 @@
+//===- VerificationUtils.cpp - Common verification utilities --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Utils/VerificationUtils.h"
+#include "llvm/ADT/DenseSet.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Dynamic Dimension Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyDynamicDimensionCount(Operation *op, ShapedType type,
+                                                ValueRange dynamicSizes) {
+  int64_t expectedCount = type.getNumDynamicDims();
+  int64_t actualCount = dynamicSizes.size();
+  if (expectedCount != actualCount) {
+    return op->emitOpError("incorrect number of dynamic sizes, has ")
+           << actualCount << ", expected " << expectedCount;
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Rank Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyRanksMatch(Operation *op, ShapedType type1,
+                                     ShapedType type2, StringRef name1,
+                                     StringRef name2) {
+  if (!type1.hasRank() || !type2.hasRank())
+    return success(); // Unranked types are considered compatible
+
+  int64_t rank1 = type1.getRank();
+  int64_t rank2 = type2.getRank();
+  if (rank1 != rank2) {
+    return op->emitOpError()
+           << name1 << " rank (" << rank1 << ") does not match " << name2
+           << " rank (" << rank2 << ")";
+  }
+  return success();
+}
+
+LogicalResult mlir::verifyRankEquals(Operation *op, ShapedType type,
+                                     int64_t expectedRank, StringRef typeName) {
+  if (!type.hasRank())
+    return success();
+
+  int64_t actualRank = type.getRank();
+  if (actualRank != expectedRank) {
+    return op->emitOpError()
+           << typeName << " must have rank " << expectedRank << ", but has "
+           << actualRank;
+  }
+  return success();
+}
+
+LogicalResult mlir::verifyRankInRange(Operation *op, ShapedType type,
+                                      int64_t minRank, int64_t maxRank,
+                                      StringRef typeName) {
+  if (!type.hasRank())
+    return success();
+
+  int64_t rank = type.getRank();
+  if (rank < minRank || rank > maxRank) {
+    return op->emitOpError()
+           << typeName << " rank must be in range [" << minRank << ", "
+           << maxRank << "], but has rank " << rank;
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Index/Dimension Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyIndexCountMatchesRank(Operation *op, int64_t rank,
+                                                size_t indexCount,
+                                                StringRef indexName) {
+  if (rank != static_cast<int64_t>(indexCount)) {
+    return op->emitOpError("incorrect number of ")
+           << indexName << ", has " << indexCount << ", expected " << rank;
+  }
+  return success();
+}
+
+LogicalResult mlir::verifyDimensionIndicesInRange(Operation *op,
+                                                  ArrayRef<int64_t> indices,
+                                                  int64_t maxDim,
+                                                  StringRef context) {
+  for (int64_t index : indices) {
+    if (index < 0 || index >= maxDim) {
+      return op->emitOpError()
+             << context << " must be in the range [0, " << (maxDim - 1)
+             << "], but got " << index;
+    }
+  }
+  return success();
+}
+
+LogicalResult mlir::verifyDimensionIndicesUnique(Operation *op,
+                                                 ArrayRef<int64_t> indices,
+                                                 StringRef context) {
+  llvm::DenseSet<int64_t> seen;
+  for (int64_t index : indices) {
+    if (!seen.insert(index).second) {
+      return op->emitOpError()
+             << context << " contains duplicate index " << index;
+    }
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Shape Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyAllShapesMatch(Operation *op, ValueRange values,
+                                         StringRef context) {
+  if (values.empty())
+    return success();
+
+  auto firstType = llvm::dyn_cast<ShapedType>(values.front().getType());
+  if (!firstType || !firstType.hasRank())
+    return success();
+
+  ArrayRef<int64_t> firstShape = firstType.getShape();
+  for (auto [idx, value] : llvm::enumerate(values.drop_front())) {
+    auto type = llvm::dyn_cast<ShapedType>(value.getType());
+    if (!type || !type.hasRank())
+      continue;
+
+    if (type.getShape() != firstShape) {
+      return op->emitOpError()
+             << context << " must all have the same shape, but " << context
+             << "[0] has shape [" << firstShape << "] while " << context << "["
+             << (idx + 1) << "] has shape [" << type.getShape() << "]";
+    }
+  }
+  return success();
+}
+
+LogicalResult mlir::verifyShapesCompatible(Operation *op, ShapedType type1,
+                                           ShapedType type2, StringRef name1,
+                                           StringRef name2) {
+  if (!type1.hasRank() || !type2.hasRank())
+    return success();
+
+  if (type1.getRank() != type2.getRank()) {
+    return op->emitOpError()
+           << name1 << " and " << name2 << " must have the same rank";
+  }
+
+  ArrayRef<int64_t> shape1 = type1.getShape();
+  ArrayRef<int64_t> shape2 = type2.getShape();
+  for (auto [idx, dims] : llvm::enumerate(llvm::zip(shape1, shape2))) {
+    auto [dim1, dim2] = dims;
+    // Dynamic dimensions are compatible with anything
+    if (ShapedType::isDynamic(dim1) || ShapedType::isDynamic(dim2))
+      continue;
+    if (dim1 != dim2) {
+      return op->emitOpError()
+             << name1 << " and " << name2 << " have incompatible shapes at "
+             << "dimension " << idx << ": " << dim1 << " vs " << dim2;
+    }
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Element Type Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyAllElementTypesMatch(Operation *op, ValueRange values,
+                                               StringRef context) {
+  if (values.empty())
+    return success();
+
+  auto firstType = llvm::dyn_cast<ShapedType>(values.front().getType());
+  if (!firstType)
+    return success();
+
+  Type firstEle...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/174336


More information about the Mlir-commits mailing list