[Mlir-commits] [mlir] [mlir][tosa] Verify the output shape of tosa.mul and tosa.rescale (PR #193952)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 24 05:12:21 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

Verifying the provided output shape against an expected shape helps diagnose issues on op construction.

---
Full diff: https://github.com/llvm/llvm-project/pull/193952.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+54-47) 
- (modified) mlir/test/Dialect/Tosa/verifier.mlir (+21) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index fa4bc120e9c1e..eb5556dc8ee74 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -662,6 +662,21 @@ static void printShapeToDiagnostic(InFlightDiagnostic &diag,
   llvm::interleaveComma(shape, diag, printDim);
 }
 
+static LogicalResult
+verifyOutputShapeCompatibleWithExpected(Operation *op, ShapedType outputType,
+                                        ArrayRef<int64_t> expectedShape,
+                                        StringRef outputName = "output") {
+  if (succeeded(verifyCompatibleShape(outputType.getShape(), expectedShape)))
+    return success();
+
+  InFlightDiagnostic diag = op->emitOpError("expected ");
+  diag << outputName << " shape ";
+  printShapeToDiagnostic(diag, outputType.getShape());
+  diag << " to be compatible with expected shape ";
+  printShapeToDiagnostic(diag, expectedShape);
+  return diag;
+}
+
 LogicalResult verifyConvOutputSize(
     Operation *op, const int64_t inputSize, const int64_t kernelSize,
     const int64_t outputSize, const int64_t padBefore, const int64_t padAfter,
@@ -2542,6 +2557,10 @@ LogicalResult tosa::MulOp::verify() {
 
   const bool aHasRank = aType.hasRank();
   const bool bHasRank = bType.hasRank();
+
+  bool hasExpectedOutputShape = false;
+  SmallVector<int64_t> expectedOutputShape;
+
   if (aHasRank && bHasRank) {
     const int64_t aRank = aType.getRank();
     const int64_t bRank = bType.getRank();
@@ -2550,12 +2569,12 @@ LogicalResult tosa::MulOp::verify() {
              << aRank << " and " << bRank;
 
     // check for broadcast compatible shapes
-    SmallVector<int64_t> resultShape;
     if (!mlir::OpTrait::util::getBroadcastedShape(
-            aType.getShape(), bType.getShape(), resultShape))
+            aType.getShape(), bType.getShape(), expectedOutputShape))
       return emitOpError("a and b operands don't have broadcast-compatible "
                          "shapes, got ")
              << aType << " and " << bType;
+    hasExpectedOutputShape = true;
   }
 
   ShapedType resultType = cast<ShapedType>(output.getType());
@@ -2570,6 +2589,11 @@ LogicalResult tosa::MulOp::verify() {
     return emitOpError("result type has different rank than b, got ")
            << resultRank << " vs " << bType.getRank();
 
+  if (hasExpectedOutputShape &&
+      failed(verifyOutputShapeCompatibleWithExpected(getOperation(), resultType,
+                                                     expectedOutputShape)))
+    return failure();
+
   return success();
 }
 
@@ -4846,12 +4870,7 @@ LogicalResult TransposeConv2DOp::verify() {
 }
 
 LogicalResult RescaleOp::verify() {
-  auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
-  if (!inputType) {
-    emitOpError("expect shaped tensor for input, got ") << getInput().getType();
-    return failure();
-  }
-
+  const auto inputType = llvm::cast<ShapedType>(getInput().getType());
   auto inputElementType =
       getStorageElementTypeOrSelf(inputType.getElementType());
   if (!mlir::isa<IntegerType>(inputElementType)) {
@@ -4860,13 +4879,7 @@ LogicalResult RescaleOp::verify() {
     return failure();
   }
 
-  auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
-  if (!outputType) {
-    emitOpError("expect shaped tensor for output, got ")
-        << getOutput().getType();
-    return failure();
-  }
-
+  const auto outputType = llvm::cast<ShapedType>(getOutput().getType());
   auto outputElementType =
       getStorageElementTypeOrSelf(outputType.getElementType());
   if (!mlir::isa<IntegerType>(outputElementType)) {
@@ -4891,19 +4904,7 @@ LogicalResult RescaleOp::verify() {
   if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
     return failure();
 
-  auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
-  if (!multiplierType) {
-    emitOpError("expect shaped tensor for multiplier, got ")
-        << getMultiplier().getType();
-    return failure();
-  }
-
-  auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
-  if (!shiftType) {
-    emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
-    return failure();
-  }
-
+  const auto multiplierType = cast<ShapedType>(getMultiplier().getType());
   // multiplier element type must be i32 for scale32 = true
   if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
     emitOpError("expect i32 element type for multiplier for scale32=true, got ")
@@ -4936,28 +4937,34 @@ LogicalResult RescaleOp::verify() {
     numChannels = inputType.getDimSize(inputType.getRank() - 1);
   }
 
-  if (!multiplierType.hasRank())
-    return success();
-
-  ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
-  // multiplier input has rank 1 by dialect definition
-  if (multiplierShape[0] != ShapedType::kDynamic &&
-      multiplierShape[0] != numChannels) {
-    emitOpError("expect shape of { ")
-        << numChannels << " } for multiplier input, got { "
-        << multiplierShape[0] << " }";
-    return failure();
+  if (outputType.hasRank()) {
+    if (failed(verifyOutputShapeCompatibleWithExpected(
+            getOperation(), outputType, inputType.getShape())))
+      return failure();
   }
 
-  if (!shiftType.hasRank())
-    return success();
+  if (multiplierType.hasRank()) {
+    ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
+    // multiplier input has rank 1 by dialect definition
+    if (multiplierShape[0] != ShapedType::kDynamic &&
+        multiplierShape[0] != numChannels) {
+      emitOpError("expect shape of { ")
+          << numChannels << " } for multiplier input, got { "
+          << multiplierShape[0] << " }";
+      return failure();
+    }
+  }
 
-  ArrayRef<int64_t> shiftShape = shiftType.getShape();
-  // shift input has rank 1 by dialect definition
-  if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
-    emitOpError("expect shape of { ")
-        << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
-    return failure();
+  const auto shiftType = cast<ShapedType>(getShift().getType());
+  if (shiftType.hasRank()) {
+    ArrayRef<int64_t> shiftShape = shiftType.getShape();
+    // shift input has rank 1 by dialect definition
+    if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
+      emitOpError("expect shape of { ")
+          << numChannels << " } for shift input, got { " << shiftShape[0]
+          << " }";
+      return failure();
+    }
   }
 
   return success();
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 1572df5357877..ca16435099744 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -423,6 +423,27 @@ func.func @test_error_scalar_input_with_per_channel(%arg0: tensor<i8>) -> tensor
 
 // -----
 
+func.func @test_rescale_invalid_static_output_shape(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x4xi8> {
+  %multiplier = "tosa.const"() <{values = dense<42> : tensor<1xi16>}> : () -> tensor<1xi16>
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %input_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %output_zp = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.rescale' op expected output shape 13, 21, 4 to be compatible with expected shape 13, 21, 3}}
+  %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = SINGLE_ROUND, per_channel = false, input_unsigned = false, output_unsigned = false} : (tensor<13x21x3xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x4xi8>
+  return %0 : tensor<13x21x4xi8>
+}
+
+// -----
+
+func.func @test_mul_invalid_static_output_shape(%arg0: tensor<?x21x1xf32>, %arg1: tensor<?x1x3xf32>) -> tensor<?x21x2xf32> {
+  %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  // expected-error at +1 {{'tosa.mul' op expected output shape ?, 21, 2 to be compatible with expected shape ?, 21, 3}}
+  %0 = tosa.mul %arg0, %arg1, %shift : (tensor<?x21x1xf32>, tensor<?x1x3xf32>, tensor<1xi8>) -> tensor<?x21x2xf32>
+  return %0 : tensor<?x21x2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @test_gather_invalid_indices_N
 func.func @test_gather_invalid_indices_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<12x26xi32>) -> tensor<13x26x3xf32> {
   // expected-error at +1 {{'tosa.gather' op requires indices dimension 0 to have size 13, got 12}}

``````````

</details>


https://github.com/llvm/llvm-project/pull/193952


More information about the Mlir-commits mailing list