[Mlir-commits] [mlir] [mlir][tosa] Verify the output shape of tosa.mul and tosa.rescale (PR #193952)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Apr 24 05:12:21 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
Verifying the provided output shape against an expected shape helps diagnose issues on op construction.
---
Full diff: https://github.com/llvm/llvm-project/pull/193952.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+54-47)
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+21)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index fa4bc120e9c1e..eb5556dc8ee74 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -662,6 +662,21 @@ static void printShapeToDiagnostic(InFlightDiagnostic &diag,
llvm::interleaveComma(shape, diag, printDim);
}
+static LogicalResult
+verifyOutputShapeCompatibleWithExpected(Operation *op, ShapedType outputType,
+ ArrayRef<int64_t> expectedShape,
+ StringRef outputName = "output") {
+ if (succeeded(verifyCompatibleShape(outputType.getShape(), expectedShape)))
+ return success();
+
+ InFlightDiagnostic diag = op->emitOpError("expected ");
+ diag << outputName << " shape ";
+ printShapeToDiagnostic(diag, outputType.getShape());
+ diag << " to be compatible with expected shape ";
+ printShapeToDiagnostic(diag, expectedShape);
+ return diag;
+}
+
LogicalResult verifyConvOutputSize(
Operation *op, const int64_t inputSize, const int64_t kernelSize,
const int64_t outputSize, const int64_t padBefore, const int64_t padAfter,
@@ -2542,6 +2557,10 @@ LogicalResult tosa::MulOp::verify() {
const bool aHasRank = aType.hasRank();
const bool bHasRank = bType.hasRank();
+
+ bool hasExpectedOutputShape = false;
+ SmallVector<int64_t> expectedOutputShape;
+
if (aHasRank && bHasRank) {
const int64_t aRank = aType.getRank();
const int64_t bRank = bType.getRank();
@@ -2550,12 +2569,12 @@ LogicalResult tosa::MulOp::verify() {
<< aRank << " and " << bRank;
// check for broadcast compatible shapes
- SmallVector<int64_t> resultShape;
if (!mlir::OpTrait::util::getBroadcastedShape(
- aType.getShape(), bType.getShape(), resultShape))
+ aType.getShape(), bType.getShape(), expectedOutputShape))
return emitOpError("a and b operands don't have broadcast-compatible "
"shapes, got ")
<< aType << " and " << bType;
+ hasExpectedOutputShape = true;
}
ShapedType resultType = cast<ShapedType>(output.getType());
@@ -2570,6 +2589,11 @@ LogicalResult tosa::MulOp::verify() {
return emitOpError("result type has different rank than b, got ")
<< resultRank << " vs " << bType.getRank();
+ if (hasExpectedOutputShape &&
+ failed(verifyOutputShapeCompatibleWithExpected(getOperation(), resultType,
+ expectedOutputShape)))
+ return failure();
+
return success();
}
@@ -4846,12 +4870,7 @@ LogicalResult TransposeConv2DOp::verify() {
}
LogicalResult RescaleOp::verify() {
- auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
- if (!inputType) {
- emitOpError("expect shaped tensor for input, got ") << getInput().getType();
- return failure();
- }
-
+ const auto inputType = llvm::cast<ShapedType>(getInput().getType());
auto inputElementType =
getStorageElementTypeOrSelf(inputType.getElementType());
if (!mlir::isa<IntegerType>(inputElementType)) {
@@ -4860,13 +4879,7 @@ LogicalResult RescaleOp::verify() {
return failure();
}
- auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
- if (!outputType) {
- emitOpError("expect shaped tensor for output, got ")
- << getOutput().getType();
- return failure();
- }
-
+ const auto outputType = llvm::cast<ShapedType>(getOutput().getType());
auto outputElementType =
getStorageElementTypeOrSelf(outputType.getElementType());
if (!mlir::isa<IntegerType>(outputElementType)) {
@@ -4891,19 +4904,7 @@ LogicalResult RescaleOp::verify() {
if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
return failure();
- auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
- if (!multiplierType) {
- emitOpError("expect shaped tensor for multiplier, got ")
- << getMultiplier().getType();
- return failure();
- }
-
- auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
- if (!shiftType) {
- emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
- return failure();
- }
-
+ const auto multiplierType = cast<ShapedType>(getMultiplier().getType());
// multiplier element type must be i32 for scale32 = true
if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
emitOpError("expect i32 element type for multiplier for scale32=true, got ")
@@ -4936,28 +4937,34 @@ LogicalResult RescaleOp::verify() {
numChannels = inputType.getDimSize(inputType.getRank() - 1);
}
- if (!multiplierType.hasRank())
- return success();
-
- ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
- // multiplier input has rank 1 by dialect definition
- if (multiplierShape[0] != ShapedType::kDynamic &&
- multiplierShape[0] != numChannels) {
- emitOpError("expect shape of { ")
- << numChannels << " } for multiplier input, got { "
- << multiplierShape[0] << " }";
- return failure();
+ if (outputType.hasRank()) {
+ if (failed(verifyOutputShapeCompatibleWithExpected(
+ getOperation(), outputType, inputType.getShape())))
+ return failure();
}
- if (!shiftType.hasRank())
- return success();
+ if (multiplierType.hasRank()) {
+ ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
+ // multiplier input has rank 1 by dialect definition
+ if (multiplierShape[0] != ShapedType::kDynamic &&
+ multiplierShape[0] != numChannels) {
+ emitOpError("expect shape of { ")
+ << numChannels << " } for multiplier input, got { "
+ << multiplierShape[0] << " }";
+ return failure();
+ }
+ }
- ArrayRef<int64_t> shiftShape = shiftType.getShape();
- // shift input has rank 1 by dialect definition
- if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
- emitOpError("expect shape of { ")
- << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
- return failure();
+ const auto shiftType = cast<ShapedType>(getShift().getType());
+ if (shiftType.hasRank()) {
+ ArrayRef<int64_t> shiftShape = shiftType.getShape();
+ // shift input has rank 1 by dialect definition
+ if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
+ emitOpError("expect shape of { ")
+ << numChannels << " } for shift input, got { " << shiftShape[0]
+ << " }";
+ return failure();
+ }
}
return success();
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 1572df5357877..ca16435099744 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -423,6 +423,27 @@ func.func @test_error_scalar_input_with_per_channel(%arg0: tensor<i8>) -> tensor
// -----
+func.func @test_rescale_invalid_static_output_shape(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x4xi8> {
+ %multiplier = "tosa.const"() <{values = dense<42> : tensor<1xi16>}> : () -> tensor<1xi16>
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %output_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.rescale' op expected output shape 13, 21, 4 to be compatible with expected shape 13, 21, 3}}
+ %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x4xi8>
+ return %0 : tensor<13x21x4xi8>
+}
+
+// -----
+
+func.func @test_mul_invalid_static_output_shape(%arg0: tensor<?x21x1xf32>, %arg1: tensor<?x1x3xf32>) -> tensor<?x21x2xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error at +1 {{'tosa.mul' op expected output shape ?, 21, 2 to be compatible with expected shape ?, 21, 3}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<?x21x1xf32>, tensor<?x1x3xf32>, tensor<1xi8>) -> tensor<?x21x2xf32>
+ return %0 : tensor<?x21x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: @test_gather_invalid_indices_N
func.func @test_gather_invalid_indices_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<12x26xi32>) -> tensor<13x26x3xf32> {
// expected-error at +1 {{'tosa.gather' op requires indices dimension 0 to have size 13, got 12}}
``````````
</details>
https://github.com/llvm/llvm-project/pull/193952
More information about the Mlir-commits
mailing list