[Mlir-commits] [mlir] a7bc628 - [mlir][tosa] Harden folds/canonicalizations for unranked and dynamic shapes (#188188)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 30 02:23:07 PDT 2026


Author: Hocky Yudhiono
Date: 2026-03-30T10:23:01+01:00
New Revision: a7bc628e44e69b43fbaf135a569691bf09fc083f

URL: https://github.com/llvm/llvm-project/commit/a7bc628e44e69b43fbaf135a569691bf09fc083f
DIFF: https://github.com/llvm/llvm-project/commit/a7bc628e44e69b43fbaf135a569691bf09fc083f.diff

LOG: [mlir][tosa] Harden folds/canonicalizations for unranked and dynamic shapes (#188188)

This MR fixes #188187 and #187974. Tighten TOSA constant folding and
identity-style folds so they do not produce invalid or type-incorrect
results when the op’s result type is unranked, rank-dynamic, or
otherwise not a static `RankedTensorType`. Several paths previously
assumed ranked/static shapes or folded through to the operand without
checking that the result type matched the value being returned.

`DenseElementsAttr::get`, `SplatElementsAttr::get` and similar builders
need a static shape; folding with `tensor<*xT>` or dynamic dims must not
fabricate dense attributes with the wrong shape.

Returning the operand from a “no-op” fold is only valid when
`operand.getType() == op.getType()`; otherwise the folder would change
the IR’s type semantics (e.g. ranked → unranked). Which in the bigger
pipeline supposed to be handled by `-tosa-infer-shapes`

Assisted-by: CLion code completion, GPT 5.3 - Codex

---------

Co-authored-by: Sayan Saha <sayans at mathworks.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
    mlir/test/Dialect/Tosa/canonicalize.mlir
    mlir/test/Dialect/Tosa/constant_folding.mlir
    mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index b72838829ce7f..ecd485ae8d641 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -31,6 +31,12 @@
 using namespace mlir;
 using namespace mlir::tosa;
 
+namespace {
+OpFoldResult foldToInputIfTypeMatches(Type typeRef, Value input) {
+  return input.getType() == typeRef ? OpFoldResult(input) : OpFoldResult{};
+}
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // Operator Canonicalizers.
 //===----------------------------------------------------------------------===//
@@ -423,7 +429,7 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
   LogicalResult matchAndRewrite(tosa::ClampOp op,
                                 PatternRewriter &rewriter) const override {
     Value input = op.getInput();
-    auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
+    auto inputType = llvm::cast<ShapedType>(op.getInput().getType());
     auto inputElementType = inputType.getElementType();
 
     if (isa<FloatType>(inputElementType)) {
@@ -843,6 +849,8 @@ struct SliceDynamicSizeCanonicalization
   LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
                                 PatternRewriter &rewriter) const override {
     ShapedType resultType = cast<ShapedType>(sliceOp.getType());
+    if (!resultType.hasRank())
+      return rewriter.notifyMatchFailure(sliceOp, "output must be ranked");
 
     ElementsAttr sizeElems;
     if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
@@ -995,6 +1003,9 @@ binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy,
   if (!lhs || !rhs)
     return {};
 
+  if (!returnTy.hasRank() || !returnTy.hasStaticShape())
+    return {};
+
   const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
   const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
   if (lETy != rETy)
@@ -1043,6 +1054,9 @@ static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy,
   if (!val)
     return {};
 
+  if (!returnTy.hasRank() || !returnTy.hasStaticShape())
+    return {};
+
   const auto vETy = llvm::cast<ShapedType>(val.getType()).getElementType();
 
   if (val.isSplat()) {
@@ -1555,7 +1569,7 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
-  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
+  auto resultTy = llvm::cast<ShapedType>(getType());
   auto lhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
   auto rhsAttr =
@@ -1568,7 +1582,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
-  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
+  auto resultTy = llvm::cast<ShapedType>(getType());
   auto lhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
   auto rhsAttr =
@@ -1581,7 +1595,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
-  auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
+  auto resultTy = llvm::cast<ShapedType>(getType());
   auto lhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
   auto rhsAttr =
@@ -1592,7 +1606,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
 
   // If we are comparing an integer value to itself it is always true. We
   // can not do this with float due to float values.
-  if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
+  if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy.hasRank() &&
       resultTy.hasStaticShape() && lhs == rhs) {
     return DenseElementsAttr::get(resultTy, true);
   }
@@ -1613,6 +1627,8 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
 
   auto inTy = llvm::cast<ShapedType>(getInput().getType());
   auto outTy = llvm::cast<ShapedType>(getType());
+  if (!outTy.hasRank() || !outTy.hasStaticShape())
+    return {};
   auto inETy = inTy.getElementType();
   auto outETy = outTy.getElementType();
 
@@ -1794,30 +1810,20 @@ OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
     return {};
   }
 
-  auto input = getInput();
-  auto inputTy = llvm::cast<RankedTensorType>(input.getType());
-  auto resultTy = llvm::cast<RankedTensorType>(getType());
-  if (inputTy != resultTy)
-    return {};
-
-  return input;
+  return foldToInputIfTypeMatches(getType(), getInput());
 }
 
 OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
   auto operand = getInput1();
   auto operandTy = llvm::cast<ShapedType>(operand.getType());
   auto axis = getAxis();
-  auto operandAttr =
-      llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getInput1());
-  if (operandAttr)
-    return operandAttr;
-
-  // If the dim-length is 1, tosa.reverse is a no-op.
-  if (operandTy.hasRank() &&
-      (operandTy.getRank() == 0 || operandTy.getDimSize(axis) == 1))
-    return operand;
-
-  return {};
+  // If the dim-length is 1, or reversing axis is unit-dim, also a no-op.
+  const bool isSplatInput =
+      llvm::isa_and_nonnull<SplatElementsAttr>(adaptor.getInput1());
+  if (!operandTy.hasRank() ||
+      (!isSplatInput && operandTy.getDimSize(axis) != 1))
+    return {};
+  return foldToInputIfTypeMatches(getType(), operand);
 }
 
 OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
@@ -1968,7 +1974,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
   // Transposing splat values just means reshaping.
   if (auto input =
           llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1())) {
-    if (input.isSplat() && resultTy.hasStaticShape() &&
+    if (input.isSplat() && resultTy.hasRank() && resultTy.hasStaticShape() &&
         input.getType().getElementType() == resultTy.getElementType())
       return input.reshape(resultTy);
   }
@@ -1979,7 +1985,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) {
   if (!llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
     return {};
 
-  return getInput1();
+  return foldToInputIfTypeMatches(getType(), getInput1());
 }
 
 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
@@ -2012,15 +2018,14 @@ OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
     return {};
   }
 
-  return definingOp.getInput1();
+  return foldToInputIfTypeMatches(getType(), definingOp.getInput1());
 }
 
 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
   auto input = getInput1();
   // Element-wise abs(abs(x)) = abs(x)
-  if (auto op = input.getDefiningOp<tosa::AbsOp>()) {
-    return input;
-  }
+  if (input.getDefiningOp<tosa::AbsOp>())
+    return foldToInputIfTypeMatches(getType(), input);
 
   return {};
 }
@@ -2068,6 +2073,8 @@ OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
     return {};
 
   auto shapeType = llvm::cast<ShapedType>(getType());
+  if (!shapeType.hasRank() || !shapeType.hasStaticShape())
+    return {};
   if (auto floatType = llvm::dyn_cast<FloatType>(inputAttr.getElementType())) {
     auto floatVal = inputAttr.getSplatValue<APFloat>();
     return DenseElementsAttr::get(shapeType,

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
index 0a035bbd3df00..d6961628afc9f 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
@@ -242,6 +242,8 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
   LogicalResult matchAndRewrite(tosa::TransposeOp op,
                                 PatternRewriter &rewriter) const override {
     auto outputType = cast<ShapedType>(op.getType());
+    if (!outputType.hasRank() || !outputType.hasStaticShape())
+      return failure();
     // TOSA supports quantized types.
     if (!outputType.getElementType().isIntOrIndexOrFloat())
       return failure();
@@ -295,6 +297,10 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
                  "tensor has a single user");
     }
 
+    if (inputTensor.getType() != recip.getType())
+      return rewriter.notifyMatchFailure(
+          recip, "input tensor and reciprocal output have 
diff erent type");
+
     // Create a new tensor with the updated values
     auto newTensor = applyElementWise<APFloat, APFloat, FloatType>(
         inputValues, &ReciprocalOp::calcOneElement,

diff  --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index bf53f06be3e07..367a60c4d2a8d 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -900,12 +900,77 @@ func.func @fold_resize_bilinear(%arg0 : tensor<1x15x13x1xi8>) -> tensor<1x15x13x
   %scale = tosa.const_shape { values = dense<[2, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
   %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
   %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
-  %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x15x13x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x15x13x1xi8>
+  %resize = tosa.resize %arg0, %scale, %offset, %border {mode = BILINEAR} : (tensor<1x15x13x1xi8>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x15x13x1xi8>
   return %resize : tensor<1x15x13x1xi8>
 }
 
 // -----
 
+// ResizeOp::fold: unit scale (1:1 Y and X), zero offset/border, in/out types equal.
+// CHECK-LABEL: @fold_resize_identity_scale
+func.func @fold_resize_identity_scale(%arg0 : tensor<1x15x13x1xf32>) -> tensor<1x15x13x1xf32> {
+  // CHECK-NOT: tosa.resize
+  %scale = tosa.const_shape { values = dense<[1, 1, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+  %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x15x13x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x15x13x1xf32>
+  return %resize : tensor<1x15x13x1xf32>
+}
+
+// -----
+// CHECK-LABEL: @fold_resize_identity_scale_to_unranked
+func.func @fold_resize_identity_scale_to_unranked(%arg0 : tensor<1x15x13x1xf32>) -> tensor<*xf32> {
+  // CHECK: tosa.resize
+  %scale = tosa.const_shape { values = dense<[1, 1, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+  %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x15x13x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %resize : tensor<*xf32>
+}
+
+// -----
+
+// Same parameters except scale_y_n != scale_y_d: fold must not apply.
+// CHECK-LABEL: @resize_nofold_asymmetric_y_scale
+func.func @resize_nofold_asymmetric_y_scale(%arg0 : tensor<1x15x13x1xf32>) -> tensor<1x29x13x1xf32> {
+  // CHECK: tosa.resize
+  %scale = tosa.const_shape { values = dense<[4, 2, 1, 1]> : tensor<4xindex> } : () -> !tosa.shape<4>
+  %offset = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %border = tosa.const_shape { values = dense<0> : tensor<2xindex> } : () -> !tosa.shape<2>
+  %resize = tosa.resize %arg0, %scale, %offset, %border {mode = NEAREST_NEIGHBOR} : (tensor<1x15x13x1xf32>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x29x13x1xf32>
+  return %resize : tensor<1x29x13x1xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dont_canonicalize_unranked_clamp
+func.func @dont_canonicalize_unranked_clamp(%arg0 : tensor<*xf32>) -> tensor<*xf32> {
+  // CHECK: tosa.clamp
+  %0 = tosa.clamp %arg0 {min_val = 0.0 : f32, max_val = 1.0 : f32} : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @dont_canonicalize_unranked_to_ranked_clamp
+func.func @dont_canonicalize_unranked_to_ranked_clamp(%arg0 : tensor<*xf32>) -> tensor<1xf32> {
+  // CHECK: tosa.clamp
+  %0 = tosa.clamp %arg0 {min_val = 0.0 : f32, max_val = 1.0 : f32} : (tensor<*xf32>) -> tensor<1xf32>
+  return %0 : tensor<1xf32>
+}
+// -----
+
+// CHECK-LABEL: @dont_canonicalize_unranked_slice_dynamic_size
+func.func @dont_canonicalize_unranked_slice_dynamic_size(%arg0: tensor<1x4xf32>) -> tensor<*xf32> {
+  // CHECK: tosa.slice
+  %start = tosa.const_shape {values = dense<[0, 0]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %size = tosa.const_shape {values = dense<[1, -1]> : tensor<2xindex>} : () -> !tosa.shape<2>
+  %0 = tosa.slice %arg0, %start, %size : (tensor<1x4xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @canonicalize_concat_slice_final_axis
 // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x12x12x1xf32>, %[[VAL_1:.*]]: tensor<1x12x12x1xf32>
 // CHECK: return %[[VAL_0]], %[[VAL_1]] : tensor<1x12x12x1xf32>, tensor<1x12x12x1xf32>
@@ -1202,6 +1267,37 @@ func.func @reverse_quant_fold() -> tensor<1x!quant.uniform<i8:f32, 3.07574046018
 
 // -----
 
+// ReverseOp::fold: unranked operand has hasRank() == false;
+// CHECK-LABEL: @reverse_nofold_unranked_operand
+func.func @reverse_nofold_unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+  // CHECK: tosa.reverse
+  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// Unit-dim no-op, but mismatch type
+// CHECK-LABEL: @reverse_nofold_unit_dim_unranked_result
+func.func @reverse_nofold_unit_dim_unranked_result(%arg0: tensor<1x4xf32>) -> tensor<*xf32> {
+  // CHECK: tosa.reverse
+  %0 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<1x4xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+// Splat fold returns the operand ElementsAttr; But result type doesn't match.
+// CHECK-LABEL: @reverse_nofold_splat_type_unmatch
+func.func @reverse_nofold_splat_type_unmatch() -> tensor<*xf32> {
+  // CHECK: tosa.reverse
+  %0 = "tosa.const"() <{values = dense<1.0> : tensor<4xf32>}> : () -> tensor<4xf32>
+  %1 = tosa.reverse %0 {axis = 0 : i32} : (tensor<4xf32>) -> tensor<*xf32>
+  return %1 : tensor<*xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @select_quant_fold
 func.func @select_quant_fold() -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>> {
    // CHECK: %[[CONST_0:.*]] = "tosa.const"() <{values = dense<0> : tensor<i8>}> : () -> tensor<!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>

diff  --git a/mlir/test/Dialect/Tosa/constant_folding.mlir b/mlir/test/Dialect/Tosa/constant_folding.mlir
index bc7debf277496..30ba340afaa2d 100644
--- a/mlir/test/Dialect/Tosa/constant_folding.mlir
+++ b/mlir/test/Dialect/Tosa/constant_folding.mlir
@@ -28,6 +28,46 @@ func.func @try_fold_equal_with_unranked_tensor(%arg0: tensor<4xi32>, %arg1: tens
 
 // -----
 
+// CHECK-LABEL: func @try_fold_unranked_constant_results
+func.func @try_fold_unranked_constant_results() {
+  // CHECK: tosa.equal
+  // CHECK: tosa.greater
+  // CHECK: tosa.greater_equal
+  // CHECK: tosa.cast
+  // CHECK: tosa.reciprocal
+  // CHECK: tosa.abs
+  // CHECK-NEXT: return
+  %lhs = arith.constant dense<1> : tensor<1xi32>
+  %rhs = arith.constant dense<2> : tensor<1xi32>
+  %f = arith.constant dense<2.0> : tensor<1xf32>
+  %0 = tosa.equal %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
+  %1 = tosa.greater %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
+  %2 = tosa.greater_equal %lhs, %rhs : (tensor<1xi32>, tensor<1xi32>) -> tensor<*xi1>
+  %3 = tosa.cast %lhs : (tensor<1xi32>) -> tensor<*xf32>
+  %4 = tosa.reciprocal %f : (tensor<1xf32>) -> tensor<*xf32>
+  %5 = tosa.abs %f : (tensor<1xf32>) -> tensor<*xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @try_fold_unranked_identity_results
+func.func @try_fold_unranked_identity_results(%arg0: tensor<1xf32>) {
+  // CHECK: tosa.transpose
+  // CHECK: tosa.reverse
+  // CHECK: tosa.abs
+  // CHECK: tosa.negate
+  // CHECK-NEXT: return
+  %zp = arith.constant dense<0.0> : tensor<1xf32>
+  %0 = tosa.transpose %arg0 { perms = array<i32: 0> } : (tensor<1xf32>) -> tensor<*xf32>
+  %1 = tosa.reverse %arg0 {axis = 0 : i32} : (tensor<1xf32>) -> tensor<*xf32>
+  %3 = tosa.abs %arg0 : (tensor<1xf32>) -> tensor<*xf32>
+  %4 = tosa.negate %arg0, %zp, %zp : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @fold_add_zero_rhs_f32
 func.func @fold_add_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {values = dense<0.0> : tensor<f32>} : () -> tensor<f32>
@@ -1501,4 +1541,3 @@ func.func @test_slice_shape() -> !tosa.shape<4> {
   %d = tosa.slice_shape %a, %b, %c : (!tosa.shape<6>, tensor<1xi32>, tensor<1xi32>) -> !tosa.shape<4>
   return %d : !tosa.shape<4>
 }
-// -----

diff  --git a/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir b/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir
index d95d267e8c907..711dfe4d2405e 100644
--- a/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-layerwise-constant-fold.mlir
@@ -46,6 +46,25 @@ func.func @transpose_nofold_shape(%arg0: tensor<3x4xf32>) -> tensor<?x?xf32> {
 
 // -----
 
+// CHECK-LABEL: @transpose_nofold_unranked_result_not_reshape
+func.func @transpose_nofold_unranked_result_not_reshape(%arg0: tensor<6x7xf32>) -> tensor<*xf32> {
+  // CHECK: tosa.transpose
+  %1 = tosa.transpose %arg0 { perms = array<i32: 1, 0> }: (tensor<6x7xf32>) -> tensor<*xf32>
+  return %1 : tensor<*xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @reciprocal_nofold_unranked_result
+func.func @reciprocal_nofold_unranked_result() -> tensor<*xf32> {
+  %input = "tosa.const"() {values = dense<2.0> : tensor<6x7xf32>} : () -> tensor<6x7xf32>
+  // CHECK: tosa.reciprocal
+  %1 = tosa.reciprocal %input : (tensor<6x7xf32>) -> tensor<*xf32>
+  return %1 : tensor<*xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @transpose_fold_splat
 func.func @transpose_fold_splat() -> tensor<3x2xf32> {
   %input = "tosa.const"() {values = dense<4.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>


        


More information about the Mlir-commits mailing list