[Mlir-commits] [mlir] 2ae5ded - [mlir][tosa] Update ControlFlow variable names to match with TOSA v1.0 spec (#129790)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 4 17:17:46 PST 2025
Author: Jerry-Ge
Date: 2025-03-05T01:17:42Z
New Revision: 2ae5dedd7a211869c2682b05baefe8e46cdd3c40
URL: https://github.com/llvm/llvm-project/commit/2ae5dedd7a211869c2682b05baefe8e46cdd3c40
DIFF: https://github.com/llvm/llvm-project/commit/2ae5dedd7a211869c2682b05baefe8e46cdd3c40.diff
LOG: [mlir][tosa] Update ControlFlow variable names to match with TOSA v1.0 spec (#129790)
Added:
Modified:
mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e0f2fd411bbe4..8362092021b0b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2471,12 +2471,12 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
}];
let arguments = (ins
- Tosa_I1Tensor:$cond,
- Variadic<Tosa_Tensor>:$inputs
+ Tosa_I1Tensor:$condition,
+ Variadic<Tosa_Tensor>:$input_list
);
let results = (outs
- Variadic<Tosa_Tensor>:$output
+ Variadic<Tosa_Tensor>:$output_list
);
list<Availability> availability = [
@@ -2485,8 +2485,8 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
];
let regions = (region
- SizedRegion<1>:$then_branch,
- SizedRegion<1>:$else_branch
+ SizedRegion<1>:$then_graph,
+ SizedRegion<1>:$else_graph
);
let hasCustomAssemblyFormat = 1;
@@ -2513,11 +2513,11 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
}];
let arguments = (ins
- Variadic<Tosa_Tensor>:$inputs
+ Variadic<Tosa_Tensor>:$input_list
);
let results = (outs
- Variadic<Tosa_Tensor>:$output
+ Variadic<Tosa_Tensor>:$output_list
);
list<Availability> availability = [
@@ -2526,8 +2526,8 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
];
let regions = (region
- SizedRegion<1>:$cond,
- SizedRegion<1>:$body
+ SizedRegion<1>:$cond_graph,
+ SizedRegion<1>:$body_graph
);
let hasCustomAssemblyFormat = 1;
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index 9139bf191fdf1..ef144fc7f0d54 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -68,13 +68,13 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
LogicalResult matchAndRewrite(tosa::IfOp op,
PatternRewriter &rewriter) const final {
auto condition =
- rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
+ rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCondition());
auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
condition, true);
- inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
+ inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(),
rewriter);
- inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
+ inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputList(),
rewriter);
rewriter.replaceOp(op, newIf.getResults());
@@ -158,12 +158,12 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
LogicalResult matchAndRewrite(tosa::WhileOp op,
PatternRewriter &rewriter) const final {
auto newWhile = rewriter.create<scf::WhileOp>(
- op.getLoc(), op.getResultTypes(), op.getInputs());
+ op.getLoc(), op.getResultTypes(), op.getInputList());
rewriter.createBlock(&newWhile.getBefore());
rewriter.createBlock(&newWhile.getAfter());
- inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
- inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
+ inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter, true);
+ inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter, false);
rewriter.replaceOp(op, newWhile.getResults());
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 8841d53b6e64d..800968e6f4766 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -127,7 +127,9 @@ struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
//===----------------------------------------------------------------------===//
/// Returns the while loop body.
-SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
+SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
+ return {&getBodyGraph()};
+}
//===----------------------------------------------------------------------===//
// Tosa dialect initialization.
@@ -2536,7 +2538,7 @@ LogicalResult WhileOp::inferReturnTypeComponents(
WhileOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
llvm::SmallVector<tosa::YieldOp> yieldOps;
- for (auto &block : adaptor.getBody())
+ for (auto &block : adaptor.getBodyGraph())
if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
yieldOps.push_back(returnOp);
@@ -2616,19 +2618,19 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
void IfOp::print(OpAsmPrinter &p) {
bool printBlockTerminators = false;
- p << " " << getCond();
+ p << " " << getCondition();
if (!getResults().empty()) {
p << " -> (" << getResultTypes() << ")";
// Print yield explicitly if the op defines values.
printBlockTerminators = true;
}
p << ' ';
- p.printRegion(getThenBranch(),
+ p.printRegion(getThenGraph(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/printBlockTerminators);
// Print the 'else' regions if it exists and has a block.
- auto &elseRegion = getElseBranch();
+ auto &elseRegion = getElseGraph();
if (!elseRegion.empty()) {
p << " else ";
p.printRegion(elseRegion,
@@ -2726,14 +2728,15 @@ static void printInitializationList(OpAsmPrinter &parser,
}
void WhileOp::print(OpAsmPrinter &parser) {
- printInitializationList(parser, getCond().front().getArguments(), getInputs(),
- " ");
+ printInitializationList(parser, getCondGraph().front().getArguments(),
+ getInputList(), " ");
parser << " : ";
- parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
+ parser.printFunctionalType(getInputList().getTypes(),
+ getResults().getTypes());
parser << ' ';
- parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
+ parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
parser << " do ";
- parser.printRegion(getBody());
+ parser.printRegion(getBodyGraph());
parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 96fb054d75b66..1060f520d2930 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -371,14 +371,14 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
}
}
if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
- if (!levelCheckListSize(op, condIf.getInputs().size(), "inputs") ||
- !levelCheckListSize(op, condIf.getOutput().size(), "outputs")) {
+ if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
+ !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
return false;
}
}
if (auto w = dyn_cast<tosa::WhileOp>(op)) {
- if (!levelCheckListSize(op, w.getInputs().size(), "inputs") ||
- !levelCheckListSize(op, w.getOutput().size(), "outputs")) {
+ if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
+ !levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
return false;
}
}
@@ -450,7 +450,7 @@ bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
auto op = tosaOp.getOperation();
// Only the condition input has rank limitation.
- if (!levelCheckRank(op, tosaOp.getCond(), "operand", tosaLevel.MAX_RANK))
+ if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
return false;
return true;
More information about the Mlir-commits
mailing list