[Mlir-commits] [mlir] 2ae5ded - [mlir][tosa] Update ControlFlow variable names to match with TOSA v1.0 spec (#129790)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 4 17:17:46 PST 2025


Author: Jerry-Ge
Date: 2025-03-05T01:17:42Z
New Revision: 2ae5dedd7a211869c2682b05baefe8e46cdd3c40

URL: https://github.com/llvm/llvm-project/commit/2ae5dedd7a211869c2682b05baefe8e46cdd3c40
DIFF: https://github.com/llvm/llvm-project/commit/2ae5dedd7a211869c2682b05baefe8e46cdd3c40.diff

LOG: [mlir][tosa] Update ControlFlow variable names to match with TOSA v1.0 spec (#129790)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
    mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
    mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
    mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e0f2fd411bbe4..8362092021b0b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -2471,12 +2471,12 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
   }];
 
   let arguments = (ins
-    Tosa_I1Tensor:$cond,
-    Variadic<Tosa_Tensor>:$inputs
+    Tosa_I1Tensor:$condition,
+    Variadic<Tosa_Tensor>:$input_list
   );
 
   let results = (outs
-    Variadic<Tosa_Tensor>:$output
+    Variadic<Tosa_Tensor>:$output_list
   );
 
   list<Availability> availability = [
@@ -2485,8 +2485,8 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
   ];
 
   let regions = (region
-    SizedRegion<1>:$then_branch,
-    SizedRegion<1>:$else_branch
+    SizedRegion<1>:$then_graph,
+    SizedRegion<1>:$else_graph
   );
 
   let hasCustomAssemblyFormat = 1;
@@ -2513,11 +2513,11 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
   }];
 
   let arguments = (ins
-    Variadic<Tosa_Tensor>:$inputs
+    Variadic<Tosa_Tensor>:$input_list
   );
 
   let results = (outs
-    Variadic<Tosa_Tensor>:$output
+    Variadic<Tosa_Tensor>:$output_list
   );
 
   list<Availability> availability = [
@@ -2526,8 +2526,8 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
   ];
 
   let regions = (region
-    SizedRegion<1>:$cond,
-    SizedRegion<1>:$body
+    SizedRegion<1>:$cond_graph,
+    SizedRegion<1>:$body_graph
   );
 
   let hasCustomAssemblyFormat = 1;

diff  --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index 9139bf191fdf1..ef144fc7f0d54 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -68,13 +68,13 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
   LogicalResult matchAndRewrite(tosa::IfOp op,
                                 PatternRewriter &rewriter) const final {
     auto condition =
-        rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCond());
+        rewriter.create<tensor::ExtractOp>(op.getLoc(), op.getCondition());
     auto newIf = rewriter.create<scf::IfOp>(op.getLoc(), op.getResultTypes(),
                                             condition, true);
 
-    inlineIfCase(op.getThenBranch(), newIf.getThenRegion(), op.getInputs(),
+    inlineIfCase(op.getThenGraph(), newIf.getThenRegion(), op.getInputList(),
                  rewriter);
-    inlineIfCase(op.getElseBranch(), newIf.getElseRegion(), op.getInputs(),
+    inlineIfCase(op.getElseGraph(), newIf.getElseRegion(), op.getInputList(),
                  rewriter);
 
     rewriter.replaceOp(op, newIf.getResults());
@@ -158,12 +158,12 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
   LogicalResult matchAndRewrite(tosa::WhileOp op,
                                 PatternRewriter &rewriter) const final {
     auto newWhile = rewriter.create<scf::WhileOp>(
-        op.getLoc(), op.getResultTypes(), op.getInputs());
+        op.getLoc(), op.getResultTypes(), op.getInputList());
     rewriter.createBlock(&newWhile.getBefore());
     rewriter.createBlock(&newWhile.getAfter());
 
-    inlineWhileCase(op.getCond(), newWhile.getBefore(), rewriter, true);
-    inlineWhileCase(op.getBody(), newWhile.getAfter(), rewriter, false);
+    inlineWhileCase(op.getCondGraph(), newWhile.getBefore(), rewriter, true);
+    inlineWhileCase(op.getBodyGraph(), newWhile.getAfter(), rewriter, false);
 
     rewriter.replaceOp(op, newWhile.getResults());
 

diff  --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 8841d53b6e64d..800968e6f4766 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -127,7 +127,9 @@ struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
 //===----------------------------------------------------------------------===//
 
 /// Returns the while loop body.
-SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
+SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
+  return {&getBodyGraph()};
+}
 
 //===----------------------------------------------------------------------===//
 // Tosa dialect initialization.
@@ -2536,7 +2538,7 @@ LogicalResult WhileOp::inferReturnTypeComponents(
     WhileOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
   llvm::SmallVector<tosa::YieldOp> yieldOps;
-  for (auto &block : adaptor.getBody())
+  for (auto &block : adaptor.getBodyGraph())
     if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
       yieldOps.push_back(returnOp);
 
@@ -2616,19 +2618,19 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
 void IfOp::print(OpAsmPrinter &p) {
   bool printBlockTerminators = false;
 
-  p << " " << getCond();
+  p << " " << getCondition();
   if (!getResults().empty()) {
     p << " -> (" << getResultTypes() << ")";
     // Print yield explicitly if the op defines values.
     printBlockTerminators = true;
   }
   p << ' ';
-  p.printRegion(getThenBranch(),
+  p.printRegion(getThenGraph(),
                 /*printEntryBlockArgs=*/false,
                 /*printBlockTerminators=*/printBlockTerminators);
 
   // Print the 'else' regions if it exists and has a block.
-  auto &elseRegion = getElseBranch();
+  auto &elseRegion = getElseGraph();
   if (!elseRegion.empty()) {
     p << " else ";
     p.printRegion(elseRegion,
@@ -2726,14 +2728,15 @@ static void printInitializationList(OpAsmPrinter &parser,
 }
 
 void WhileOp::print(OpAsmPrinter &parser) {
-  printInitializationList(parser, getCond().front().getArguments(), getInputs(),
-                          " ");
+  printInitializationList(parser, getCondGraph().front().getArguments(),
+                          getInputList(), " ");
   parser << " : ";
-  parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes());
+  parser.printFunctionalType(getInputList().getTypes(),
+                             getResults().getTypes());
   parser << ' ';
-  parser.printRegion(getCond(), /*printEntryBlockArgs=*/false);
+  parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
   parser << " do ";
-  parser.printRegion(getBody());
+  parser.printRegion(getBodyGraph());
   parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
 }
 

diff  --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 96fb054d75b66..1060f520d2930 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -371,14 +371,14 @@ struct TosaValidation : public tosa::impl::TosaValidationBase<TosaValidation> {
       }
     }
     if (auto condIf = dyn_cast<tosa::IfOp>(op)) {
-      if (!levelCheckListSize(op, condIf.getInputs().size(), "inputs") ||
-          !levelCheckListSize(op, condIf.getOutput().size(), "outputs")) {
+      if (!levelCheckListSize(op, condIf.getInputList().size(), "inputs") ||
+          !levelCheckListSize(op, condIf.getOutputList().size(), "outputs")) {
         return false;
       }
     }
     if (auto w = dyn_cast<tosa::WhileOp>(op)) {
-      if (!levelCheckListSize(op, w.getInputs().size(), "inputs") ||
-          !levelCheckListSize(op, w.getOutput().size(), "outputs")) {
+      if (!levelCheckListSize(op, w.getInputList().size(), "inputs") ||
+          !levelCheckListSize(op, w.getOutputList().size(), "outputs")) {
         return false;
       }
     }
@@ -450,7 +450,7 @@ bool TosaValidation::levelCheckRanks(tosa::IfOp tosaOp) {
   auto op = tosaOp.getOperation();
 
   // Only the condition input has rank limitation.
-  if (!levelCheckRank(op, tosaOp.getCond(), "operand", tosaLevel.MAX_RANK))
+  if (!levelCheckRank(op, tosaOp.getCondition(), "operand", tosaLevel.MAX_RANK))
     return false;
 
   return true;


        


More information about the Mlir-commits mailing list