[Mlir-commits] [mlir] [mlir][tosa] Enhance TosaInferShapes pass for simple shape inference (PR #178418)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 28 05:17:47 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

This commit enhances the TosaInferShapes pass with two new options:
- fold-shape-expressions
- convert-function-boundaries

The "fold-shape-expressions" option enables greedily folding the newly added TOSA shape operations when possible. Folding these operations directly within TosaInferShapes is useful since it allows shapes of later operations to be inferred in a single pass.

The "convert-function-boundaries" updates the return types of a function to the newly inferred output shapes. This avoids the need for additional tensor.cast operations at function boundaries. This option is particularly useful when wanting to resolve a dynamic function to fully static.

When both of these options are used in conjunction with the "tosa-input-shapes" pass option, it's possible to resolve a dynamic function to static in a single pass.

Note: This PR is split into two commits. [68888db](https://github.com/llvm/llvm-project/commit/68888dbb1e70cc3f4383ccc2bab88f5325c347d6) is a simple refactor and consists of no logic changes. [db9a6a3](https://github.com/llvm/llvm-project/commit/db9a6a323afbd8cfb3b7f90b2cb172b68f8bcfdf) includes the changes for shape inference.

---

Patch is 41.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/178418.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td (+10) 
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp (+228-173) 
- (added) mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir (+59) 
- (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+84-116) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
index 03be41d684f3f..5979ce4962e55 100644
--- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td
@@ -42,6 +42,16 @@ def TosaInferShapesPass : Pass<"tosa-infer-shapes", "func::FuncOp"> {
     "tensor::TensorDialect",
     "tosa::TosaDialect",
   ];
+
+  let options = [
+    Option<"foldShapeExpressions", "fold-shape-expressions", "bool",
+           /*default=*/"false",
+           "Fold TOSA shape operations when they have known input values">,
+    Option<"convertFunctionBoundaries", "convert-function-boundaries", "bool",
+           /*default=*/"false",
+           "If enabled, the pass will convert function I/O types as well. Otherwise casts will"
+           "be inserted at the I/O boundaries.">
+  ];
 }
 
 def TosaMakeBroadcastablePass
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 4d347c02ee16d..37644ee8c03f8 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/FoldUtils.h"
 
 namespace mlir {
 namespace tosa {
@@ -128,179 +129,6 @@ class TypeModificationState {
   llvm::SmallVector<std::pair<Value, Type>> oldTypes;
 };
 
-void propagateShapesInRegion(Region &region, TypeModificationState &state);
-
-void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
-  IfOp ifOp = dyn_cast<IfOp>(op);
-  if (!ifOp)
-    return;
-
-  for (auto &region : op.getRegions()) {
-    Block &frontBlock = region.front();
-    if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
-      return;
-
-    for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
-      auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
-      auto blockArg = frontBlock.getArgument(i - 1);
-      auto oldType = cast<ShapedType>(blockArg.getType());
-
-      if (inferredTy.hasRank()) {
-        Type newType = oldType.clone(inferredTy.getShape());
-        state.setType(blockArg, newType);
-      }
-    }
-
-    for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
-      ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
-          ifOp.getOperand(i + 1).getType());
-      ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
-          frontBlock.getArgument(i).getType());
-      ValueKnowledge joinedKnowledge =
-          ValueKnowledge::join(operandKnowledge, blockKnowledge);
-      if (!joinedKnowledge)
-        continue;
-      state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
-    }
-
-    propagateShapesInRegion(region, state);
-  }
-}
-
-void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
-  WhileOp whileOp = dyn_cast<WhileOp>(op);
-  if (!whileOp)
-    return;
-
-  // Determine what the expected argument types are to the cond/body blocks.
-  // The expected arguments should be compatible with ever iteration of the
-  // loop body / condition for tosa.while.
-  SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
-
-  bool hasNewTypes = true;
-  while (hasNewTypes) {
-    TypeModificationState localState;
-
-    // Set types on the block args.
-    Region &bodyRegion = op.getRegion(1);
-    Block &block = bodyRegion.front();
-    for (int i = 0, s = argTypes.size(); i < s; i++) {
-      localState.setType(block.getArgument(i), argTypes[i]);
-    }
-
-    // Propagate to the end.
-    propagateShapesInRegion(bodyRegion, localState);
-
-    // Find all the tosa yield types and verify there is a single one.
-    llvm::SmallVector<YieldOp> yieldOps;
-    for (auto &block : bodyRegion)
-      if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
-        yieldOps.push_back(yieldOp);
-
-    assert(yieldOps.size() == 1 && "missing or non-unique yield op");
-    // Using the new tosa.yield operand types, infer the new subtypes.
-    llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
-    for (auto ty : argTypes) {
-      yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
-    }
-
-    for (auto yieldOp : yieldOps) {
-      for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
-        auto newKnowledge =
-            ValueKnowledge::getKnowledgeFromType(it.value().getType());
-        yieldTypeInfo[it.index()] =
-            ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
-      }
-    }
-
-    // This should never happen.
-    if (yieldTypeInfo.size() != argTypes.size()) {
-      op.emitWarning("has a tosa.yield with the incorrect number of operands");
-      return;
-    }
-
-    // Determine the new block args and see if any changed.
-    hasNewTypes = false;
-    for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
-      Type newType = yieldTypeInfo[i].getType();
-      hasNewTypes |= (newType != argTypes[i]);
-      argTypes[i] = newType;
-    }
-
-    // Roll back all changes made during the speculative part of the algorithm.
-    localState.rollBack();
-  }
-
-  // We now set the block arguments according to the most recent shape
-  // inference results. This gives us the block arg types for the next
-  // iteration.
-  for (auto &region : op.getRegions()) {
-    for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
-      state.setType(region.front().getArgument(i), argTypes[i]);
-    }
-
-    propagateShapesInRegion(region, state);
-  }
-}
-
-void propagateShapesInRegion(Region &region, TypeModificationState &state) {
-  Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
-
-  for (auto &block : region) {
-    for (Operation &op : block) {
-      if (op.getDialect() != tosaDialect)
-        continue;
-
-      propagateShapesToTosaIf(op, state);
-      propagateShapesToTosaWhile(op, state);
-
-      InferShapedTypeOpInterface shapeInterface =
-          dyn_cast<InferShapedTypeOpInterface>(op);
-      if (!shapeInterface)
-        continue;
-
-      SmallVector<ShapedTypeComponents> returnedShapes;
-
-      if (shapeInterface
-              .inferReturnTypeComponents(
-                  op.getContext(), op.getLoc(), op.getOperands(),
-                  op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
-                  op.getRegions(), returnedShapes)
-              .succeeded()) {
-        for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
-          Value result = std::get<0>(it);
-          ShapedTypeComponents predictedShape = std::get<1>(it);
-
-          // Determine the knowledge based on the output type.
-          // TODO: should also query WIP type probably
-          Type resultTy = result.getType();
-          auto currentKnowledge =
-              ValueKnowledge::getKnowledgeFromType(resultTy);
-
-          // Compute the knowledge based on the inferred type.
-          auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
-          inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
-          inferredKnowledge.hasRank = predictedShape.hasRank();
-          if (predictedShape.hasRank()) {
-            for (auto dim : predictedShape.getDims()) {
-              inferredKnowledge.sizes.push_back(dim);
-            }
-          }
-
-          // Compute the new type based on the joined version.
-          auto newKnowledge =
-              ValueKnowledge::join(currentKnowledge, inferredKnowledge);
-          if (!newKnowledge)
-            continue;
-
-          // Set new type
-          state.setType(result, newKnowledge.getType());
-        }
-      }
-    }
-  }
-}
-
 /// Recursively validate tosa ops with SameOperandsAndResultRank trait in region
 /// and all nested regions
 void validateSameOperandsAndResultRankTrait(Region &region) {
@@ -333,6 +161,13 @@ void validateSameOperandsAndResultRankTrait(Region &region) {
 struct TosaInferShapes
     : public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> {
 public:
+  explicit TosaInferShapes() = default;
+  explicit TosaInferShapes(const TosaInferShapesPassOptions &options)
+      : TosaInferShapes() {
+    this->foldShapeExpressions = options.foldShapeExpressions;
+    this->convertFunctionBoundaries = options.convertFunctionBoundaries;
+  }
+
   void runOnOperation() override {
     func::FuncOp func = getOperation();
     TypeModificationState state;
@@ -340,6 +175,226 @@ struct TosaInferShapes
     state.commit();
 
     validateSameOperandsAndResultRankTrait(func.getBody());
+
+    if (convertFunctionBoundaries)
+      convertFunctionReturnTypes(func);
+  }
+
+private:
+  void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
+    IfOp ifOp = dyn_cast<IfOp>(op);
+    if (!ifOp)
+      return;
+
+    for (auto &region : op.getRegions()) {
+      Block &frontBlock = region.front();
+      if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
+        return;
+
+      for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
+        auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
+        auto blockArg = frontBlock.getArgument(i - 1);
+        auto oldType = cast<ShapedType>(blockArg.getType());
+
+        if (inferredTy.hasRank()) {
+          Type newType = oldType.clone(inferredTy.getShape());
+          state.setType(blockArg, newType);
+        }
+      }
+
+      for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
+        ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
+            ifOp.getOperand(i + 1).getType());
+        ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
+            frontBlock.getArgument(i).getType());
+        ValueKnowledge joinedKnowledge =
+            ValueKnowledge::join(operandKnowledge, blockKnowledge);
+        if (!joinedKnowledge)
+          continue;
+        state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
+      }
+
+      propagateShapesInRegion(region, state);
+    }
+  }
+
+  void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
+    WhileOp whileOp = dyn_cast<WhileOp>(op);
+    if (!whileOp)
+      return;
+
+    // Determine what the expected argument types are to the cond/body blocks.
+    // The expected arguments should be compatible with ever iteration of the
+    // loop body / condition for tosa.while.
+    SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
+
+    bool hasNewTypes = true;
+    while (hasNewTypes) {
+      TypeModificationState localState;
+
+      // Set types on the block args.
+      Region &bodyRegion = op.getRegion(1);
+      Block &block = bodyRegion.front();
+      for (int i = 0, s = argTypes.size(); i < s; i++) {
+        localState.setType(block.getArgument(i), argTypes[i]);
+      }
+
+      // Propagate to the end.
+      propagateShapesInRegion(bodyRegion, localState);
+
+      // Find all the tosa yield types and verify there is a single one.
+      llvm::SmallVector<YieldOp> yieldOps;
+      for (auto &block : bodyRegion)
+        if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
+          yieldOps.push_back(yieldOp);
+
+      assert(yieldOps.size() == 1 && "missing or non-unique yield op");
+      // Using the new tosa.yield operand types, infer the new subtypes.
+      llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
+      for (auto ty : argTypes) {
+        yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
+      }
+
+      for (auto yieldOp : yieldOps) {
+        for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
+          auto newKnowledge =
+              ValueKnowledge::getKnowledgeFromType(it.value().getType());
+          yieldTypeInfo[it.index()] =
+              ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
+        }
+      }
+
+      // This should never happen.
+      if (yieldTypeInfo.size() != argTypes.size()) {
+        op.emitWarning(
+            "has a tosa.yield with the incorrect number of operands");
+        return;
+      }
+
+      // Determine the new block args and see if any changed.
+      hasNewTypes = false;
+      for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
+        Type newType = yieldTypeInfo[i].getType();
+        hasNewTypes |= (newType != argTypes[i]);
+        argTypes[i] = newType;
+      }
+
+      // Roll back all changes made during the speculative part of the
+      // algorithm.
+      localState.rollBack();
+    }
+
+    // We now set the block arguments according to the most recent shape
+    // inference results. This gives us the block arg types for the next
+    // iteration.
+    for (auto &region : op.getRegions()) {
+      for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
+        state.setType(region.front().getArgument(i), argTypes[i]);
+      }
+
+      propagateShapesInRegion(region, state);
+    }
+  }
+
+  void propagateShapesInRegion(Region &region, TypeModificationState &state) {
+    MLIRContext *ctx = region.getContext();
+    Dialect *tosaDialect = ctx->getLoadedDialect<TosaDialect>();
+    OperationFolder folder(ctx);
+
+    for (auto &block : region) {
+      for (auto it = block.begin(); it != block.end();) {
+        Operation &op = *it++;
+        if (op.getDialect() != tosaDialect)
+          continue;
+
+        propagateShapesToTosaIf(op, state);
+        propagateShapesToTosaWhile(op, state);
+
+        if (foldShapeExpressions &&
+            op.hasTrait<OpTrait::tosa::TosaShapeOperator>()) {
+          (void)folder.tryToFold(&op);
+          continue;
+        }
+
+        InferShapedTypeOpInterface shapeInterface =
+            dyn_cast<InferShapedTypeOpInterface>(op);
+        if (!shapeInterface)
+          continue;
+
+        SmallVector<ShapedTypeComponents> returnedShapes;
+
+        if (shapeInterface
+                .inferReturnTypeComponents(
+                    op.getContext(), op.getLoc(), op.getOperands(),
+                    op.getDiscardableAttrDictionary(),
+                    op.getPropertiesStorage(), op.getRegions(), returnedShapes)
+                .succeeded()) {
+          for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
+            Value result = std::get<0>(it);
+            ShapedTypeComponents predictedShape = std::get<1>(it);
+
+            // Determine the knowledge based on the output type.
+            // TODO: should also query WIP type probably
+            Type resultTy = result.getType();
+            auto currentKnowledge =
+                ValueKnowledge::getKnowledgeFromType(resultTy);
+
+            // Compute the knowledge based on the inferred type.
+            auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
+            inferredKnowledge.dtype =
+                cast<ShapedType>(resultTy).getElementType();
+            inferredKnowledge.hasRank = predictedShape.hasRank();
+            if (predictedShape.hasRank()) {
+              for (auto dim : predictedShape.getDims()) {
+                inferredKnowledge.sizes.push_back(dim);
+              }
+            }
+
+            // Compute the new type based on the joined version.
+            auto newKnowledge =
+                ValueKnowledge::join(currentKnowledge, inferredKnowledge);
+            if (!newKnowledge)
+              continue;
+
+            // Set new type
+            state.setType(result, newKnowledge.getType());
+          }
+        }
+      }
+    }
+  }
+
+  void convertFunctionReturnTypes(func::FuncOp func) {
+    IRRewriter rewriter(func.getContext());
+    SmallVector<Type> newReturnTypes;
+
+    // Rewrite func.return ops, removing dead tensor.cast ops if possible
+    func.walk([&rewriter, &newReturnTypes](func::ReturnOp ret) {
+      SmallVector<Value> newReturnValues;
+      OperandRange returnOperands = ret.getOperands();
+      newReturnValues.reserve(returnOperands.size());
+      newReturnTypes.reserve(returnOperands.size());
+
+      for (const Value &v : returnOperands) {
+        Value newReturnValue = v;
+        if (auto castOp = v.getDefiningOp<tensor::CastOp>()) {
+          newReturnValue = castOp.getSource();
+          if (castOp->use_empty())
+            rewriter.eraseOp(castOp);
+        }
+        newReturnValues.push_back(newReturnValue);
+        newReturnTypes.push_back(newReturnValue.getType());
+      }
+
+      rewriter.setInsertionPoint(ret);
+      rewriter.replaceOpWithNewOp<func::ReturnOp>(ret, newReturnValues);
+    });
+
+    // Update function return types with newly inferred types
+    const FunctionType oldType = func.getFunctionType();
+    const FunctionType newType = FunctionType::get(
+        func.getContext(), oldType.getInputs(), newReturnTypes);
+    func.setType(newType);
   }
 };
 } // namespace
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir
new file mode 100644
index 0000000000000..1d8cc863ebe3e
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes-fold-shape-expressions.mlir
@@ -0,0 +1,59 @@
+// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes="fold-shape-expressions" %s | FileCheck %s --check-prefixes=CHECK,DEFAULT
+// RUN: mlir-opt --split-input-file --verify-diagnostics --tosa-infer-shapes="convert-function-boundaries fold-shape-expressions" %s | FileCheck %s --check-prefixes=CHECK,FUNCBOUND
+
+// -----
+
+// CHECK-LABEL: test_simple_shape_expression
+func.func @test_simple_shape_expression(%arg0: tensor<7x12xi32>, %arg1: tensor<80xi32>, %arg2: tensor<4xi32>) -> tensor<?xi32> {
+  // CHECK-NOT: tosa.dim
+  // CHECK-NOT: tosa.add_shape
+  // CHECK: %[[SHAPE:.+]] = tosa.const_shape {values = dense<84> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK: %[[RESHAPE:.+]] = tosa.reshape %arg0, %[[SHAPE]] : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<84xi32>
+  // CHECK: %[[TILE:.+]] = tosa.tile %[[RESHAPE]], %[[SHAPE]] : (tensor<84xi32>, !tosa.shape<1>) -> tensor<7056xi32>
+  // DEFAULT: %[[CAST:.+]] = tensor.cast %[[TILE]] : tensor<7056xi32> to tensor<?xi32>
+  // DEFAULT: return %[[CAST]] : tensor<?xi32>
+  // FUNCBOUND: return %[[TILE]] : tensor<7056xi32>
+  %a = tosa.dim %arg1 {axis = 0: i32} : (tensor<80xi32>) -> !tosa.shape<1>
+  %b = tosa.dim %arg2 {axis = 0: i32} : (tensor<4xi32>) -> !tosa.shape<1>
+  %c = tosa.add_shape %a, %b : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
+  %d = tosa.reshape %arg0, %c : (tensor<7x12xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  %e = tosa.dim %d {axis = 0: i32} : (tensor<?xi32>) -> !tosa.shape<1>
+  %f = tosa.tile %d, %e : (tensor<?xi32>, !tosa.shape<1>) -> tensor<?xi32>
+  return %f : tensor<?xi32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_with_shape_expressions
+func.func @test_cond_if_with_shape_expressions(%arg0 : tensor<3xf32>, %arg1 : tensor<3xf32>, %arg2 : tensor<i1>) -> () {
+  // CHECK: %[[CONST_SHAPE:.*]] = tosa.const_shape {values = dense<3> : tensor<1xindex>} : () -> !tosa.shape<1>
+  // CHECK: tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> {
+  %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) : tensor<i1> (tensor<3xf32>, tensor<3xf32>) -> tensor<?xf32> {
+    // CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
+    ^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
+      // CHECK-NOT: tosa.dim
+      %0 = tosa.dim %arg3 {axis = 0 : i32} : (tensor<?xf32>) -> !tosa.shape<1>
+      // CHECK: %[[RESHAPE:.*]] = tosa.reshape %arg3, %[[CONST_SHAPE]] : (tensor<3xf32>, !tosa.shape<1>) -> tensor<3xf32>
+      %1 = tosa.reshape %arg3, %0 : (tensor<?xf32>, !tosa.shape<1>) -> tensor<?xf32>
+      // CHECK: tosa.yield %[[RESHAPE]] : tensor<3xf32>
+      tosa.yield %1 : tensor<?xf32>
+  } else {
+    // CHECK: ^bb0(%arg3: tensor<3xf32>, %arg4: tensor<3xf32>)
+    ^bb0(%arg3 : tensor<?xf32>, %arg4 : tensor<?xf32>):
+      // CHECK: tosa.yield %arg4 : tensor<3xf32>
+      tosa.yield %arg4 : tensor<?xf32>
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: test_no_fold_shape_expression
+func.func @test_no_fold_shape_expression(%arg0: tensor<1x?x3xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+  /...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/178418


More information about the Mlir-commits mailing list