[Mlir-commits] [mlir] 977355b - [mlir][tosa] Disallow inferable dim in reshape/slice validation (#182472)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 2 05:55:34 PST 2026


Author: Luke Hutton
Date: 2026-03-02T13:55:29Z
New Revision: 977355be38d10c62c3d5498a64618ec926348252

URL: https://github.com/llvm/llvm-project/commit/977355be38d10c62c3d5498a64618ec926348252
DIFF: https://github.com/llvm/llvm-project/commit/977355be38d10c62c3d5498a64618ec926348252.diff

LOG: [mlir][tosa] Disallow inferable dim in reshape/slice validation (#182472)

This commit ensures that the validation pass checks for the presence of
inferable dimensions (represented by -1) in reshape and slice
operations. These are not compliant with the TOSA specification. If such
dimensions are found, an error message is emitted indicating that they
do not conform to the TOSA specification.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
    mlir/test/Dialect/Tosa/error_if_check.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index c5196349dbb1a..f50914436b63d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -146,6 +146,12 @@ namespace tosa {
 
 bool isa_tosa_shape_type(mlir::Type t);
 
+/// Represents a dimension in the shape of a tensor that can be inferred
+/// based on the other provided dimensions. For example, in a reshape
+/// operation, -1 can be used to indicate a size that is the remainder
+/// of the other dimensions.
+constexpr int64_t kInferableDimSize = -1;
+
 } // namespace tosa
 
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 55415f201f216..6072aecdf347b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2280,8 +2280,6 @@ LogicalResult tosa::SliceOp::verify() {
       return emitOpError("length of size is not equal to rank of input shape");
   }
 
-  constexpr int64_t kInferableDimSize = -1;
-
   SmallVector<int64_t> startValues;
   tosa::getConstShapeValues(start.getDefiningOp(), startValues);
   if (startValues.size()) {
@@ -2626,9 +2624,10 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
     return mlir::success();
   }
 
-  int missingDims = llvm::count(shapeValues, -1);
+  int missingDims = llvm::count(shapeValues, kInferableDimSize);
   if (missingDims > 1)
-    return emitOpError() << "expected at most one target dimension to be -1";
+    return emitOpError() << "expected at most one target dimension to be "
+                         << kInferableDimSize;
 
   const auto outputType = dyn_cast<RankedTensorType>(getType());
   if (!outputType)
@@ -2639,11 +2638,12 @@ llvm::LogicalResult tosa::ReshapeOp::verify() {
 
   for (auto [newShapeDim, outputShapeDim] :
        zip(shapeValues, outputType.getShape())) {
-    if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic &&
+    if (newShapeDim != kInferableDimSize &&
+        newShapeDim != ShapedType::kDynamic &&
         outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
       return emitOpError() << "new shape is inconsistent with result shape";
 
-    if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1)
+    if (newShapeDim != ShapedType::kDynamic && newShapeDim < kInferableDimSize)
       return emitOpError() << "new shape has invalid tensor dimension size "
                            << newShapeDim;
   }

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 97572fdd13953..35b4b862dbff7 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -26,6 +26,7 @@
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/FormatVariadic.h"
 
@@ -1214,6 +1215,50 @@ LogicalResult checkErrorIfPad(Operation *op) {
   return success();
 }
 
+LogicalResult checkErrorIfReshape(Operation *op) {
+  auto reshapeOp = dyn_cast<tosa::ReshapeOp>(op);
+  if (!reshapeOp)
+    return success();
+
+  SmallVector<int64_t> shapeValues;
+  if (!tosa::getConstShapeValues(reshapeOp.getShape().getDefiningOp(),
+                                 shapeValues))
+    return success();
+
+  if (llvm::is_contained(shapeValues, kInferableDimSize))
+    return op->emitOpError("shape input contains inferable dimension (")
+           << kInferableDimSize
+           << ") "
+              "which does not conform to the TOSA specification";
+
+  return success();
+}
+
+LogicalResult checkErrorIfSlice(Operation *op) {
+  auto sliceOp = dyn_cast<tosa::SliceOp>(op);
+  if (!sliceOp)
+    return success();
+
+  SmallVector<int64_t> startValues;
+  SmallVector<int64_t> sizeValues;
+  const bool hasStartValues = tosa::getConstShapeValues(
+      sliceOp.getStart().getDefiningOp(), startValues);
+  const bool hasSizeValues =
+      tosa::getConstShapeValues(sliceOp.getSize().getDefiningOp(), sizeValues);
+
+  if (hasStartValues && llvm::is_contained(startValues, kInferableDimSize))
+    return op->emitOpError("start input contains inferable dimension (")
+           << kInferableDimSize
+           << ") which does not conform to the TOSA specification";
+  if (hasSizeValues && llvm::is_contained(sizeValues, kInferableDimSize))
+    return op->emitOpError("size input contains inferable dimension (")
+           << kInferableDimSize
+           << ") which "
+              "does not conform to the TOSA specification";
+
+  return success();
+}
+
 static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
   return llvm::all_of(op->getOperands(), [&](auto operand) {
     Region *operandRegion = operand.getParentRegion();
@@ -1321,7 +1366,8 @@ LogicalResult checkErrorIfScatter(Operation *op) {
 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
   if (failed(checkErrorIfResize(op)) || failed(checkErrorIfMul(op)) ||
       failed(checkErrorIfTable(op)) || failed(checkErrorIfRescale(op)) ||
-      failed(checkErrorIfPad(op)) || failed(checkErrorIfCondIf(op)) ||
+      failed(checkErrorIfPad(op)) || failed(checkErrorIfReshape(op)) ||
+      failed(checkErrorIfSlice(op)) || failed(checkErrorIfCondIf(op)) ||
       failed(checkErrorIfWhileLoop(op)) || failed(checkErrorIfScatter(op)))
     return failure();
   return success();

diff  --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 334f52a3407c7..ab31657019674 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -86,6 +86,38 @@ func.func @test_resize_invalid_boarder_x(%arg0: tensor<1x8x8x8xf32>) -> tensor<?
 
 // -----
 
+// CHECK-LABEL: test_reshape_inferable_dim
+func.func @test_reshape_inferable_dim(%arg0: tensor<4xf32>) -> tensor<?x2xf32> {
+  %shape = tosa.const_shape { values = dense<[-1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.reshape' op shape input contains inferable dimension (-1) which does not conform to the TOSA specification}}
+  %0 = tosa.reshape %arg0, %shape : (tensor<4xf32>, !tosa.shape<2>) -> tensor<?x2xf32>
+  return %0 : tensor<?x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_slice_inferable_start
+func.func @test_slice_inferable_start(%arg0: tensor<4x4xf32>) -> tensor<2x2xf32> {
+  %start = tosa.const_shape { values = dense<[-1, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %size = tosa.const_shape { values = dense<[2, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.slice' op start input contains inferable dimension (-1) which does not conform to the TOSA specification}}
+  %0 = tosa.slice %arg0, %start, %size : (tensor<4x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<2x2xf32>
+  return %0 : tensor<2x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: test_slice_inferable_size
+func.func @test_slice_inferable_size(%arg0: tensor<4x4xf32>) -> tensor<?x2xf32> {
+  %start = tosa.const_shape { values = dense<[0, 1]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %size = tosa.const_shape { values = dense<[-1, 2]> : tensor<2xindex> } : () -> !tosa.shape<2>
+  // expected-error at +1 {{'tosa.slice' op size input contains inferable dimension (-1) which does not conform to the TOSA specification}}
+  %0 = tosa.slice %arg0, %start, %size : (tensor<4x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<?x2xf32>
+  return %0 : tensor<?x2xf32>
+}
+
+// -----
+
 // CHECK-LABEL: test_mul_negative_shift
 func.func @test_mul_negative_shift(%arg0: tensor<1x8x8x8xi32>, %arg1: tensor<1x8x8x8xi32>) -> tensor<1x8x8x8xi32> {
   %shift = "tosa.const" () { values = dense<-1> : tensor<1xi8> } : () -> tensor<1xi8>


        


More information about the Mlir-commits mailing list