[Mlir-commits] [mlir] [mlir][tosa] Disallow inferable dim in reshape/slice validation (PR #182472)
Luke Hutton
llvmlistbot at llvm.org
Mon Mar 2 02:19:53 PST 2026
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/182472
>From 14d55503e9e922448c57310e8ee0b337504ff217 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 20 Feb 2026 10:15:40 +0000
Subject: [PATCH 1/2] [mlir][tosa] Disallow inferable dim in reshape/slice
validation
This commit ensures that the validation pass checks for the presence of
inferable dimensions (represented by -1) in reshape and slice operations.
These are not compliant with the TOSA specification. If such dimensions
are found, an error message is emitted indicating that they do not
conform to the TOSA specification.
Change-Id: Ie507cee4e47659a60367d4986733b1666df763be
---
.../Tosa/Transforms/TosaValidation.cpp | 45 ++++++++++++++++++-
mlir/test/Dialect/Tosa/error_if_check.mlir | 32 +++++++++++++
2 files changed, 76 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 97572fdd13953..92d25e1824987 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -26,6 +26,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
@@ -1214,6 +1215,47 @@ LogicalResult checkErrorIfPad(Operation *op) {
return success();
}
+LogicalResult checkErrorIfReshape(Operation *op) {
+ auto reshapeOp = dyn_cast<tosa::ReshapeOp>(op);
+ if (!reshapeOp)
+ return success();
+
+ constexpr int64_t kInferableDim = -1;
+ SmallVector<int64_t> shapeValues;
+ if (!tosa::getConstShapeValues(reshapeOp.getShape().getDefiningOp(),
+ shapeValues))
+ return success();
+
+ if (llvm::is_contained(shapeValues, kInferableDim))
+ return op->emitOpError("shape input contains inferable dimension (-1) "
+ "which does not conform to the TOSA specification");
+
+ return success();
+}
+
+LogicalResult checkErrorIfSlice(Operation *op) {
+ auto sliceOp = dyn_cast<tosa::SliceOp>(op);
+ if (!sliceOp)
+ return success();
+
+ constexpr int64_t kInferableDim = -1;
+ SmallVector<int64_t> startValues;
+ SmallVector<int64_t> sizeValues;
+ const bool hasStartValues = tosa::getConstShapeValues(
+ sliceOp.getStart().getDefiningOp(), startValues);
+ const bool hasSizeValues =
+ tosa::getConstShapeValues(sliceOp.getSize().getDefiningOp(), sizeValues);
+
+ if (hasStartValues && llvm::is_contained(startValues, kInferableDim))
+ return op->emitOpError("start input contains inferable dimension (-1) "
+ "which does not conform to the TOSA specification");
+ if (hasSizeValues && llvm::is_contained(sizeValues, kInferableDim))
+ return op->emitOpError("size input contains inferable dimension (-1) which "
+ "does not conform to the TOSA specification");
+
+ return success();
+}
+
static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
return llvm::all_of(op->getOperands(), [&](auto operand) {
Region *operandRegion = operand.getParentRegion();
@@ -1321,7 +1363,8 @@ LogicalResult checkErrorIfScatter(Operation *op) {
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
if (failed(checkErrorIfResize(op)) || failed(checkErrorIfMul(op)) ||
failed(checkErrorIfTable(op)) || failed(checkErrorIfRescale(op)) ||
- failed(checkErrorIfPad(op)) || failed(checkErrorIfCondIf(op)) ||
+ failed(checkErrorIfPad(op)) || failed(checkErrorIfReshape(op)) ||
+ failed(checkErrorIfSlice(op)) || failed(checkErrorIfCondIf(op)) ||
failed(checkErrorIfWhileLoop(op)) || failed(checkErrorIfScatter(op)))
return failure();
return success();
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 334f52a3407c7..ab31657019674 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -86,6 +86,38 @@ func.func @test_resize_invalid_boarder_x(%arg0: tensor<1x8x8x8xf32>) -> tensor<?
// -----
+// CHECK-LABEL: test_reshape_inferable_dim
+func.func @test_reshape_inferable_dim(%arg0: tensor<4xf32>) -> tensor<?x2xf32> {
+ %shape = tosa.const_shape { values = dense<[-1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.reshape' op shape input contains inferable dimension (-1) which does not conform to the TOSA specification}}
+ %0 = tosa.reshape %arg0, %shape : (tensor<4xf32>, !tosa.shape<2>) -> tensor<?x2xf32>
+ return %0 : tensor<?x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_slice_inferable_start
+func.func @test_slice_inferable_start(%arg0: tensor<4x4xf32>) -> tensor<2x2xf32> {
+ %start = tosa.const_shape { values = dense<[-1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %size = tosa.const_shape { values = dense<[2, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.slice' op start input contains inferable dimension (-1) which does not conform to the TOSA specification}}
+ %0 = tosa.slice %arg0, %start, %size : (tensor<4x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<2x2xf32>
+ return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_slice_inferable_size
+func.func @test_slice_inferable_size(%arg0: tensor<4x4xf32>) -> tensor<?x2xf32> {
+ %start = tosa.const_shape { values = dense<[0, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %size = tosa.const_shape { values = dense<[-1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
+ // expected-error at +1 {{'tosa.slice' op size input contains inferable dimension (-1) which does not conform to the TOSA specification}}
+ %0 = tosa.slice %arg0, %start, %size : (tensor<4x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x2xf32>
+ return %0 : tensor<?x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: test_mul_negative_shift
func.func @test_mul_negative_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> {
%shift = "tosa.const" () { values = dense<-1> : tensor<1xi8> } : () -> tensor<1xi8>
>From cc781766ad56c2fa3c35c429a326f1990fdaf161 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Mon, 2 Mar 2026 10:16:26 +0000
Subject: [PATCH 2/2] Move `kInferableDimSize` to TosaOps.h
Moves constant to a common header.
Change-Id: I39e1d11c55816756239cc262b458f825ea337b0b
---
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 6 +++++
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 12 ++++-----
.../Tosa/Transforms/TosaValidation.cpp | 25 +++++++++++--------
3 files changed, 26 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index c5196349dbb1a..f50914436b63d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -146,6 +146,12 @@ namespace tosa {
bool isa_tosa_shape_type(mlir::Type t);
+/// Represents a dimension in the shape of a tensor that can be inferred
+/// based on the other provided dimensions. For example, in a reshape
+/// operation, -1 can be used to indicate a size that is the remainder
+/// of the other dimensions.
+constexpr int64_t kInferableDimSize = -1;
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 55415f201f216..6072aecdf347b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2280,8 +2280,6 @@ LogicalResult tosa::SliceOp::verify() {
return emitOpError("length of size is not equal to rank of input shape");
}
- constexpr int64_t kInferableDimSize = -1;
-
SmallVector<int64_t> startValues;
tosa::getConstShapeValues(start.getDefiningOp(), startValues);
if (startValues.size()) {
@@ -2626,9 +2624,10 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
return mlir::success();
}
- int missingDims = llvm::count(shapeValues, -1);
+ int missingDims = llvm::count(shapeValues, kInferableDimSize);
if (missingDims > 1)
- return emitOpError() << "expected at most one target dimension to be -1";
+ return emitOpError() << "expected at most one target dimension to be "
+ << kInferableDimSize;
const auto outputType = dyn_cast<RankedTensorType>(getType());
if (!outputType)
@@ -2639,11 +2638,12 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
for (auto [newShapeDim, outputShapeDim] :
zip(shapeValues, outputType.getShape())) {
- if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
+ if (newShapeDim != kInferableDimSize &&
+ newShapeDim != ShapedType::kDynamic &&
outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
return emitOpError() << "new shape is inconsistent with result shape";
- if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
+ if (newShapeDim != ShapedType::kDynamic && newShapeDim < kInferableDimSize)
return emitOpError() << "new shape has invalid tensor dimension size "
<< newShapeDim;
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 92d25e1824987..35b4b862dbff7 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1220,15 +1220,16 @@ LogicalResult checkErrorIfReshape(Operation *op) {
if (!reshapeOp)
return success();
- constexpr int64_t kInferableDim = -1;
SmallVector<int64_t> shapeValues;
if (!tosa::getConstShapeValues(reshapeOp.getShape().getDefiningOp(),
shapeValues))
return success();
- if (llvm::is_contained(shapeValues, kInferableDim))
- return op->emitOpError("shape input contains inferable dimension (-1) "
- "which does not conform to the TOSA specification");
+ if (llvm::is_contained(shapeValues, kInferableDimSize))
+ return op->emitOpError("shape input contains inferable dimension (")
+ << kInferableDimSize
+ << ") "
+ "which does not conform to the TOSA specification";
return success();
}
@@ -1238,7 +1239,6 @@ LogicalResult checkErrorIfSlice(Operation *op) {
if (!sliceOp)
return success();
- constexpr int64_t kInferableDim = -1;
SmallVector<int64_t> startValues;
SmallVector<int64_t> sizeValues;
const bool hasStartValues = tosa::getConstShapeValues(
@@ -1246,12 +1246,15 @@ LogicalResult checkErrorIfSlice(Operation *op) {
const bool hasSizeValues =
tosa::getConstShapeValues(sliceOp.getSize().getDefiningOp(), sizeValues);
- if (hasStartValues && llvm::is_contained(startValues, kInferableDim))
- return op->emitOpError("start input contains inferable dimension (-1) "
- "which does not conform to the TOSA specification");
- if (hasSizeValues && llvm::is_contained(sizeValues, kInferableDim))
- return op->emitOpError("size input contains inferable dimension (-1) which "
- "does not conform to the TOSA specification");
+ if (hasStartValues && llvm::is_contained(startValues, kInferableDimSize))
+ return op->emitOpError("start input contains inferable dimension (")
+ << kInferableDimSize
+ << ") which does not conform to the TOSA specification";
+ if (hasSizeValues && llvm::is_contained(sizeValues, kInferableDimSize))
+ return op->emitOpError("size input contains inferable dimension (")
+ << kInferableDimSize
+ << ") which "
+ "does not conform to the TOSA specification";
return success();
}
More information about the Mlir-commits
mailing list