[Mlir-commits] [mlir] [mlir][tosa] Fix unranked tosa canonicalizations (PR #188188)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 24 00:27:15 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Hocky Yudhiono (hockyy)

<details>
<summary>Changes</summary>

This MR fixes #<!-- -->188187 multiple crashes in TOSA constant/canonicalization folding when ops use unranked tensor types (`tensor<*x...>`), which TOSA allows. Some folders assumed ranked tensor types and used them unconditionally. With unranked result/input types, canonicalization could crash (segfault/assert) instead of safely bailing out.

---
Full diff: https://github.com/llvm/llvm-project/pull/188188.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+7-5) 
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+12) 
- (modified) mlir/test/Dialect/Tosa/constant_folding.mlir (+36) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b622cbedec1dc..73c83b370f5db 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1502,7 +1502,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
   auto rhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
-  if (!lhsAttr || !rhsAttr)
+  if (!resultTy || !lhsAttr || !rhsAttr)
     return {};
 
   return binaryFolder<GreaterFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
@@ -1515,7 +1515,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
   auto rhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
-  if (!lhsAttr || !rhsAttr)
+  if (!resultTy || !lhsAttr || !rhsAttr)
     return {};
 
   return binaryFolder<GreaterEqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
@@ -1538,7 +1538,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
     return DenseElementsAttr::get(resultTy, true);
   }
 
-  if (!lhsAttr || !rhsAttr)
+  if (!resultTy || !lhsAttr || !rhsAttr)
     return {};
 
   return binaryFolder<EqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
@@ -1736,8 +1736,10 @@ OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
   }
 
   auto input = getInput();
-  auto inputTy = llvm::cast<RankedTensorType>(input.getType());
-  auto resultTy = llvm::cast<RankedTensorType>(getType());
+  auto inputTy = llvm::dyn_cast<RankedTensorType>(input.getType());
+  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
+  if (!inputTy || !resultTy)
+    return {};
   if (inputTy != resultTy)
     return {};
 
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 52098413f18d9..1afa513621b39 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -906,6 +906,18 @@ func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x
 
 // -----
 
+// CHECK-LABEL: @dont_fold_resize_unranked
+func.func @dont_fold_resize_unranked(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
+  // CHECK: tosa.resize
+  %scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+  %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<*xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %resize : tensor<*xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @canonicalize_concat_slice_final_axis
 // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12x1xf32>, %[[VAL_1:.*]]: tensor<1x12x12x1xf32>
 // CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index dc040d3231964..0868aed17d376 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -28,6 +28,42 @@ func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tens
 
 // -----
 
+// CHECK-LABEL: func @try_fold_equal_with_unranked_tensor_constants
+func.func @try_fold_equal_with_unranked_tensor_constants() {
+  // CHECK: tosa.equal
+  // CHECK-NEXT: return
+  %lhs = arith.constant dense<1> : tensor<1xi32>
+  %rhs = arith.constant dense<2> : tensor<1xi32>
+  %0 = tosa.equal %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @try_fold_greater_with_unranked_tensor_constants
+func.func @try_fold_greater_with_unranked_tensor_constants() {
+  // CHECK: tosa.greater
+  // CHECK-NEXT: return
+  %lhs = arith.constant dense<1> : tensor<1xi32>
+  %rhs = arith.constant dense<2> : tensor<1xi32>
+  %0 = tosa.greater %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @try_fold_greater_equal_with_unranked_tensor_constants
+func.func @try_fold_greater_equal_with_unranked_tensor_constants() {
+  // CHECK: tosa.greater_equal
+  // CHECK-NEXT: return
+  %lhs = arith.constant dense<1> : tensor<1xi32>
+  %rhs = arith.constant dense<2> : tensor<1xi32>
+  %0 = tosa.greater_equal %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @fold_add_zero_rhs_f32
 func.func @fold_add_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {values = dense<0.0> : tensor<f32>} : () -> tensor<f32>

``````````

</details>


https://github.com/llvm/llvm-project/pull/188188


More information about the Mlir-commits mailing list