[Mlir-commits] [mlir] [mlir][tosa] Enhance TosaInferShapes pass for simple shape inference (PR #178418)

Sayan Saha llvmlistbot at llvm.org
Wed Jan 28 09:05:56 PST 2026


================
@@ -333,13 +161,240 @@ void validateSameOperandsAndResultRankTrait(Region &region) {
 struct TosaInferShapes
     : public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> {
 public:
+  explicit TosaInferShapes() = default;
+  explicit TosaInferShapes(const TosaInferShapesPassOptions &options)
+      : TosaInferShapes() {
+    this->foldShapeExpressions = options.foldShapeExpressions;
+    this->convertFunctionBoundaries = options.convertFunctionBoundaries;
+  }
+
   void runOnOperation() override {
     func::FuncOp func = getOperation();
     TypeModificationState state;
     propagateShapesInRegion(func.getBody(), state);
     state.commit();
 
     validateSameOperandsAndResultRankTrait(func.getBody());
+
+    if (convertFunctionBoundaries)
+      convertFunctionReturnTypes(func);
+  }
+
+private:
+  void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
+    IfOp ifOp = dyn_cast<IfOp>(op);
+    if (!ifOp)
+      return;
+
+    for (auto &region : op.getRegions()) {
+      Block &frontBlock = region.front();
+      if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
+        return;
+
+      for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
+        auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
+        auto blockArg = frontBlock.getArgument(i - 1);
+        auto oldType = cast<ShapedType>(blockArg.getType());
+
+        if (inferredTy.hasRank()) {
+          Type newType = oldType.clone(inferredTy.getShape());
+          state.setType(blockArg, newType);
+        }
+      }
+
+      for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
+        ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
+            ifOp.getOperand(i + 1).getType());
+        ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
+            frontBlock.getArgument(i).getType());
+        ValueKnowledge joinedKnowledge =
+            ValueKnowledge::join(operandKnowledge, blockKnowledge);
+        if (!joinedKnowledge)
+          continue;
+        state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
+      }
+
+      propagateShapesInRegion(region, state);
+    }
+  }
+
+  void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
+    WhileOp whileOp = dyn_cast<WhileOp>(op);
+    if (!whileOp)
+      return;
+
+    // Determine what the expected argument types are to the cond/body blocks.
+    // The expected arguments should be compatible with ever iteration of the
+    // loop body / condition for tosa.while.
+    SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
+
+    bool hasNewTypes = true;
+    while (hasNewTypes) {
+      TypeModificationState localState;
+
+      // Set types on the block args.
+      Region &bodyRegion = op.getRegion(1);
+      Block &block = bodyRegion.front();
+      for (int i = 0, s = argTypes.size(); i < s; i++) {
+        localState.setType(block.getArgument(i), argTypes[i]);
+      }
+
+      // Propagate to the end.
+      propagateShapesInRegion(bodyRegion, localState);
+
+      // Find all the tosa yield types and verify there is a single one.
+      llvm::SmallVector<YieldOp> yieldOps;
+      for (auto &block : bodyRegion)
+        if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
+          yieldOps.push_back(yieldOp);
+
+      assert(yieldOps.size() == 1 && "missing or non-unique yield op");
+      // Using the new tosa.yield operand types, infer the new subtypes.
+      llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
+      for (auto ty : argTypes) {
+        yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
+      }
+
+      for (auto yieldOp : yieldOps) {
+        for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
+          auto newKnowledge =
+              ValueKnowledge::getKnowledgeFromType(it.value().getType());
+          yieldTypeInfo[it.index()] =
+              ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
+        }
+      }
+
+      // This should never happen.
+      if (yieldTypeInfo.size() != argTypes.size()) {
+        op.emitWarning(
+            "has a tosa.yield with the incorrect number of operands");
+        return;
+      }
+
+      // Determine the new block args and see if any changed.
+      hasNewTypes = false;
+      for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
+        Type newType = yieldTypeInfo[i].getType();
+        hasNewTypes |= (newType != argTypes[i]);
+        argTypes[i] = newType;
+      }
+
+      // Roll back all changes made during the speculative part of the
+      // algorithm.
+      localState.rollBack();
+    }
+
+    // We now set the block arguments according to the most recent shape
+    // inference results. This gives us the block arg types for the next
+    // iteration.
+    for (auto &region : op.getRegions()) {
+      for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
+        state.setType(region.front().getArgument(i), argTypes[i]);
+      }
+
+      propagateShapesInRegion(region, state);
+    }
+  }
+
+  void propagateShapesInRegion(Region &region, TypeModificationState &state) {
+    MLIRContext *ctx = region.getContext();
+    Dialect *tosaDialect = ctx->getLoadedDialect<TosaDialect>();
+    OperationFolder folder(ctx);
+
+    for (auto &block : region) {
+      for (auto it = block.begin(); it != block.end();) {
+        Operation &op = *it++;
+        if (op.getDialect() != tosaDialect)
+          continue;
+
+        propagateShapesToTosaIf(op, state);
+        propagateShapesToTosaWhile(op, state);
+
+        if (foldShapeExpressions &&
+            op.hasTrait<OpTrait::tosa::TosaShapeOperator>()) {
+          (void)folder.tryToFold(&op);
+          continue;
+        }
+
+        InferShapedTypeOpInterface shapeInterface =
+            dyn_cast<InferShapedTypeOpInterface>(op);
+        if (!shapeInterface)
+          continue;
+
+        SmallVector<ShapedTypeComponents> returnedShapes;
+
+        if (shapeInterface
+                .inferReturnTypeComponents(
+                    op.getContext(), op.getLoc(), op.getOperands(),
+                    op.getDiscardableAttrDictionary(),
+                    op.getPropertiesStorage(), op.getRegions(), returnedShapes)
+                .succeeded()) {
+          for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
+            Value result = std::get<0>(it);
+            ShapedTypeComponents predictedShape = std::get<1>(it);
+
+            // Determine the knowledge based on the output type.
+            // TODO: should also query WIP type probably
+            Type resultTy = result.getType();
+            auto currentKnowledge =
+                ValueKnowledge::getKnowledgeFromType(resultTy);
+
+            // Compute the knowledge based on the inferred type.
+            auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
+            inferredKnowledge.dtype =
+                cast<ShapedType>(resultTy).getElementType();
+            inferredKnowledge.hasRank = predictedShape.hasRank();
+            if (predictedShape.hasRank()) {
+              for (auto dim : predictedShape.getDims()) {
+                inferredKnowledge.sizes.push_back(dim);
+              }
+            }
+
+            // Compute the new type based on the joined version.
+            auto newKnowledge =
+                ValueKnowledge::join(currentKnowledge, inferredKnowledge);
+            if (!newKnowledge)
+              continue;
+
+            // Set new type
+            state.setType(result, newKnowledge.getType());
+          }
+        }
+      }
+    }
+  }
+
+  void convertFunctionReturnTypes(func::FuncOp func) {
+    IRRewriter rewriter(func.getContext());
+    SmallVector<Type> newReturnTypes;
+
+    // Rewrite func.return ops, removing dead tensor.cast ops if possible
+    func.walk([&rewriter, &newReturnTypes](func::ReturnOp ret) {
+      SmallVector<Value> newReturnValues;
+      OperandRange returnOperands = ret.getOperands();
+      newReturnValues.reserve(returnOperands.size());
+      newReturnTypes.reserve(returnOperands.size());
----------------
sahas3 wrote:

```suggestion
      newReturnTypes.reserve(newReturnTypes.size() + returnOperands.size());
```

https://github.com/llvm/llvm-project/pull/178418


More information about the Mlir-commits mailing list