[Mlir-commits] [mlir] 585efb4 - [mlir][Utils] Add verifyRanksMatch helper (NFC) (#175880)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jan 17 10:50:56 PST 2026
Author: Nick Kreeger
Date: 2026-01-17T18:50:51Z
New Revision: 585efb4c74c865b548876abc594939c40d262f6e
URL: https://github.com/llvm/llvm-project/commit/585efb4c74c865b548876abc594939c40d262f6e
DIFF: https://github.com/llvm/llvm-project/commit/585efb4c74c865b548876abc594939c40d262f6e.diff
LOG: [mlir][Utils] Add verifyRanksMatch helper (NFC) (#175880)
This change builds on https://github.com/llvm/llvm-project/pull/174336,
which introduced shared VerificationUtils with an initial
verifyDynamicDimensionCount() method.
This patch adds a new verifyRanksMatch() verification utility that
checks if two shaped types have matching ranks and emits consistent
error messages. The utility is applied to several ops across multiple
MLIR dialects.
---------
Co-authored-by: Andrzej WarzyĆski <andrzej.warzynski at gmail.com>
Added:
Modified:
mlir/include/mlir/Dialect/Utils/VerificationUtils.h
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/lib/Dialect/Utils/VerificationUtils.cpp
mlir/test/Dialect/Bufferization/invalid.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Tosa/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
index c1c3cc6231eb6..3d350aae7cf2f 100644
--- a/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/VerificationUtils.h
@@ -27,6 +27,11 @@ namespace mlir {
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type,
ValueRange dynamicSizes);
+/// Verify that two shaped types have matching ranks. Returns failure and emits
+/// an error if ranks don't match. Unranked types are considered compatible.
+LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs,
+ StringRef lhsName, StringRef rhsName);
+
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_VERIFICATIONUTILS_H
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 4515a5b5a2671..cab51a3372429 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -702,8 +702,9 @@ LogicalResult MaterializeInDestinationOp::verify() {
if (srcType.hasRank() != destType.hasRank())
return emitOpError("source/destination shapes are incompatible");
if (srcType.hasRank()) {
- if (srcType.getRank() != destType.getRank())
- return emitOpError("rank mismatch between source and destination shape");
+ if (failed(verifyRanksMatch(getOperation(), srcType, destType, "source",
+ "destination")))
+ return failure();
for (auto [src, dest] :
llvm::zip(srcType.getShape(), destType.getShape())) {
if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 210f9584c1e86..e9f617a785d22 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -24,6 +24,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
@@ -2043,9 +2044,9 @@ LogicalResult TransposeOp::verify() {
int64_t rank = inputType.getRank();
- if (rank != initType.getRank())
- return emitOpError() << "input rank " << rank
- << " does not match init rank " << initType.getRank();
+ if (failed(verifyRanksMatch(getOperation(), inputType, initType, "input",
+ "init")))
+ return failure();
if (rank != static_cast<int64_t>(permutationRef.size()))
return emitOpError() << "size of permutation " << permutationRef.size()
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4aaadf28d7a61..033feb79405df 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
@@ -1140,9 +1141,8 @@ static LogicalResult verifyPoolingOp(T op) {
const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
return op.emitOpError("calculated output ")
- << dimName << " did not match expected: "
- << "calculated=" << calculatedOutSize
- << ", expected=" << outputSize;
+ << dimName << " did not match expected: " << "calculated="
+ << calculatedOutSize << ", expected=" << outputSize;
return success();
};
@@ -2047,8 +2047,7 @@ LogicalResult MatmulTBlockScaledOp::verify() {
multiplesOfC != C / blockSize)
return emitOpError(
"expect scale operands dimension 2 to equal C/block_size (")
- << C << "/" << blockSize << ")"
- << ", got " << multiplesOfC;
+ << C << "/" << blockSize << ")" << ", got " << multiplesOfC;
// Verify output shape
N = ShapedType::isDynamic(N) ? D : N;
@@ -2142,13 +2141,11 @@ LogicalResult tosa::PadOp::verify() {
if (!inputType || !outputType)
return success();
- auto inputRank = inputType.getRank();
- auto outputRank = outputType.getRank();
- if (inputRank != outputRank)
- return emitOpError() << "expect same input and output tensor rank, but got "
- << "inputRank: " << inputRank
- << ", outputRank: " << outputRank;
+ if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
+ "output")))
+ return failure();
+ auto inputRank = inputType.getRank();
DenseIntElementsAttr paddingAttr;
if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
return failure();
@@ -2389,9 +2386,9 @@ LogicalResult tosa::TableOp::verify() {
if (!inputType.hasRank() || !outputType.hasRank())
return success();
- if (inputType.getRank() != outputType.getRank())
- return emitOpError()
- << "expected input tensor rank to equal result tensor rank";
+ if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
+ "result")))
+ return failure();
auto inputDims = inputType.getShape();
auto outputDims = outputType.getShape();
@@ -2481,8 +2478,10 @@ LogicalResult tosa::TileOp::verify() {
if (inputType.getRank() != multiplesRank)
return emitOpError("expect 'multiples' to have rank ")
<< inputType.getRank() << " but got " << multiplesRank << ".";
- if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
- return emitOpError("expect same input and output tensor rank.");
+ if (outputType.hasRank() &&
+ failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
+ "output")))
+ return failure();
} else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
return emitOpError("expect 'multiples' array to have length ")
<< outputType.getRank() << " but got " << multiplesRank << ".";
@@ -4230,8 +4229,8 @@ LogicalResult CastFromBlockScaledOp::verify() {
const Type outputDataType = getResult().getType();
if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
return emitOpError() << "require compatible shapes for input_data ("
- << inputDataType << ") and "
- << "output_data (" << outputDataType << ")";
+ << inputDataType << ") and " << "output_data ("
+ << outputDataType << ")";
const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
@@ -4258,10 +4257,10 @@ LogicalResult CastFromBlockScaledOp::verify() {
failed(verifyCompatibleShape(
ArrayRef<int64_t>(inputDataDims).drop_back(1),
ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
- return emitOpError() << "require compatible shapes for input_data ("
- << inputDataType << ") and "
- << "input_scale (" << inputScaleType
- << ") except for the last dimension";
+ return emitOpError()
+ << "require compatible shapes for input_data (" << inputDataType
+ << ") and " << "input_scale (" << inputScaleType
+ << ") except for the last dimension";
const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
inputScaleDims.back()};
@@ -4306,8 +4305,8 @@ LogicalResult CastToBlockScaledOp::verify() {
const Type outputDataType = getResult(0).getType();
if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
return emitOpError() << "require compatible shapes for input_data ("
- << inputDataType << ") and "
- << "output_data (" << outputDataType << ")";
+ << inputDataType << ") and " << "output_data ("
+ << outputDataType << ")";
const unsigned int blockSize =
BlockSizeAttr::getBlockSizeValue(getBlockSize());
@@ -4336,8 +4335,8 @@ LogicalResult CastToBlockScaledOp::verify() {
ArrayRef<int64_t>(outputDataDims).drop_back(1),
ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
return emitOpError() << "require compatible shapes for output_data ("
- << outputDataType << ") and "
- << "output_scale (" << outputScaleType
+ << outputDataType << ") and " << "output_scale ("
+ << outputScaleType
<< ") except for the last dimension";
const int64_t outputDataLastDim = outputDataDims.back();
@@ -4513,9 +4512,9 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(parser.getCurrentLocation())
- << "expected as many input types as operands "
- << "(expected " << operands.size() << " got "
- << functionType.getNumInputs() << ")";
+ << "expected as many input types as operands " << "(expected "
+ << operands.size() << " got " << functionType.getNumInputs()
+ << ")";
}
// Resolve input operands.
@@ -4764,9 +4763,8 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(typeLoc)
- << "expected as many input types as operands "
- << "(expected " << operands.size() << " got "
- << functionType.getNumInputs() << ")";
+ << "expected as many input types as operands " << "(expected "
+ << operands.size() << " got " << functionType.getNumInputs() << ")";
}
// Resolve input operands.
diff --git a/mlir/lib/Dialect/Utils/VerificationUtils.cpp b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
index 22b224713a6a3..9b3bfda42b91a 100644
--- a/mlir/lib/Dialect/Utils/VerificationUtils.cpp
+++ b/mlir/lib/Dialect/Utils/VerificationUtils.cpp
@@ -20,3 +20,19 @@ LogicalResult mlir::verifyDynamicDimensionCount(Operation *op, ShapedType type,
}
return success();
}
+
+LogicalResult mlir::verifyRanksMatch(Operation *op, ShapedType type1,
+ ShapedType type2, StringRef name1,
+ StringRef name2) {
+ if (!type1.hasRank() || !type2.hasRank())
+ return success(); // Unranked types are considered compatible
+
+ int64_t rank1 = type1.getRank();
+ int64_t rank2 = type2.getRank();
+ if (rank1 != rank2) {
+ return op->emitOpError()
+ << name1 << " rank (" << rank1 << ") does not match " << name2
+ << " rank (" << rank2 << ")";
+ }
+ return success();
+}
diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir
index 76aba14bc50f2..f3368b9108d65 100644
--- a/mlir/test/Dialect/Bufferization/invalid.mlir
+++ b/mlir/test/Dialect/Bufferization/invalid.mlir
@@ -51,7 +51,7 @@ func.func @invalid_materialize_in_destination(%arg0: tensor<4xf32>, %arg1: tenso
// -----
func.func @invalid_materialize_in_destination(%arg0: tensor<5x5xf32>, %arg1: tensor<5xf32>) {
- // expected-error @below{{rank mismatch between source and destination shape}}
+ // expected-error @below{{'bufferization.materialize_in_destination' op source rank (2) does not match destination rank (1)}}
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5x5xf32>, tensor<5xf32>) -> tensor<5xf32>
}
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index d70cdceed6b86..355d801f8732c 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -1068,7 +1068,7 @@ func.func @transpose_rank_permutation_size_mismatch(
func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>,
%init: tensor<32x64x16xf32>) -> tensor<32x64x16xf32> {
- // expected-error @+1 {{'linalg.transpose' op input rank 2 does not match init rank 3}}
+ // expected-error @+1 {{'linalg.transpose' op input rank (2) does not match init rank (3)}}
%transpose = linalg.transpose
ins(%input:tensor<16x32xf32>)
outs(%init:tensor<32x64x16xf32>)
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index aafb688750433..3fb53fa2cd41f 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -292,7 +292,7 @@ func.func @test_pad_const_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor<1xi
func.func @test_pad_io_rank_mismatch(%arg0: tensor<13x21xf32>) {
%0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4>
%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
- // expected-error at +1 {{'tosa.pad' op expect same input and output tensor rank}}
+ // expected-error at +1 {{'tosa.pad' op input rank (2) does not match output rank (3)}}
%1 = tosa.pad %arg0, %0, %pad_const : (tensor<13x21xf32>, !tosa.shape<4>, tensor<1xf32>) -> tensor<13x21x3xf32>
}
@@ -663,7 +663,7 @@ func.func @test_tile_invalid_multiples_value() {
func.func @test_tile_io_rank_mismatch() {
%0 = tensor.empty() : tensor<4x31xf32>
%multiples = tosa.const_shape { values = dense<[2, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
- // expected-error at +1 {{'tosa.tile' op expect same input and output tensor rank.}}
+ // expected-error at +1 {{'tosa.tile' op input rank (2) does not match output rank (3)}}
%1 = tosa.tile %0, %multiples : (tensor<4x31xf32>, !tosa.shape<2>) -> tensor<4x31x31xf32>
return
}
@@ -682,7 +682,7 @@ func.func @test_table_rank0_table(%arg0: tensor<64xi16>, %arg1: tensor<i16>) {
// CHECK-LABEL: test_table_io_rank_mismatch
func.func @test_table_io_rank_mismatch(%arg0: tensor<64xi16>, %arg1: tensor<6xi16>) {
- // expected-error at +1 {{'tosa.table' op expected input tensor rank to equal result tensor rank}}
+ // expected-error at +1 {{'tosa.table' op input rank (1) does not match result rank (2)}}
%0 = tosa.table %arg0, %arg1 : (tensor<64xi16>, tensor<6xi16>) -> tensor<64x?xi16>
return
}
More information about the Mlir-commits
mailing list