[Mlir-commits] [mlir] [mlir][Utils] Add VerificationUtils for common op verification patterns (NFC) (PR #174336)
Nick Kreeger
llvmlistbot at llvm.org
Mon Jan 5 12:54:43 PST 2026
https://github.com/nkreeger updated https://github.com/llvm/llvm-project/pull/174336
>From 227d34e04e1ec06b7ceff6955980a7ba25bd0378 Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at microsoft.com>
Date: Sun, 4 Jan 2026 10:22:36 -0600
Subject: [PATCH 1/3] [mlir][Utils] Add VerificationUtils for common op
verification patterns (NFC).
Add a new VerificationUtils.h/cpp that provides reusable verification
utilities for common patterns across MLIR dialects. This reduces code
duplication and provides more consistent error messages.
New utilities:
- verifyDynamicDimensionCount: Check dynamic sizes match type dims
- verifyRanksMatch/verifyRankEquals/verifyRankInRange: Rank verification
- verifyIndexCountMatchesRank: Check index count equals rank
- verifyDimensionIndicesInRange/verifyDimensionIndicesUnique: Index validation
- verifyAllShapesMatch/verifyShapesCompatible: Shape verification
- verifyAllElementTypesMatch/verifyElementTypesMatch: Element type checks
- verifyElementCountsMatch: Total element count comparison
---
.../mlir/Dialect/Utils/VerificationUtils.h | 123 +++++++++
.../Bufferization/IR/BufferizationOps.cpp | 8 +-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 36 +--
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 20 +-
mlir/lib/Dialect/Utils/CMakeLists.txt | 1 +
mlir/lib/Dialect/Utils/VerificationUtils.cpp | 235 ++++++++++++++++++
mlir/test/Dialect/Bufferization/invalid.mlir | 2 +-
mlir/test/Dialect/Linalg/invalid.mlir | 6 +-
mlir/test/Dialect/Tensor/invalid.mlir | 2 +-
9 files changed, 387 insertions(+), 46 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Utils/VerificationUtils.h
create mode 100644 mlir/lib/Dialect/Utils/VerificationUtils.cpp
diff --git a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
new file mode 100644
index 0000000000000..894d225c32b98
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
@@ -0,0 +1,123 @@
+//===- VerificationUtils.h - Common verification utilities ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header file defines common verification utilities that can be shared
+// across multiple MLIR dialects. These utilities help reduce code duplication
+// for common verification patterns such as checking dynamic dimensions,
+// rank matching, and index validation.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H
+#define MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+
+//===----------------------------------------------------------------------===//
+// Dynamic Dimension Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that the number of dynamic size operands matches the number of
+/// dynamic dimensions in the shaped type. Returns failure and emits an error
+/// if the counts don't match.
+LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type,
+ ValueRange dynamicSizes);
+
+//===----------------------------------------------------------------------===//
+// Rank Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that two shaped types have matching ranks. Returns failure and emits
+/// an error if ranks don't match.
+LogicalResult verifyRanksMatch(Operation *op, ShapedType type1,
+ ShapedType type2, StringRef name1,
+ StringRef name2);
+
+/// Verify that a shaped type has the expected rank. Returns failure and emits
+/// an error if the rank doesn't match.
+LogicalResult verifyRankEquals(Operation *op, ShapedType type,
+ int64_t expectedRank, StringRef typeName);
+
+/// Verify that a shaped type's rank is within the specified range [minRank,
+/// maxRank]. Returns failure and emits an error if out of range.
+LogicalResult verifyRankInRange(Operation *op, ShapedType type, int64_t minRank,
+ int64_t maxRank, StringRef typeName);
+
+//===----------------------------------------------------------------------===//
+// Index/Dimension Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that the number of indices matches the rank of the shaped type.
+/// Returns failure and emits an error if the counts don't match.
+LogicalResult verifyIndexCountMatchesRank(Operation *op, int64_t rank,
+ size_t indexCount,
+ StringRef indexName = "indices");
+
+/// Verify that all dimension indices in the array are within the valid range
+/// [0, maxDim). Returns failure and emits an error if any index is out of
+/// range.
+LogicalResult verifyDimensionIndicesInRange(Operation *op,
+ ArrayRef<int64_t> indices,
+ int64_t maxDim,
+ StringRef context = "dimensions");
+
+/// Verify that all dimension indices are unique (no duplicates). Returns
+/// failure and emits an error if duplicates are found.
+LogicalResult verifyDimensionIndicesUnique(Operation *op,
+ ArrayRef<int64_t> indices,
+ StringRef context = "dimensions");
+
+//===----------------------------------------------------------------------===//
+// Shape Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that all values in the range have the same shape. Returns failure
+/// and emits an error if shapes don't match.
+LogicalResult verifyAllShapesMatch(Operation *op, ValueRange values,
+ StringRef context);
+
+/// Verify that two shaped types have compatible shapes (same rank and matching
+/// dimensions, with dynamic dimensions considered compatible with any size).
+/// Returns failure and emits an error if incompatible.
+LogicalResult verifyShapesCompatible(Operation *op, ShapedType type1,
+ ShapedType type2, StringRef name1,
+ StringRef name2);
+
+//===----------------------------------------------------------------------===//
+// Element Type Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that all values in the range have the same element type. Returns
+/// failure and emits an error if element types don't match.
+LogicalResult verifyAllElementTypesMatch(Operation *op, ValueRange values,
+ StringRef context);
+
+/// Verify that two shaped types have the same element type. Returns failure
+/// and emits an error if they don't match.
+LogicalResult verifyElementTypesMatch(Operation *op, ShapedType type1,
+ ShapedType type2, StringRef name1,
+ StringRef name2);
+
+//===----------------------------------------------------------------------===//
+// Element Count Verification
+//===----------------------------------------------------------------------===//
+
+/// Verify that two shaped types have the same total number of elements.
+/// Returns failure and emits an error if element counts don't match.
+/// Useful for reshape and shape_cast operations.
+LogicalResult verifyElementCountsMatch(Operation *op, ShapedType type1,
+ ShapedType type2, StringRef name1,
+ StringRef name2);
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 56ff2121e4620..5df97324c2dce 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
#include "mlir/IR/Matchers.h"
#include <optional>
@@ -251,9 +252,10 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
LogicalResult AllocTensorOp::verify() {
if (getCopy() && !getDynamicSizes().empty())
return emitError("dynamic sizes not needed when copying a tensor");
- if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size())
- return emitError("expected ")
- << getType().getNumDynamicDims() << " dynamic sizes";
+ if (!getCopy() &&
+ failed(verifyDynamicDimensionCount(getOperation(), getType(),
+ getDynamicSizes())))
+ return failure();
if (getCopy() && getCopy().getType() != getType())
return emitError("expected that `copy` and return type match");
return success();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 210f9584c1e86..25b9f7743d294 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -24,6 +24,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -1883,36 +1884,21 @@ void ReduceOp::print(OpAsmPrinter &p) {
LogicalResult ReduceOp::verify() {
ArrayRef<int64_t> dimensionsRef = getDimensions();
- for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
- if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
- llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
- return emitOpError() << "expects all inputs to have the same shapes. "
- "Shape at input-index "
- << i
- << " is not equal to the shape at input-index 0.";
- }
- }
- for (int64_t i = 1; i < getNumDpsInits(); ++i) {
- if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
- llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
- return emitOpError() << "expects all outputs to have the same shapes. "
- "Shape at output-index "
- << i
- << " is not equal to the shape at output-index 0.";
- }
- }
+ if (failed(verifyAllShapesMatch(getOperation(), getInputs(), "inputs")))
+ return failure();
+ if (failed(verifyAllShapesMatch(getOperation(), getInits(), "inits")))
+ return failure();
auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
+ if (failed(verifyDimensionIndicesInRange(getOperation(), dimensionsRef,
+ inputType.getRank(),
+ "reduction dimensions")))
+ return failure();
+
DenseSet<int64_t> dimensionsToReduce;
- for (int64_t dimension : dimensionsRef) {
- if (dimension < 0 || dimension >= inputType.getRank()) {
- return emitOpError()
- << "dimensions for reduction should be in the range [0, "
- << inputType.getRank() - 1 << "].";
- }
+ for (int64_t dimension : dimensionsRef)
dimensionsToReduce.insert(dimension);
- }
auto inputDims = inputType.getShape();
auto initDims = initType.getShape();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index a0c7e40c20a46..32bfa603677ba 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
@@ -1075,11 +1076,8 @@ void EmptyOp::build(OpBuilder &builder, OperationState &result,
}
LogicalResult EmptyOp::verify() {
- if (getType().getNumDynamicDims() != getDynamicSizes().size())
- return emitOpError("incorrect number of dynamic sizes, has ")
- << getDynamicSizes().size() << ", expected "
- << getType().getNumDynamicDims();
- return success();
+ return verifyDynamicDimensionCount(getOperation(), getType(),
+ getDynamicSizes());
}
LogicalResult
@@ -1354,9 +1352,8 @@ void ExtractOp::getAsmResultNames(
LogicalResult ExtractOp::verify() {
// Verify the # indices match if we have a ranked type.
auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
- if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
- return emitOpError("incorrect number of indices for extract_element");
- return success();
+ return verifyIndexCountMatchesRank(getOperation(), tensorType.getRank(),
+ getIndices().size());
}
/// If we have an ExtractOp consuming an InsertOp with the same
@@ -4039,11 +4036,8 @@ void SplatOp::getAsmResultNames(
}
LogicalResult SplatOp::verify() {
- if (getType().getNumDynamicDims() != getDynamicSizes().size())
- return emitOpError("incorrect number of dynamic sizes, has ")
- << getDynamicSizes().size() << ", expected "
- << getType().getNumDynamicDims();
- return success();
+ return verifyDynamicDimensionCount(getOperation(), getType(),
+ getDynamicSizes());
}
LogicalResult
diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt
index f5bb687ae071a..7673da5932304 100644
--- a/mlir/lib/Dialect/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/Utils/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRDialectUtils
ReshapeOpsUtils.cpp
StructuredOpsUtils.cpp
StaticValueUtils.cpp
+ VerificationUtils.cpp
DEPENDS
MLIRDialectUtilsIncGen
diff --git a/mlir/lib/Dialect/Utils/VerificationUtils.cpp b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
new file mode 100644
index 0000000000000..25b84a67956cb
--- /dev/null
+++ b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
@@ -0,0 +1,235 @@
+//===- VerificationUtils.cpp - Common verification utilities --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Utils/VerificationUtils.h"
+#include "llvm/ADT/DenseSet.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Dynamic Dimension Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyDynamicDimensionCount(Operation *op, ShapedType type,
+ ValueRange dynamicSizes) {
+ int64_t expectedCount = type.getNumDynamicDims();
+ int64_t actualCount = dynamicSizes.size();
+ if (expectedCount != actualCount) {
+ return op->emitOpError("incorrect number of dynamic sizes, has ")
+ << actualCount << ", expected " << expectedCount;
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Rank Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyRanksMatch(Operation *op, ShapedType type1,
+ ShapedType type2, StringRef name1,
+ StringRef name2) {
+ if (!type1.hasRank() || !type2.hasRank())
+ return success(); // Unranked types are considered compatible
+
+ int64_t rank1 = type1.getRank();
+ int64_t rank2 = type2.getRank();
+ if (rank1 != rank2) {
+ return op->emitOpError()
+ << name1 << " rank (" << rank1 << ") does not match " << name2
+ << " rank (" << rank2 << ")";
+ }
+ return success();
+}
+
+LogicalResult mlir::verifyRankEquals(Operation *op, ShapedType type,
+ int64_t expectedRank, StringRef typeName) {
+ if (!type.hasRank())
+ return success();
+
+ int64_t actualRank = type.getRank();
+ if (actualRank != expectedRank) {
+ return op->emitOpError()
+ << typeName << " must have rank " << expectedRank << ", but has "
+ << actualRank;
+ }
+ return success();
+}
+
+LogicalResult mlir::verifyRankInRange(Operation *op, ShapedType type,
+ int64_t minRank, int64_t maxRank,
+ StringRef typeName) {
+ if (!type.hasRank())
+ return success();
+
+ int64_t rank = type.getRank();
+ if (rank < minRank || rank > maxRank) {
+ return op->emitOpError()
+ << typeName << " rank must be in range [" << minRank << ", "
+ << maxRank << "], but has rank " << rank;
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Index/Dimension Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyIndexCountMatchesRank(Operation *op, int64_t rank,
+ size_t indexCount,
+ StringRef indexName) {
+ if (rank != static_cast<int64_t>(indexCount)) {
+ return op->emitOpError("incorrect number of ")
+ << indexName << ", has " << indexCount << ", expected " << rank;
+ }
+ return success();
+}
+
+LogicalResult mlir::verifyDimensionIndicesInRange(Operation *op,
+ ArrayRef<int64_t> indices,
+ int64_t maxDim,
+ StringRef context) {
+ for (int64_t index : indices) {
+ if (index < 0 || index >= maxDim) {
+ return op->emitOpError()
+ << context << " must be in the range [0, " << (maxDim - 1)
+ << "], but got " << index;
+ }
+ }
+ return success();
+}
+
+LogicalResult mlir::verifyDimensionIndicesUnique(Operation *op,
+ ArrayRef<int64_t> indices,
+ StringRef context) {
+ llvm::DenseSet<int64_t> seen;
+ for (int64_t index : indices) {
+ if (!seen.insert(index).second) {
+ return op->emitOpError()
+ << context << " contains duplicate index " << index;
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Shape Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyAllShapesMatch(Operation *op, ValueRange values,
+ StringRef context) {
+ if (values.empty())
+ return success();
+
+ auto firstType = llvm::dyn_cast<ShapedType>(values.front().getType());
+ if (!firstType || !firstType.hasRank())
+ return success();
+
+ ArrayRef<int64_t> firstShape = firstType.getShape();
+ for (auto [idx, value] : llvm::enumerate(values.drop_front())) {
+ auto type = llvm::dyn_cast<ShapedType>(value.getType());
+ if (!type || !type.hasRank())
+ continue;
+
+ if (type.getShape() != firstShape) {
+ return op->emitOpError()
+ << context << " must all have the same shape, but " << context
+ << "[0] has shape [" << firstShape << "] while " << context << "["
+ << (idx + 1) << "] has shape [" << type.getShape() << "]";
+ }
+ }
+ return success();
+}
+
+LogicalResult mlir::verifyShapesCompatible(Operation *op, ShapedType type1,
+ ShapedType type2, StringRef name1,
+ StringRef name2) {
+ if (!type1.hasRank() || !type2.hasRank())
+ return success();
+
+ if (type1.getRank() != type2.getRank()) {
+ return op->emitOpError()
+ << name1 << " and " << name2 << " must have the same rank";
+ }
+
+ ArrayRef<int64_t> shape1 = type1.getShape();
+ ArrayRef<int64_t> shape2 = type2.getShape();
+ for (auto [idx, dims] : llvm::enumerate(llvm::zip(shape1, shape2))) {
+ auto [dim1, dim2] = dims;
+ // Dynamic dimensions are compatible with anything
+ if (ShapedType::isDynamic(dim1) || ShapedType::isDynamic(dim2))
+ continue;
+ if (dim1 != dim2) {
+ return op->emitOpError()
+ << name1 << " and " << name2 << " have incompatible shapes at "
+ << "dimension " << idx << ": " << dim1 << " vs " << dim2;
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Element Type Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyAllElementTypesMatch(Operation *op, ValueRange values,
+ StringRef context) {
+ if (values.empty())
+ return success();
+
+ auto firstType = llvm::dyn_cast<ShapedType>(values.front().getType());
+ if (!firstType)
+ return success();
+
+ Type firstElementType = firstType.getElementType();
+ for (auto [idx, value] : llvm::enumerate(values.drop_front())) {
+ auto type = llvm::dyn_cast<ShapedType>(value.getType());
+ if (!type)
+ continue;
+
+ if (type.getElementType() != firstElementType) {
+ return op->emitOpError()
+ << context << " must all have the same element type, but "
+ << context << "[0] has element type " << firstElementType
+ << " while " << context << "[" << (idx + 1) << "] has element type "
+ << type.getElementType();
+ }
+ }
+ return success();
+}
+
+LogicalResult mlir::verifyElementTypesMatch(Operation *op, ShapedType type1,
+ ShapedType type2, StringRef name1,
+ StringRef name2) {
+ if (type1.getElementType() != type2.getElementType()) {
+ return op->emitOpError()
+ << name1 << " element type (" << type1.getElementType()
+ << ") does not match " << name2 << " element type ("
+ << type2.getElementType() << ")";
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Element Count Verification
+//===----------------------------------------------------------------------===//
+
+LogicalResult mlir::verifyElementCountsMatch(Operation *op, ShapedType type1,
+ ShapedType type2, StringRef name1,
+ StringRef name2) {
+ if (!type1.hasStaticShape() || !type2.hasStaticShape())
+ return success(); // Can't verify dynamic shapes at compile time
+
+ int64_t count1 = type1.getNumElements();
+ int64_t count2 = type2.getNumElements();
+ if (count1 != count2) {
+ return op->emitOpError()
+ << name1 << " has " << count1 << " elements, but " << name2
+ << " has " << count2 << " elements";
+ }
+ return success();
+}
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 9884b040119d0..76aba14bc50f2 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -2,7 +2,7 @@
func.func @alloc_tensor_missing_dims(%arg0: index)
{
- // expected-error @+1 {{expected 2 dynamic sizes}}
+ // expected-error @+1 {{incorrect number of dynamic sizes, has 1, expected 2}}
%0 = bufferization.alloc_tensor(%arg0) : tensor<4x?x?x5xf32>
return
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index 5a699135604b7..f9b411d7ca766 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -802,7 +802,7 @@ func.func @reduce_input_vs_init_dimension_mismatch(
func.func @reduce_dimensions_out_of_range(%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
- // expected-error @+1 {{'linalg.reduce' op dimensions for reduction should be in the range [0, 2].}}
+ // expected-error @+1 {{'linalg.reduce' op reduction dimensions must be in the range [0, 2], but got 3}}
%reduce = linalg.reduce
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<16x64xf32>)
@@ -922,7 +922,7 @@ func.func @reduce_wrong_block_argument_output_type(
func.func @reduce_different_input_shapes(%input1: tensor<16x32x64xf32>,
%init1: tensor<16x64xf32>, %input2: tensor<17x32x64xf32>,
%init2: tensor<17x64xf32>) -> (tensor<16x64xf32>, tensor<17x64xf32>) {
- // expected-error @+1{{'linalg.reduce' op expects all inputs to have the same shapes. Shape at input-index 1 is not equal to the shape at input-index 0.}}
+ // expected-error @+1{{'linalg.reduce' op inputs must all have the same shape, but inputs[0] has shape [16, 32, 64] while inputs[1] has shape [17, 32, 64]}}
%reduce, %reduce2 = linalg.reduce
ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<17x32x64xf32>)
outs(%init1, %init2 : tensor<16x64xf32>, tensor<17x64xf32>)
@@ -940,7 +940,7 @@ func.func @reduce_different_input_shapes(%input1: tensor<16x32x64xf32>,
func.func @reduce_different_output_shapes(%input1: tensor<16x32x64xf32>,
%init1: tensor<16x64xf32>, %input2: tensor<16x32x64xf32>,
%init2: tensor<17x64xf32>) -> (tensor<16x64xf32>, tensor<17x64xf32>) {
- // expected-error @+1{{'linalg.reduce' op expects all outputs to have the same shapes. Shape at output-index 1 is not equal to the shape at output-index 0.}}
+ // expected-error @+1{{'linalg.reduce' op inits must all have the same shape, but inits[0] has shape [16, 64] while inits[1] has shape [17, 64]}}
%reduce, %reduce2 = linalg.reduce
ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xf32>)
outs(%init1, %init2 : tensor<16x64xf32>, tensor<17x64xf32>)
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index f36678c3d7589..910ed8a89edd6 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -65,7 +65,7 @@ func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) {
// -----
func.func @extract_too_many_indices(%arg0: tensor<?xf32>) {
- // expected-error at +1 {{incorrect number of indices for extract_element}}
+ // expected-error at +1 {{incorrect number of indices, has 0, expected 1}}
%0 = tensor.extract %arg0[] : tensor<?xf32>
return
}
>From 75141d40a3420e4c40a94f4bce380d5b801243a1 Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at microsoft.com>
Date: Sun, 4 Jan 2026 10:38:13 -0600
Subject: [PATCH 2/3] Fix clang-format issues in VerificationUtils (NFC).
---
.../Bufferization/IR/BufferizationOps.cpp | 5 ++---
mlir/lib/Dialect/Utils/VerificationUtils.cpp | 19 ++++++++-----------
2 files changed, 10 insertions(+), 14 deletions(-)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 5df97324c2dce..eda6bf276be06 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -252,9 +252,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
LogicalResult AllocTensorOp::verify() {
if (getCopy() && !getDynamicSizes().empty())
return emitError("dynamic sizes not needed when copying a tensor");
- if (!getCopy() &&
- failed(verifyDynamicDimensionCount(getOperation(), getType(),
- getDynamicSizes())))
+ if (!getCopy() && failed(verifyDynamicDimensionCount(
+ getOperation(), getType(), getDynamicSizes())))
return failure();
if (getCopy() && getCopy().getType() != getType())
return emitError("expected that `copy` and return type match");
diff --git a/mlir/lib/Dialect/Utils/VerificationUtils.cpp b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
index 25b84a67956cb..60a591b9fbaa4 100644
--- a/mlir/lib/Dialect/Utils/VerificationUtils.cpp
+++ b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
@@ -53,9 +53,8 @@ LogicalResult mlir::verifyRankEquals(Operation *op, ShapedType type,
int64_t actualRank = type.getRank();
if (actualRank != expectedRank) {
- return op->emitOpError()
- << typeName << " must have rank " << expectedRank << ", but has "
- << actualRank;
+ return op->emitOpError() << typeName << " must have rank " << expectedRank
+ << ", but has " << actualRank;
}
return success();
}
@@ -95,9 +94,8 @@ LogicalResult mlir::verifyDimensionIndicesInRange(Operation *op,
StringRef context) {
for (int64_t index : indices) {
if (index < 0 || index >= maxDim) {
- return op->emitOpError()
- << context << " must be in the range [0, " << (maxDim - 1)
- << "], but got " << index;
+ return op->emitOpError() << context << " must be in the range [0, "
+ << (maxDim - 1) << "], but got " << index;
}
}
return success();
@@ -195,8 +193,8 @@ LogicalResult mlir::verifyAllElementTypesMatch(Operation *op, ValueRange values,
return op->emitOpError()
<< context << " must all have the same element type, but "
<< context << "[0] has element type " << firstElementType
- << " while " << context << "[" << (idx + 1) << "] has element type "
- << type.getElementType();
+ << " while " << context << "[" << (idx + 1)
+ << "] has element type " << type.getElementType();
}
}
return success();
@@ -227,9 +225,8 @@ LogicalResult mlir::verifyElementCountsMatch(Operation *op, ShapedType type1,
int64_t count1 = type1.getNumElements();
int64_t count2 = type2.getNumElements();
if (count1 != count2) {
- return op->emitOpError()
- << name1 << " has " << count1 << " elements, but " << name2
- << " has " << count2 << " elements";
+ return op->emitOpError() << name1 << " has " << count1 << " elements, but "
+ << name2 << " has " << count2 << " elements";
}
return success();
}
>From 6f5a4ba534311c120987689e185ca1a661a043bf Mon Sep 17 00:00:00 2001
From: Nick Kreeger <nick.kreeger at microsoft.com>
Date: Mon, 5 Jan 2026 14:54:21 -0600
Subject: [PATCH 3/3] Just use verifyDynamicDimensionCount() for now.
---
.../mlir/Dialect/Utils/VerificationUtils.h | 93 +-------
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 36 ++-
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 5 +-
mlir/lib/Dialect/Utils/VerificationUtils.cpp | 210 ------------------
mlir/test/Dialect/Linalg/invalid.mlir | 6 +-
mlir/test/Dialect/Tensor/invalid.mlir | 2 +-
6 files changed, 33 insertions(+), 319 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
index 894d225c32b98..c1c3cc6231eb6 100644
--- a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
@@ -8,8 +8,7 @@
//
// This header file defines common verification utilities that can be shared
// across multiple MLIR dialects. These utilities help reduce code duplication
-// for common verification patterns such as checking dynamic dimensions,
-// rank matching, and index validation.
+// for common verification patterns.
//
//===----------------------------------------------------------------------===//
@@ -22,102 +21,12 @@
namespace mlir {
-//===----------------------------------------------------------------------===//
-// Dynamic Dimension Verification
-//===----------------------------------------------------------------------===//
-
/// Verify that the number of dynamic size operands matches the number of
/// dynamic dimensions in the shaped type. Returns failure and emits an error
/// if the counts don't match.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type,
ValueRange dynamicSizes);
-//===----------------------------------------------------------------------===//
-// Rank Verification
-//===----------------------------------------------------------------------===//
-
-/// Verify that two shaped types have matching ranks. Returns failure and emits
-/// an error if ranks don't match.
-LogicalResult verifyRanksMatch(Operation *op, ShapedType type1,
- ShapedType type2, StringRef name1,
- StringRef name2);
-
-/// Verify that a shaped type has the expected rank. Returns failure and emits
-/// an error if the rank doesn't match.
-LogicalResult verifyRankEquals(Operation *op, ShapedType type,
- int64_t expectedRank, StringRef typeName);
-
-/// Verify that a shaped type's rank is within the specified range [minRank,
-/// maxRank]. Returns failure and emits an error if out of range.
-LogicalResult verifyRankInRange(Operation *op, ShapedType type, int64_t minRank,
- int64_t maxRank, StringRef typeName);
-
-//===----------------------------------------------------------------------===//
-// Index/Dimension Verification
-//===----------------------------------------------------------------------===//
-
-/// Verify that the number of indices matches the rank of the shaped type.
-/// Returns failure and emits an error if the counts don't match.
-LogicalResult verifyIndexCountMatchesRank(Operation *op, int64_t rank,
- size_t indexCount,
- StringRef indexName = "indices");
-
-/// Verify that all dimension indices in the array are within the valid range
-/// [0, maxDim). Returns failure and emits an error if any index is out of
-/// range.
-LogicalResult verifyDimensionIndicesInRange(Operation *op,
- ArrayRef<int64_t> indices,
- int64_t maxDim,
- StringRef context = "dimensions");
-
-/// Verify that all dimension indices are unique (no duplicates). Returns
-/// failure and emits an error if duplicates are found.
-LogicalResult verifyDimensionIndicesUnique(Operation *op,
- ArrayRef<int64_t> indices,
- StringRef context = "dimensions");
-
-//===----------------------------------------------------------------------===//
-// Shape Verification
-//===----------------------------------------------------------------------===//
-
-/// Verify that all values in the range have the same shape. Returns failure
-/// and emits an error if shapes don't match.
-LogicalResult verifyAllShapesMatch(Operation *op, ValueRange values,
- StringRef context);
-
-/// Verify that two shaped types have compatible shapes (same rank and matching
-/// dimensions, with dynamic dimensions considered compatible with any size).
-/// Returns failure and emits an error if incompatible.
-LogicalResult verifyShapesCompatible(Operation *op, ShapedType type1,
- ShapedType type2, StringRef name1,
- StringRef name2);
-
-//===----------------------------------------------------------------------===//
-// Element Type Verification
-//===----------------------------------------------------------------------===//
-
-/// Verify that all values in the range have the same element type. Returns
-/// failure and emits an error if element types don't match.
-LogicalResult verifyAllElementTypesMatch(Operation *op, ValueRange values,
- StringRef context);
-
-/// Verify that two shaped types have the same element type. Returns failure
-/// and emits an error if they don't match.
-LogicalResult verifyElementTypesMatch(Operation *op, ShapedType type1,
- ShapedType type2, StringRef name1,
- StringRef name2);
-
-//===----------------------------------------------------------------------===//
-// Element Count Verification
-//===----------------------------------------------------------------------===//
-
-/// Verify that two shaped types have the same total number of elements.
-/// Returns failure and emits an error if element counts don't match.
-/// Useful for reshape and shape_cast operations.
-LogicalResult verifyElementCountsMatch(Operation *op, ShapedType type1,
- ShapedType type2, StringRef name1,
- StringRef name2);
-
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 25b9f7743d294..210f9584c1e86 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -24,7 +24,6 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
-#include "mlir/Dialect/Utils/VerificationUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -1884,21 +1883,36 @@ void ReduceOp::print(OpAsmPrinter &p) {
LogicalResult ReduceOp::verify() {
ArrayRef<int64_t> dimensionsRef = getDimensions();
- if (failed(verifyAllShapesMatch(getOperation(), getInputs(), "inputs")))
- return failure();
- if (failed(verifyAllShapesMatch(getOperation(), getInits(), "inits")))
- return failure();
+ for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
+ if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
+ llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
+ return emitOpError() << "expects all inputs to have the same shapes. "
+ "Shape at input-index "
+ << i
+ << " is not equal to the shape at input-index 0.";
+ }
+ }
+ for (int64_t i = 1; i < getNumDpsInits(); ++i) {
+ if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
+ llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
+ return emitOpError() << "expects all outputs to have the same shapes. "
+ "Shape at output-index "
+ << i
+ << " is not equal to the shape at output-index 0.";
+ }
+ }
auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
- if (failed(verifyDimensionIndicesInRange(getOperation(), dimensionsRef,
- inputType.getRank(),
- "reduction dimensions")))
- return failure();
-
DenseSet<int64_t> dimensionsToReduce;
- for (int64_t dimension : dimensionsRef)
+ for (int64_t dimension : dimensionsRef) {
+ if (dimension < 0 || dimension >= inputType.getRank()) {
+ return emitOpError()
+ << "dimensions for reduction should be in the range [0, "
+ << inputType.getRank() - 1 << "].";
+ }
dimensionsToReduce.insert(dimension);
+ }
auto inputDims = inputType.getShape();
auto initDims = initType.getShape();
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 32bfa603677ba..a6bd1fc55be78 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -1352,8 +1352,9 @@ void ExtractOp::getAsmResultNames(
LogicalResult ExtractOp::verify() {
// Verify the # indices match if we have a ranked type.
auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
- return verifyIndexCountMatchesRank(getOperation(), tensorType.getRank(),
- getIndices().size());
+ if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
+ return emitOpError("incorrect number of indices for extract_element");
+ return success();
}
/// If we have an ExtractOp consuming an InsertOp with the same
diff --git a/mlir/lib/Dialect/Utils/VerificationUtils.cpp b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
index 60a591b9fbaa4..22b224713a6a3 100644
--- a/mlir/lib/Dialect/Utils/VerificationUtils.cpp
+++ b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
@@ -7,14 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Utils/VerificationUtils.h"
-#include "llvm/ADT/DenseSet.h"
using namespace mlir;
-//===----------------------------------------------------------------------===//
-// Dynamic Dimension Verification
-//===----------------------------------------------------------------------===//
-
LogicalResult mlir::verifyDynamicDimensionCount(Operation *op, ShapedType type,
ValueRange dynamicSizes) {
int64_t expectedCount = type.getNumDynamicDims();
@@ -25,208 +20,3 @@ LogicalResult mlir::verifyDynamicDimensionCount(Operation *op, ShapedType type,
}
return success();
}
-
-//===----------------------------------------------------------------------===//
-// Rank Verification
-//===----------------------------------------------------------------------===//
-
-LogicalResult mlir::verifyRanksMatch(Operation *op, ShapedType type1,
- ShapedType type2, StringRef name1,
- StringRef name2) {
- if (!type1.hasRank() || !type2.hasRank())
- return success(); // Unranked types are considered compatible
-
- int64_t rank1 = type1.getRank();
- int64_t rank2 = type2.getRank();
- if (rank1 != rank2) {
- return op->emitOpError()
- << name1 << " rank (" << rank1 << ") does not match " << name2
- << " rank (" << rank2 << ")";
- }
- return success();
-}
-
-LogicalResult mlir::verifyRankEquals(Operation *op, ShapedType type,
- int64_t expectedRank, StringRef typeName) {
- if (!type.hasRank())
- return success();
-
- int64_t actualRank = type.getRank();
- if (actualRank != expectedRank) {
- return op->emitOpError() << typeName << " must have rank " << expectedRank
- << ", but has " << actualRank;
- }
- return success();
-}
-
-LogicalResult mlir::verifyRankInRange(Operation *op, ShapedType type,
- int64_t minRank, int64_t maxRank,
- StringRef typeName) {
- if (!type.hasRank())
- return success();
-
- int64_t rank = type.getRank();
- if (rank < minRank || rank > maxRank) {
- return op->emitOpError()
- << typeName << " rank must be in range [" << minRank << ", "
- << maxRank << "], but has rank " << rank;
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Index/Dimension Verification
-//===----------------------------------------------------------------------===//
-
-LogicalResult mlir::verifyIndexCountMatchesRank(Operation *op, int64_t rank,
- size_t indexCount,
- StringRef indexName) {
- if (rank != static_cast<int64_t>(indexCount)) {
- return op->emitOpError("incorrect number of ")
- << indexName << ", has " << indexCount << ", expected " << rank;
- }
- return success();
-}
-
-LogicalResult mlir::verifyDimensionIndicesInRange(Operation *op,
- ArrayRef<int64_t> indices,
- int64_t maxDim,
- StringRef context) {
- for (int64_t index : indices) {
- if (index < 0 || index >= maxDim) {
- return op->emitOpError() << context << " must be in the range [0, "
- << (maxDim - 1) << "], but got " << index;
- }
- }
- return success();
-}
-
-LogicalResult mlir::verifyDimensionIndicesUnique(Operation *op,
- ArrayRef<int64_t> indices,
- StringRef context) {
- llvm::DenseSet<int64_t> seen;
- for (int64_t index : indices) {
- if (!seen.insert(index).second) {
- return op->emitOpError()
- << context << " contains duplicate index " << index;
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Shape Verification
-//===----------------------------------------------------------------------===//
-
-LogicalResult mlir::verifyAllShapesMatch(Operation *op, ValueRange values,
- StringRef context) {
- if (values.empty())
- return success();
-
- auto firstType = llvm::dyn_cast<ShapedType>(values.front().getType());
- if (!firstType || !firstType.hasRank())
- return success();
-
- ArrayRef<int64_t> firstShape = firstType.getShape();
- for (auto [idx, value] : llvm::enumerate(values.drop_front())) {
- auto type = llvm::dyn_cast<ShapedType>(value.getType());
- if (!type || !type.hasRank())
- continue;
-
- if (type.getShape() != firstShape) {
- return op->emitOpError()
- << context << " must all have the same shape, but " << context
- << "[0] has shape [" << firstShape << "] while " << context << "["
- << (idx + 1) << "] has shape [" << type.getShape() << "]";
- }
- }
- return success();
-}
-
-LogicalResult mlir::verifyShapesCompatible(Operation *op, ShapedType type1,
- ShapedType type2, StringRef name1,
- StringRef name2) {
- if (!type1.hasRank() || !type2.hasRank())
- return success();
-
- if (type1.getRank() != type2.getRank()) {
- return op->emitOpError()
- << name1 << " and " << name2 << " must have the same rank";
- }
-
- ArrayRef<int64_t> shape1 = type1.getShape();
- ArrayRef<int64_t> shape2 = type2.getShape();
- for (auto [idx, dims] : llvm::enumerate(llvm::zip(shape1, shape2))) {
- auto [dim1, dim2] = dims;
- // Dynamic dimensions are compatible with anything
- if (ShapedType::isDynamic(dim1) || ShapedType::isDynamic(dim2))
- continue;
- if (dim1 != dim2) {
- return op->emitOpError()
- << name1 << " and " << name2 << " have incompatible shapes at "
- << "dimension " << idx << ": " << dim1 << " vs " << dim2;
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Element Type Verification
-//===----------------------------------------------------------------------===//
-
-LogicalResult mlir::verifyAllElementTypesMatch(Operation *op, ValueRange values,
- StringRef context) {
- if (values.empty())
- return success();
-
- auto firstType = llvm::dyn_cast<ShapedType>(values.front().getType());
- if (!firstType)
- return success();
-
- Type firstElementType = firstType.getElementType();
- for (auto [idx, value] : llvm::enumerate(values.drop_front())) {
- auto type = llvm::dyn_cast<ShapedType>(value.getType());
- if (!type)
- continue;
-
- if (type.getElementType() != firstElementType) {
- return op->emitOpError()
- << context << " must all have the same element type, but "
- << context << "[0] has element type " << firstElementType
- << " while " << context << "[" << (idx + 1)
- << "] has element type " << type.getElementType();
- }
- }
- return success();
-}
-
-LogicalResult mlir::verifyElementTypesMatch(Operation *op, ShapedType type1,
- ShapedType type2, StringRef name1,
- StringRef name2) {
- if (type1.getElementType() != type2.getElementType()) {
- return op->emitOpError()
- << name1 << " element type (" << type1.getElementType()
- << ") does not match " << name2 << " element type ("
- << type2.getElementType() << ")";
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// Element Count Verification
-//===----------------------------------------------------------------------===//
-
-LogicalResult mlir::verifyElementCountsMatch(Operation *op, ShapedType type1,
- ShapedType type2, StringRef name1,
- StringRef name2) {
- if (!type1.hasStaticShape() || !type2.hasStaticShape())
- return success(); // Can't verify dynamic shapes at compile time
-
- int64_t count1 = type1.getNumElements();
- int64_t count2 = type2.getNumElements();
- if (count1 != count2) {
- return op->emitOpError() << name1 << " has " << count1 << " elements, but "
- << name2 << " has " << count2 << " elements";
- }
- return success();
-}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index f9b411d7ca766..5a699135604b7 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -802,7 +802,7 @@ func.func @reduce_input_vs_init_dimension_mismatch(
func.func @reduce_dimensions_out_of_range(%input: tensor<16x32x64xf32>,
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
- // expected-error @+1 {{'linalg.reduce' op reduction dimensions must be in the range [0, 2], but got 3}}
+ // expected-error @+1 {{'linalg.reduce' op dimensions for reduction should be in the range [0, 2].}}
%reduce = linalg.reduce
ins(%input:tensor<16x32x64xf32>)
outs(%init:tensor<16x64xf32>)
@@ -922,7 +922,7 @@ func.func @reduce_wrong_block_argument_output_type(
func.func @reduce_different_input_shapes(%input1: tensor<16x32x64xf32>,
%init1: tensor<16x64xf32>, %input2: tensor<17x32x64xf32>,
%init2: tensor<17x64xf32>) -> (tensor<16x64xf32>, tensor<17x64xf32>) {
- // expected-error @+1{{'linalg.reduce' op inputs must all have the same shape, but inputs[0] has shape [16, 32, 64] while inputs[1] has shape [17, 32, 64]}}
+ // expected-error @+1{{'linalg.reduce' op expects all inputs to have the same shapes. Shape at input-index 1 is not equal to the shape at input-index 0.}}
%reduce, %reduce2 = linalg.reduce
ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<17x32x64xf32>)
outs(%init1, %init2 : tensor<16x64xf32>, tensor<17x64xf32>)
@@ -940,7 +940,7 @@ func.func @reduce_different_input_shapes(%input1: tensor<16x32x64xf32>,
func.func @reduce_different_output_shapes(%input1: tensor<16x32x64xf32>,
%init1: tensor<16x64xf32>, %input2: tensor<16x32x64xf32>,
%init2: tensor<17x64xf32>) -> (tensor<16x64xf32>, tensor<17x64xf32>) {
- // expected-error @+1{{'linalg.reduce' op inits must all have the same shape, but inits[0] has shape [16, 64] while inits[1] has shape [17, 64]}}
+ // expected-error @+1{{'linalg.reduce' op expects all outputs to have the same shapes. Shape at output-index 1 is not equal to the shape at output-index 0.}}
%reduce, %reduce2 = linalg.reduce
ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xf32>)
outs(%init1, %init2 : tensor<16x64xf32>, tensor<17x64xf32>)
diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir
index 910ed8a89edd6..f36678c3d7589 100644
--- a/mlir/test/Dialect/Tensor/invalid.mlir
+++ b/mlir/test/Dialect/Tensor/invalid.mlir
@@ -65,7 +65,7 @@ func.func @concat_static_shape_mismatch(%arg0: tensor<3xf32>) {
// -----
func.func @extract_too_many_indices(%arg0: tensor<?xf32>) {
- // expected-error at +1 {{incorrect number of indices, has 0, expected 1}}
+ // expected-error at +1 {{incorrect number of indices for extract_element}}
%0 = tensor.extract %arg0[] : tensor<?xf32>
return
}
More information about the Mlir-commits
mailing list