[Mlir-commits] [mlir] [mlir][tosa] Enhance TosaInferShapes pass for simple shape inference (PR #178418)
Sayan Saha
llvmlistbot at llvm.org
Thu Feb 5 05:14:14 PST 2026
================
@@ -333,13 +161,240 @@ void validateSameOperandsAndResultRankTrait(Region ®ion) {
struct TosaInferShapes
: public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> {
public:
+ explicit TosaInferShapes() = default;
+ explicit TosaInferShapes(const TosaInferShapesPassOptions &options)
+ : TosaInferShapes() {
+ this->foldShapeExpressions = options.foldShapeExpressions;
+ this->convertFunctionBoundaries = options.convertFunctionBoundaries;
+ }
+
void runOnOperation() override {
func::FuncOp func = getOperation();
TypeModificationState state;
propagateShapesInRegion(func.getBody(), state);
state.commit();
validateSameOperandsAndResultRankTrait(func.getBody());
+
+ if (convertFunctionBoundaries)
+ convertFunctionReturnTypes(func);
+ }
+
+private:
+ void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
+ IfOp ifOp = dyn_cast<IfOp>(op);
+ if (!ifOp)
+ return;
+
+ for (auto ®ion : op.getRegions()) {
+ Block &frontBlock = region.front();
+ if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
+ return;
+
+ for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
+ auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
+ auto blockArg = frontBlock.getArgument(i - 1);
+ auto oldType = cast<ShapedType>(blockArg.getType());
+
+ if (inferredTy.hasRank()) {
+ Type newType = oldType.clone(inferredTy.getShape());
+ state.setType(blockArg, newType);
+ }
+ }
+
+ for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
+ ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
+ ifOp.getOperand(i + 1).getType());
+ ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
+ frontBlock.getArgument(i).getType());
+ ValueKnowledge joinedKnowledge =
+ ValueKnowledge::join(operandKnowledge, blockKnowledge);
+ if (!joinedKnowledge)
+ continue;
+ state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
+ }
+
+ propagateShapesInRegion(region, state);
+ }
+ }
+
+ void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
+ WhileOp whileOp = dyn_cast<WhileOp>(op);
+ if (!whileOp)
+ return;
+
+ // Determine what the expected argument types are to the cond/body blocks.
+ // The expected arguments should be compatible with ever iteration of the
+ // loop body / condition for tosa.while.
+ SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
+
+ bool hasNewTypes = true;
+ while (hasNewTypes) {
+ TypeModificationState localState;
+
+ // Set types on the block args.
+ Region &bodyRegion = op.getRegion(1);
+ Block &block = bodyRegion.front();
+ for (int i = 0, s = argTypes.size(); i < s; i++) {
+ localState.setType(block.getArgument(i), argTypes[i]);
+ }
+
+ // Propagate to the end.
+ propagateShapesInRegion(bodyRegion, localState);
+
+ // Find all the tosa yield types and verify there is a single one.
+ llvm::SmallVector<YieldOp> yieldOps;
+ for (auto &block : bodyRegion)
+ if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
+ yieldOps.push_back(yieldOp);
+
+ assert(yieldOps.size() == 1 && "missing or non-unique yield op");
+ // Using the new tosa.yield operand types, infer the new subtypes.
+ llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
+ for (auto ty : argTypes) {
+ yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
+ }
+
+ for (auto yieldOp : yieldOps) {
+ for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
+ auto newKnowledge =
+ ValueKnowledge::getKnowledgeFromType(it.value().getType());
+ yieldTypeInfo[it.index()] =
+ ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
+ }
+ }
+
+ // This should never happen.
+ if (yieldTypeInfo.size() != argTypes.size()) {
+ op.emitWarning(
+ "has a tosa.yield with the incorrect number of operands");
+ return;
+ }
+
+ // Determine the new block args and see if any changed.
+ hasNewTypes = false;
+ for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
+ Type newType = yieldTypeInfo[i].getType();
+ hasNewTypes |= (newType != argTypes[i]);
+ argTypes[i] = newType;
+ }
+
+ // Roll back all changes made during the speculative part of the
+ // algorithm.
+ localState.rollBack();
+ }
+
+ // We now set the block arguments according to the most recent shape
+ // inference results. This gives us the block arg types for the next
+ // iteration.
+ for (auto ®ion : op.getRegions()) {
+ for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
+ state.setType(region.front().getArgument(i), argTypes[i]);
+ }
+
+ propagateShapesInRegion(region, state);
+ }
+ }
+
+ void propagateShapesInRegion(Region ®ion, TypeModificationState &state) {
+ MLIRContext *ctx = region.getContext();
+ Dialect *tosaDialect = ctx->getLoadedDialect<TosaDialect>();
+ OperationFolder folder(ctx);
+
+ for (auto &block : region) {
+ for (auto it = block.begin(); it != block.end();) {
+ Operation &op = *it++;
----------------
sahas3 wrote:
Thanks for the clarification. I think the new behavior is subtle and adding a comment will improve readability.
https://github.com/llvm/llvm-project/pull/178418
More information about the Mlir-commits
mailing list