[Mlir-commits] [mlir] [mlir][tosa] Allow dynamic dims in `--tosa-validate` pass (PR #171463)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Dec 9 07:59:53 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
This commit allows tensor dimensions to be dynamic when the specified target TOSA specification version is `1.1.draft` or higher. This is because this version of the specification supports representation operations that are dynamic until backend compile time.
---
Full diff: https://github.com/llvm/llvm-project/pull/171463.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+15-2)
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir (+8)
- (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+8)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index b54ed5585d72d..9c7bc83f77ec7 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -688,9 +688,22 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op,
return op->emitOpError() << "failed level check: unranked tensor";
auto shape = type.getShape();
for (auto dim : shape) {
- if (mlir::ShapedType::isDynamic(dim))
+ const bool dimIsDynamic = mlir::ShapedType::isDynamic(dim);
+ const TosaSpecificationVersion targetVersion = targetEnv.getSpecVersion();
+ const TosaSpecificationVersion minRequiredVersion(1, 1);
+ if (targetVersion.isBackwardsCompatibleWith(minRequiredVersion) &&
+ dimIsDynamic)
+ // TOSA 1.1 and above supports dynamic dimensions, however, they must be
+ // resolved at backend compile time. Runtime dynamism is not currently
+ // supported. Checking this requirement is met is delegated to backends.
+ return success();
+
+ // When targeting TOSA 1.0 or below, dynamic dims are not supported
+ if (dimIsDynamic)
return op->emitOpError() << "failed level check: " << operandOrResult
- << " shape dimension cannot be dynamic";
+ << " shape dimension cannot be dynamic when"
+ << " targeting TOSA specification version 1.0"
+ << " or below";
}
int64_t element_bits = tosa::getBitWidth(getElementTypeOrSelf(type));
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
index 51089df238b84..fa40913834b5b 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p0-invalid.mlir
@@ -19,3 +19,11 @@ func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>,
%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32>
return %0 : tensor<1x14x28xf32>
}
+
+// -----
+
+func.func @test_dyanmic_dims(%arg0: tensor<?x8x16xi8>) -> tensor<?x16xi32> {
+ // expected-error at +1 {{'tosa.argmax' op failed level check: operand shape dimension cannot be dynamic when targeting TOSA specification version 1.0 or below}}
+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<?x8x16xi8>) -> tensor<?x16xi32>
+ return %0 : tensor<?x16xi32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index c285ae3cf44ee..f0bdc645559b7 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -156,3 +156,11 @@ func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: te
%0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
return %0 : tensor<2x52x3xf32>
}
+
+// -----
+
+// CHECK-LABEL: test_dynamic_dims
+func.func @test_dynamic_dims(%arg0: tensor<?x8x16xi8>) -> tensor<?x16xi32> {
+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<?x8x16xi8>) -> tensor<?x16xi32>
+ return %0 : tensor<?x16xi32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/171463
More information about the Mlir-commits
mailing list