[Mlir-commits] [mlir] [mlir][tosa] Fix unranked tosa canonicalizations (PR #188188)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 24 00:27:16 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Hocky Yudhiono (hockyy)
<details>
<summary>Changes</summary>
This MR fixes #<!-- -->188187 multiple crashes in TOSA constant/canonicalization folding when ops use unranked tensor types (`tensor<*x...>`), which TOSA allows. Some folders assumed ranked tensor types and used them unconditionally. With unranked result/input types, canonicalization could crash (segfault/assert) instead of safely bailing out.
---
Full diff: https://github.com/llvm/llvm-project/pull/188188.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+7-5)
- (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+12)
- (modified) mlir/test/Dialect/Tosa/constant_folding.mlir (+36)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b622cbedec1dc..73c83b370f5db 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1502,7 +1502,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
auto rhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
- if (!lhsAttr || !rhsAttr)
+ if (!resultTy || !lhsAttr || !rhsAttr)
return {};
return binaryFolder<GreaterFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
@@ -1515,7 +1515,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
auto rhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
- if (!lhsAttr || !rhsAttr)
+ if (!resultTy || !lhsAttr || !rhsAttr)
return {};
return binaryFolder<GreaterEqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
@@ -1538,7 +1538,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
return DenseElementsAttr::get(resultTy, true);
}
- if (!lhsAttr || !rhsAttr)
+ if (!resultTy || !lhsAttr || !rhsAttr)
return {};
return binaryFolder<EqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
@@ -1736,8 +1736,10 @@ OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
}
auto input = getInput();
- auto inputTy = llvm::cast<RankedTensorType>(input.getType());
- auto resultTy = llvm::cast<RankedTensorType>(getType());
+ auto inputTy = llvm::dyn_cast<RankedTensorType>(input.getType());
+ auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
+ if (!inputTy || !resultTy)
+ return {};
if (inputTy != resultTy)
return {};
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 52098413f18d9..1afa513621b39 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -906,6 +906,18 @@ func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x
// -----
+// CHECK-LABEL: @dont_fold_resize_unranked
+func.func @dont_fold_resize_unranked(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
+ // CHECK: tosa.resize
+ %scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+ %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<*xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+ return %resize : tensor<*xf32>
+}
+
+// -----
+
// CHECK-LABEL: @canonicalize_concat_slice_final_axis
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12x1xf32>, %[[VAL_1:.*]]: tensor<1x12x12x1xf32>
// CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
diff --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index dc040d3231964..0868aed17d376 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -28,6 +28,42 @@ func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tens
// -----
+// CHECK-LABEL: func @try_fold_equal_with_unranked_tensor_constants
+func.func @try_fold_equal_with_unranked_tensor_constants() {
+ // CHECK: tosa.equal
+ // CHECK-NEXT: return
+ %lhs = arith.constant dense<1> : tensor<1xi32>
+ %rhs = arith.constant dense<2> : tensor<1xi32>
+ %0 = tosa.equal %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @try_fold_greater_with_unranked_tensor_constants
+func.func @try_fold_greater_with_unranked_tensor_constants() {
+ // CHECK: tosa.greater
+ // CHECK-NEXT: return
+ %lhs = arith.constant dense<1> : tensor<1xi32>
+ %rhs = arith.constant dense<2> : tensor<1xi32>
+ %0 = tosa.greater %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @try_fold_greater_equal_with_unranked_tensor_constants
+func.func @try_fold_greater_equal_with_unranked_tensor_constants() {
+ // CHECK: tosa.greater_equal
+ // CHECK-NEXT: return
+ %lhs = arith.constant dense<1> : tensor<1xi32>
+ %rhs = arith.constant dense<2> : tensor<1xi32>
+ %0 = tosa.greater_equal %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @fold_add_zero_rhs_f32
func.func @fold_add_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
%zero = "tosa.const"() {values = dense<0.0> : tensor<f32>} : () -> tensor<f32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/188188
More information about the Mlir-commits
mailing list