[Mlir-commits] [mlir] [mlir][tosa] Fix unranked tosa canonicalizations crashes (PR #188188)

Hocky Yudhiono llvmlistbot at llvm.org
Tue Mar 24 00:46:56 PDT 2026


https://github.com/hockyy updated https://github.com/llvm/llvm-project/pull/188188

>From 4045185b9c26cc9815fe0eb50d36bf5423191b29 Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Tue, 24 Mar 2026 15:46:28 +0800
Subject: [PATCH] [mlir][tosa] Fix unranked tosa canonicalizations crashes

---
 .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 50 ++++++++++++++-----
 .../Dialect/Tosa/Transforms/TosaFolders.cpp   |  4 ++
 mlir/test/Dialect/Tosa/canonicalize.mlir      | 20 ++++++++
 mlir/test/Dialect/Tosa/constant_folding.mlir  | 37 ++++++++++++++
 .../Tosa/tosa-layerwise-constant-fold.mlir    | 20 ++++++++
 5 files changed, 118 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b622cbedec1dc..905cb60c464e8 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -31,6 +31,21 @@
 using namespace mlir;
 using namespace mlir::tosa;
 
+namespace {
+template <typename OpTy>
+static ShapedType getStaticResultShapeForDenseFold(OpTy op) {
+  auto ty = llvm::dyn_cast<ShapedType>(op.getType());
+  if (!ty || !ty.hasStaticShape())
+    return {};
+  return ty;
+}
+
+template <typename OpTy>
+static OpFoldResult foldToInputIfTypeMatches(OpTy op, Value input) {
+  return input.getType() == op.getType() ? OpFoldResult(input) : OpFoldResult{};
+}
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Operator Canonicalizers.
 //===----------------------------------------------------------------------===//
@@ -424,6 +439,8 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
                                 PatternRewriter &rewriter) const override {
     Value input = op.getInput();
     auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
+    if (!inputType)
+      return failure();
     auto inputElementType = inputType.getElementType();
 
     if (isa<FloatType>(inputElementType)) {
@@ -843,6 +860,8 @@ struct SliceDynamicSizeCanonicalization
   LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
                                 PatternRewriter &rewriter) const override {
     ShapedType resultType = cast<ShapedType>(sliceOp.getType());
+    if (!resultType.hasRank())
+      return rewriter.notifyMatchFailure(sliceOp, "output must be ranked");
 
     ElementsAttr sizeElems;
     if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
@@ -1502,7 +1521,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
   auto rhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
-  if (!lhsAttr || !rhsAttr)
+  if (!resultTy || !lhsAttr || !rhsAttr)
     return {};
 
   return binaryFolder<GreaterFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
@@ -1515,7 +1534,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
   auto rhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
-  if (!lhsAttr || !rhsAttr)
+  if (!resultTy || !lhsAttr || !rhsAttr)
     return {};
 
   return binaryFolder<GreaterEqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
@@ -1538,7 +1557,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
     return DenseElementsAttr::get(resultTy, true);
   }
 
-  if (!lhsAttr || !rhsAttr)
+  if (!resultTy || !lhsAttr || !rhsAttr)
     return {};
 
   return binaryFolder<EqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
@@ -1553,7 +1572,9 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
     return {};
 
   auto inTy = llvm::cast<ShapedType>(getInput().getType());
-  auto outTy = llvm::cast<ShapedType>(getType());
+  auto outTy = getStaticResultShapeForDenseFold(*this);
+  if (!outTy)
+    return {};
   auto inETy = inTy.getElementType();
   auto outETy = outTy.getElementType();
 
@@ -1736,8 +1757,10 @@ OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
   }
 
   auto input = getInput();
-  auto inputTy = llvm::cast<RankedTensorType>(input.getType());
-  auto resultTy = llvm::cast<RankedTensorType>(getType());
+  auto inputTy = llvm::dyn_cast<RankedTensorType>(input.getType());
+  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
+  if (!inputTy || !resultTy)
+    return {};
   if (inputTy != resultTy)
     return {};
 
@@ -1756,7 +1779,7 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
   // If the dim-length is 1, tosa.reverse is a no-op.
   if (operandTy.hasRank() &&
       (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
-    return operand;
+    return foldToInputIfTypeMatches(*this, operand);
 
   return {};
 }
@@ -1920,7 +1943,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
   if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
     return {};
 
-  return getInput1();
+  return foldToInputIfTypeMatches(*this, getInput1());
 }
 
 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
@@ -1953,15 +1976,14 @@ OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
     return {};
   }
 
-  return definingOp.getInput1();
+  return foldToInputIfTypeMatches(*this, definingOp.getInput1());
 }
 
 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
   auto input = getInput1();
   // Element-wise abs(abs(x)) = abs(x)
-  if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
-    return input;
-  }
+  if (input.getDefiningOp<tosa::AbsOp>())
+    return foldToInputIfTypeMatches(*this, input);
 
   return {};
 }
@@ -2008,7 +2030,9 @@ OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
   if (!inputAttr || !inputAttr.isSplat())
     return {};
 
-  auto shapeType = llvm::cast<ShapedType>(getType());
+  auto shapeType = getStaticResultShapeForDenseFold(*this);
+  if (!shapeType)
+    return {};
   if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
     auto floatVal = inputAttr.getSplatValue<APFloat>();
     return DenseElementsAttr::get(shapeType,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 0a035bbd3df00..fcc1ca4f4d1ec 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -242,6 +242,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
   LogicalResult matchAndRewrite(tosa::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
     auto outputType = cast<ShapedType>(op.getType());
+    if (!outputType.hasStaticShape())
+      return failure();
     // TOSA supports quantized types.
     if (!outputType.getElementType().isIntOrIndexOrFloat())
       return failure();
@@ -299,6 +301,8 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
     auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
         inputValues, &ReciprocalOp::calcOneElement,
         cast<FloatType>(inputValues.getElementType()));
+    if (newTensor.getType() != recip.getType())
+      return failure();
 
     // Replace the use of the reciprocal with the transformed tensor
     rewriter.replaceOpWithNewOp<ConstOp>(recip, newTensor.getType(), newTensor);
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 52098413f18d9..2f1f583cf1c3a 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -906,6 +906,26 @@ func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x
 
 // -----
 
+// CHECK-LABEL: @dont_canonicalize_unranked_clamp
+func.func @dont_canonicalize_unranked_clamp(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
+  // CHECK: tosa.clamp
+  %0 = tosa.clamp %arg0 {min_val = 0.0 : f32, max_val = 1.0 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dont_canonicalize_unranked_slice_dynamic_size
+func.func @dont_canonicalize_unranked_slice_dynamic_size(%arg0: tensor<1x4xf32>) -> tensor<*xf32> {
+  // CHECK: tosa.slice
+  %start = tosa.const_shape {values = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %size = tosa.const_shape {values = dense<[1, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %0 = tosa.slice %arg0, %start, %size : (tensor<1x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @canonicalize_concat_slice_final_axis
 // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12x1xf32>, %[[VAL_1:.*]]: tensor<1x12x12x1xf32>
 // CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index dc040d3231964..db643ff5f23f4 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -28,6 +28,43 @@ func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tens
 
 // -----
 
+// CHECK-LABEL: func @try_fold_unranked_constant_results
+func.func @try_fold_unranked_constant_results() {
+  // CHECK: tosa.equal
+  // CHECK: tosa.greater
+  // CHECK: tosa.greater_equal
+  // CHECK: tosa.cast
+  // CHECK: tosa.reciprocal
+  // CHECK-NEXT: return
+  %lhs = arith.constant dense<1> : tensor<1xi32>
+  %rhs = arith.constant dense<2> : tensor<1xi32>
+  %f = arith.constant dense<2.0> : tensor<1xf32>
+  %0 = tosa.equal %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
+  %1 = tosa.greater %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
+  %2 = tosa.greater_equal %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
+  %3 = tosa.cast %lhs : (tensor<1xi32>) -> tensor<*xf32>
+  %4 = tosa.reciprocal %f : (tensor<1xf32>) -> tensor<*xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @try_fold_unranked_identity_results
+func.func @try_fold_unranked_identity_results(%arg0: tensor<1xf32>) {
+  // CHECK: tosa.transpose
+  // CHECK: tosa.reverse
+  // CHECK: tosa.abs
+  // CHECK: tosa.abs
+  // CHECK-NEXT: return
+  %0 = tosa.transpose %arg0 { perms = array<i32: 0> } : (tensor<1xf32>) -> tensor<*xf32>
+  %1 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<1xf32>) -> tensor<*xf32>
+  %2 = tosa.abs %arg0 : (tensor<1xf32>) -> tensor<1xf32>
+  %3 = tosa.abs %2 : (tensor<1xf32>) -> tensor<*xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @fold_add_zero_rhs_f32
 func.func @fold_add_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {values = dense<0.0> : tensor<f32>} : () -> tensor<f32>
diff --git a/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir b/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir
index d95d267e8c907..5c56728f2fe71 100644
--- a/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir
@@ -46,6 +46,26 @@ func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
 
 // -----
 
+// CHECK-LABEL: @transpose_nofold_unranked_result
+func.func @transpose_nofold_unranked_result() -> tensor<*xf32> {
+  %input = "tosa.const"() {values = dense<1.0> : tensor<1xf32>} : () -> tensor<1xf32>
+  // CHECK: tosa.reshape
+  %1 = tosa.transpose %input { perms = array<i32: 0> }: (tensor<1xf32>) -> tensor<*xf32>
+  return %1 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @reciprocal_nofold_unranked_result
+func.func @reciprocal_nofold_unranked_result() -> tensor<*xf32> {
+  %input = "tosa.const"() {values = dense<2.0> : tensor<1xf32>} : () -> tensor<1xf32>
+  // CHECK: tosa.reciprocal
+  %1 = tosa.reciprocal %input : (tensor<1xf32>) -> tensor<*xf32>
+  return %1 : tensor<*xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @transpose_fold_splat
 func.func @transpose_fold_splat() -> tensor<3x2xf32> {
   %input = "tosa.const"() {values = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>



More information about the Mlir-commits mailing list