[Mlir-commits] [mlir] [mlir][tosa] Enhance TosaInferShapes pass for simple shape inference (PR #178418)

Sayan Saha llvmlistbot at llvm.org
Thu Feb 5 05:14:14 PST 2026


================
@@ -333,13 +161,240 @@ void validateSameOperandsAndResultRankTrait(Region &region) {
 struct TosaInferShapes
     : public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> {
 public:
+  explicit TosaInferShapes() = default;
+  explicit TosaInferShapes(const TosaInferShapesPassOptions &options)
+      : TosaInferShapes() {
+    this->foldShapeExpressions = options.foldShapeExpressions;
+    this->convertFunctionBoundaries = options.convertFunctionBoundaries;
+  }
+
   void runOnOperation() override {
     func::FuncOp func = getOperation();
     TypeModificationState state;
     propagateShapesInRegion(func.getBody(), state);
     state.commit();
 
     validateSameOperandsAndResultRankTrait(func.getBody());
+
+    if (convertFunctionBoundaries)
+      convertFunctionReturnTypes(func);
+  }
+
+private:
+  void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
+    IfOp ifOp = dyn_cast<IfOp>(op);
+    if (!ifOp)
+      return;
+
+    for (auto &region : op.getRegions()) {
+      Block &frontBlock = region.front();
+      if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
+        return;
+
+      for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
+        auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
+        auto blockArg = frontBlock.getArgument(i - 1);
+        auto oldType = cast<ShapedType>(blockArg.getType());
+
+        if (inferredTy.hasRank()) {
+          Type newType = oldType.clone(inferredTy.getShape());
+          state.setType(blockArg, newType);
+        }
+      }
+
+      for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
+        ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
+            ifOp.getOperand(i + 1).getType());
+        ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
+            frontBlock.getArgument(i).getType());
+        ValueKnowledge joinedKnowledge =
+            ValueKnowledge::join(operandKnowledge, blockKnowledge);
+        if (!joinedKnowledge)
+          continue;
+        state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
+      }
+
+      propagateShapesInRegion(region, state);
+    }
+  }
+
+  void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
+    WhileOp whileOp = dyn_cast<WhileOp>(op);
+    if (!whileOp)
+      return;
+
+    // Determine what the expected argument types are to the cond/body blocks.
+    // The expected arguments should be compatible with ever iteration of the
+    // loop body / condition for tosa.while.
+    SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
+
+    bool hasNewTypes = true;
+    while (hasNewTypes) {
+      TypeModificationState localState;
+
+      // Set types on the block args.
+      Region &bodyRegion = op.getRegion(1);
+      Block &block = bodyRegion.front();
+      for (int i = 0, s = argTypes.size(); i < s; i++) {
+        localState.setType(block.getArgument(i), argTypes[i]);
+      }
+
+      // Propagate to the end.
+      propagateShapesInRegion(bodyRegion, localState);
+
+      // Find all the tosa yield types and verify there is a single one.
+      llvm::SmallVector<YieldOp> yieldOps;
+      for (auto &block : bodyRegion)
+        if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
+          yieldOps.push_back(yieldOp);
+
+      assert(yieldOps.size() == 1 && "missing or non-unique yield op");
+      // Using the new tosa.yield operand types, infer the new subtypes.
+      llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
+      for (auto ty : argTypes) {
+        yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
+      }
+
+      for (auto yieldOp : yieldOps) {
+        for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
+          auto newKnowledge =
+              ValueKnowledge::getKnowledgeFromType(it.value().getType());
+          yieldTypeInfo[it.index()] =
+              ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
+        }
+      }
+
+      // This should never happen.
+      if (yieldTypeInfo.size() != argTypes.size()) {
+        op.emitWarning(
+            "has a tosa.yield with the incorrect number of operands");
+        return;
+      }
+
+      // Determine the new block args and see if any changed.
+      hasNewTypes = false;
+      for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
+        Type newType = yieldTypeInfo[i].getType();
+        hasNewTypes |= (newType != argTypes[i]);
+        argTypes[i] = newType;
+      }
+
+      // Roll back all changes made during the speculative part of the
+      // algorithm.
+      localState.rollBack();
+    }
+
+    // We now set the block arguments according to the most recent shape
+    // inference results. This gives us the block arg types for the next
+    // iteration.
+    for (auto &region : op.getRegions()) {
+      for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
+        state.setType(region.front().getArgument(i), argTypes[i]);
+      }
+
+      propagateShapesInRegion(region, state);
+    }
+  }
+
+  void propagateShapesInRegion(Region &region, TypeModificationState &state) {
+    MLIRContext *ctx = region.getContext();
+    Dialect *tosaDialect = ctx->getLoadedDialect<TosaDialect>();
+    OperationFolder folder(ctx);
+
+    for (auto &block : region) {
+      for (auto it = block.begin(); it != block.end();) {
+        Operation &op = *it++;
----------------
sahas3 wrote:

Thanks for the clarification. I think the new behavior is subtle and adding a comment will improve readability.

https://github.com/llvm/llvm-project/pull/178418


More information about the Mlir-commits mailing list