[Mlir-commits] [mlir] [mlir][Utils] Add VerificationUtils for common op verification patterns (NFC) (PR #174336)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 4 08:25:53 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-bufferization
Author: Nick Kreeger (nkreeger)
<details>
<summary>Changes</summary>
Add a new VerificationUtils.h/cpp that provides reusable verification utilities for common patterns across MLIR dialects. This reduces code duplication and provides more consistent error messages.
New utilities:
- verifyDynamicDimensionCount: Check dynamic sizes match type dims
- verifyRanksMatch/verifyRankEquals/verifyRankInRange: Rank verification
- verifyIndexCountMatchesRank: Check index count equals rank
- verifyDimensionIndicesInRange/verifyDimensionIndicesUnique: Index validation
- verifyAllShapesMatch/verifyShapesCompatible: Shape verification
- verifyAllElementTypesMatch/verifyElementTypesMatch: Element type checks
- verifyElementCountsMatch: Total element count comparison
---
Patch is 25.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/174336.diff
9 Files Affected:
- (added) mlir/include/mlir/Dialect/Utils/VerificationUtils.h (+123)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+5-3)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+11-25)
- (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+7-13)
- (modified) mlir/lib/Dialect/Utils/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Utils/VerificationUtils.cpp (+235)
- (modified) mlir/test/Dialect/Bufferization/invalid.mlir (+1-1)
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+3-3)
- (modified) mlir/test/Dialect/Tensor/invalid.mlir (+1-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
new file mode 100644
index 0000000000000..894d225c32b98
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
@@ -0,0 +1,123 @@
+//===- VerificationUtils.h - Common verification utilities ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines common verification utilities that can be shared
+// across multiple MLIR dialects. These utilities help reduce code duplication
+// for common verification patterns such as checking dynamic dimensions,
+// rank matching, and index validation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H
+#define MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// Dynamic Dimension Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that the number of dynamic size operands matches the number of
+/// dynamic dimensions in the shaped type. Returns failure and emits an error
+/// if the counts don't match.
+LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type,
+ ValueRange dynamicSizes);
+
+//===----------------------------------------------------------------------===//
+// Rank Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that two shaped types have matching ranks. Returns failure and emits
+/// an error if ranks don't match.
+LogicalResult verifyRanksMatch(Operation *op, ShapedType type1,
+ ShapedType type2, StringRef name1,
+ StringRef name2);
+
+/// Verify that a shaped type has the expected rank. Returns failure and emits
+/// an error if the rank doesn't match.
+LogicalResult verifyRankEquals(Operation *op, ShapedType type,
+ int64_t expectedRank, StringRef typeName);
+
+/// Verify that a shaped type's rank is within the specified range [minRank,
+/// maxRank]. Returns failure and emits an error if out of range.
+LogicalResult verifyRankInRange(Operation *op, ShapedType type, int64_t minRank,
+ int64_t maxRank, StringRef typeName);
+
+//===----------------------------------------------------------------------===//
+// Index/Dimension Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that the number of indices matches the rank of the shaped type.
+/// Returns failure and emits an error if the counts don't match.
+LogicalResult verifyIndexCountMatchesRank(Operation *op, int64_t rank,
+ size_t indexCount,
+ StringRef indexName = "indices");
+
+/// Verify that all dimension indices in the array are within the valid range
+/// [0, maxDim). Returns failure and emits an error if any index is out of
+/// range.
+LogicalResult verifyDimensionIndicesInRange(Operation *op,
+ ArrayRef<int64_t> indices,
+ int64_t maxDim,
+ StringRef context = "dimensions");
+
+/// Verify that all dimension indices are unique (no duplicates). Returns
+/// failure and emits an error if duplicates are found.
+LogicalResult verifyDimensionIndicesUnique(Operation *op,
+ ArrayRef<int64_t> indices,
+ StringRef context = "dimensions");
+
+//===----------------------------------------------------------------------===//
+// Shape Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that all values in the range have the same shape. Returns failure
+/// and emits an error if shapes don't match.
+LogicalResult verifyAllShapesMatch(Operation *op, ValueRange values,
+ StringRef context);
+
+/// Verify that two shaped types have compatible shapes (same rank and matching
+/// dimensions, with dynamic dimensions considered compatible with any size).
+/// Returns failure and emits an error if incompatible.
+LogicalResult verifyShapesCompatible(Operation *op, ShapedType type1,
+ ShapedType type2, StringRef name1,
+ StringRef name2);
+
+//===----------------------------------------------------------------------===//
+// Element Type Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that all values in the range have the same element type. Returns
+/// failure and emits an error if element types don't match.
+LogicalResult verifyAllElementTypesMatch(Operation *op, ValueRange values,
+ StringRef context);
+
+/// Verify that two shaped types have the same element type. Returns failure
+/// and emits an error if they don't match.
+LogicalResult verifyElementTypesMatch(Operation *op, ShapedType type1,
+ ShapedType type2, StringRef name1,
+ StringRef name2);
+
+//===----------------------------------------------------------------------===//
+// Element Count Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that two shaped types have the same total number of elements.
+/// Returns failure and emits an error if element counts don't match.
+/// Useful for reshape and shape_cast operations.
+LogicalResult verifyElementCountsMatch(Operation *op, ShapedType type1,
+ ShapedType type2, StringRef name1,
+ StringRef name2);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 56ff2121e4620..5df97324c2dce 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
#include "mlir/IR/Matchers.h"
#include <optional>
@@ -251,9 +252,10 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
LogicalResult AllocTensorOp::verify() {
if (getCopy() && !getDynamicSizes().empty())
return emitError("dynamic sizes not needed when copying a tensor");
- if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size())
- return emitError("expected ")
- << getType().getNumDynamicDims() << " dynamic sizes";
+ if (!getCopy() &&
+ failed(verifyDynamicDimensionCount(getOperation(), getType(),
+ getDynamicSizes())))
+ return failure();
if (getCopy() && getCopy().getType() != getType())
return emitError("expected that `copy` and return type match");
return success();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 210f9584c1e86..25b9f7743d294 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -24,6 +24,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -1883,36 +1884,21 @@ void ReduceOp::print(OpAsmPrinter &p) {
LogicalResult ReduceOp::verify() {
ArrayRef<int64_t> dimensionsRef = getDimensions();
- for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
- if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
- llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
- return emitOpError() << "expects all inputs to have the same shapes. "
- "Shape at input-index "
- << i
- << " is not equal to the shape at input-index 0.";
- }
- }
- for (int64_t i = 1; i < getNumDpsInits(); ++i) {
- if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
- llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
- return emitOpError() << "expects all outputs to have the same shapes. "
- "Shape at output-index "
- << i
- << " is not equal to the shape at output-index 0.";
- }
- }
+ if (failed(verifyAllShapesMatch(getOperation(), getInputs(), "inputs")))
+ return failure();
+ if (failed(verifyAllShapesMatch(getOperation(), getInits(), "inits")))
+ return failure();
auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
+ if (failed(verifyDimensionIndicesInRange(getOperation(), dimensionsRef,
+ inputType.getRank(),
+ "reduction dimensions")))
+ return failure();
+
DenseSet<int64_t> dimensionsToReduce;
- for (int64_t dimension : dimensionsRef) {
- if (dimension < 0 || dimension >= inputType.getRank()) {
- return emitOpError()
- << "dimensions for reduction should be in the range [0, "
- << inputType.getRank() - 1 << "].";
- }
+ for (int64_t dimension : dimensionsRef)
dimensionsToReduce.insert(dimension);
- }
auto inputDims = inputType.getShape();
auto initDims = initType.getShape();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index a0c7e40c20a46..32bfa603677ba 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -1075,11 +1076,8 @@ void EmptyOp::build(OpBuilder &builder, OperationState &result,
}
LogicalResult EmptyOp::verify() {
- if (getType().getNumDynamicDims() != getDynamicSizes().size())
- return emitOpError("incorrect number of dynamic sizes, has ")
- << getDynamicSizes().size() << ", expected "
- << getType().getNumDynamicDims();
- return success();
+ return verifyDynamicDimensionCount(getOperation(), getType(),
+ getDynamicSizes());
}
LogicalResult
@@ -1354,9 +1352,8 @@ void ExtractOp::getAsmResultNames(
LogicalResult ExtractOp::verify() {
// Verify the # indices match if we have a ranked type.
auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
- if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
- return emitOpError("incorrect number of indices for extract_element");
- return success();
+ return verifyIndexCountMatchesRank(getOperation(), tensorType.getRank(),
+ getIndices().size());
}
/// If we have an ExtractOp consuming an InsertOp with the same
@@ -4039,11 +4036,8 @@ void SplatOp::getAsmResultNames(
}
LogicalResult SplatOp::verify() {
- if (getType().getNumDynamicDims() != getDynamicSizes().size())
- return emitOpError("incorrect number of dynamic sizes, has ")
- << getDynamicSizes().size() << ", expected "
- << getType().getNumDynamicDims();
- return success();
+ return verifyDynamicDimensionCount(getOperation(), getType(),
+ getDynamicSizes());
}
LogicalResult
diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt
index f5bb687ae071a..7673da5932304 100644
--- a/mlir/lib/Dialect/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Utils/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRDialectUtils
ReshapeOpsUtils.cpp
StructuredOpsUtils.cpp
StaticValueUtils.cpp
+ VerificationUtils.cpp
DEPENDS
MLIRDialectUtilsIncGen
diff --git a/mlir/lib/Dialect/Utils/VerificationUtils.cpp b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
new file mode 100644
index 0000000000000..25b84a67956cb
--- /dev/null
+++ b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
@@ -0,0 +1,235 @@
+//===- VerificationUtils.cpp - Common verification utilities --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Utils/VerificationUtils.h"
+#include "llvm/ADT/DenseSet.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Dynamic Dimension Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyDynamicDimensionCount(Operation *op, ShapedType type,
+ ValueRange dynamicSizes) {
+ int64_t expectedCount = type.getNumDynamicDims();
+ int64_t actualCount = dynamicSizes.size();
+ if (expectedCount != actualCount) {
+ return op->emitOpError("incorrect number of dynamic sizes, has ")
+ << actualCount << ", expected " << expectedCount;
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Rank Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyRanksMatch(Operation *op, ShapedType type1,
+ ShapedType type2, StringRef name1,
+ StringRef name2) {
+ if (!type1.hasRank() || !type2.hasRank())
+ return success(); // Unranked types are considered compatible
+
+ int64_t rank1 = type1.getRank();
+ int64_t rank2 = type2.getRank();
+ if (rank1 != rank2) {
+ return op->emitOpError()
+ << name1 << " rank (" << rank1 << ") does not match " << name2
+ << " rank (" << rank2 << ")";
+ }
+ return success();
+}
+
+LogicalResult mlir::verifyRankEquals(Operation *op, ShapedType type,
+ int64_t expectedRank, StringRef typeName) {
+ if (!type.hasRank())
+ return success();
+
+ int64_t actualRank = type.getRank();
+ if (actualRank != expectedRank) {
+ return op->emitOpError()
+ << typeName << " must have rank " << expectedRank << ", but has "
+ << actualRank;
+ }
+ return success();
+}
+
+LogicalResult mlir::verifyRankInRange(Operation *op, ShapedType type,
+ int64_t minRank, int64_t maxRank,
+ StringRef typeName) {
+ if (!type.hasRank())
+ return success();
+
+ int64_t rank = type.getRank();
+ if (rank < minRank || rank > maxRank) {
+ return op->emitOpError()
+ << typeName << " rank must be in range [" << minRank << ", "
+ << maxRank << "], but has rank " << rank;
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Index/Dimension Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyIndexCountMatchesRank(Operation *op, int64_t rank,
+ size_t indexCount,
+ StringRef indexName) {
+ if (rank != static_cast<int64_t>(indexCount)) {
+ return op->emitOpError("incorrect number of ")
+ << indexName << ", has " << indexCount << ", expected " << rank;
+ }
+ return success();
+}
+
+LogicalResult mlir::verifyDimensionIndicesInRange(Operation *op,
+ ArrayRef<int64_t> indices,
+ int64_t maxDim,
+ StringRef context) {
+ for (int64_t index : indices) {
+ if (index < 0 || index >= maxDim) {
+ return op->emitOpError()
+ << context << " must be in the range [0, " << (maxDim - 1)
+ << "], but got " << index;
+ }
+ }
+ return success();
+}
+
+LogicalResult mlir::verifyDimensionIndicesUnique(Operation *op,
+ ArrayRef<int64_t> indices,
+ StringRef context) {
+ llvm::DenseSet<int64_t> seen;
+ for (int64_t index : indices) {
+ if (!seen.insert(index).second) {
+ return op->emitOpError()
+ << context << " contains duplicate index " << index;
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Shape Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyAllShapesMatch(Operation *op, ValueRange values,
+ StringRef context) {
+ if (values.empty())
+ return success();
+
+ auto firstType = llvm::dyn_cast<ShapedType>(values.front().getType());
+ if (!firstType || !firstType.hasRank())
+ return success();
+
+ ArrayRef<int64_t> firstShape = firstType.getShape();
+ for (auto [idx, value] : llvm::enumerate(values.drop_front())) {
+ auto type = llvm::dyn_cast<ShapedType>(value.getType());
+ if (!type || !type.hasRank())
+ continue;
+
+ if (type.getShape() != firstShape) {
+ return op->emitOpError()
+ << context << " must all have the same shape, but " << context
+ << "[0] has shape [" << firstShape << "] while " << context << "["
+ << (idx + 1) << "] has shape [" << type.getShape() << "]";
+ }
+ }
+ return success();
+}
+
+LogicalResult mlir::verifyShapesCompatible(Operation *op, ShapedType type1,
+ ShapedType type2, StringRef name1,
+ StringRef name2) {
+ if (!type1.hasRank() || !type2.hasRank())
+ return success();
+
+ if (type1.getRank() != type2.getRank()) {
+ return op->emitOpError()
+ << name1 << " and " << name2 << " must have the same rank";
+ }
+
+ ArrayRef<int64_t> shape1 = type1.getShape();
+ ArrayRef<int64_t> shape2 = type2.getShape();
+ for (auto [idx, dims] : llvm::enumerate(llvm::zip(shape1, shape2))) {
+ auto [dim1, dim2] = dims;
+ // Dynamic dimensions are compatible with anything
+ if (ShapedType::isDynamic(dim1) || ShapedType::isDynamic(dim2))
+ continue;
+ if (dim1 != dim2) {
+ return op->emitOpError()
+ << name1 << " and " << name2 << " have incompatible shapes at "
+ << "dimension " << idx << ": " << dim1 << " vs " << dim2;
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Element Type Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyAllElementTypesMatch(Operation *op, ValueRange values,
+ StringRef context) {
+ if (values.empty())
+ return success();
+
+ auto firstType = llvm::dyn_cast<ShapedType>(values.front().getType());
+ if (!firstType)
+ return success();
+
+ Type firstEle...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/174336
More information about the Mlir-commits
mailing list