[Mlir-commits] [mlir] [mlir][tosa] Print generic `cond_if` when block arguments are present (PR #144859)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 19 02:35:06 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

The generic printer/parser captures information about block arguments for then/else regions, while the simplified version does not. Currently the simplified printer is preferred by default, which means information about block arguments can be lost during a parse/print round-trip. This commit changes that behaviour so that the generic printer is preferred when block arguments have been provided, thus avoiding loss of information.

---
Full diff: https://github.com/llvm/llvm-project/pull/144859.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+20-5) 
- (added) mlir/test/Dialect/Tosa/controlflow.mlir (+72) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a32e4ccbed594..d79a8760b5498 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3649,17 +3649,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
 
 // parse and print of IfOp refer to the implementation of SCF dialect.
 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
+  OpAsmParser::UnresolvedOperand cond;
+  // Fallback to generic IfOp parser when no immediate conditional
+  // operand is provided.
+  if (!parser.parseOptionalOperand(cond).has_value()) {
+    return parser.parseGenericOperationAfterOpName(result);
+  }
+
   // Create the regions for 'then'.
   result.regions.reserve(2);
   Region *thenRegion = result.addRegion();
   Region *elseRegion = result.addRegion();
 
   auto &builder = parser.getBuilder();
-  OpAsmParser::UnresolvedOperand cond;
   // Create a i1 tensor type for the boolean condition.
   Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
-  if (parser.parseOperand(cond) ||
-      parser.resolveOperand(cond, i1Type, result.operands))
+  if (parser.resolveOperand(cond, i1Type, result.operands))
     return failure();
   // Parse optional results type list.
   if (parser.parseOptionalArrowTypeList(result.types))
@@ -3681,6 +3686,17 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
 }
 
 void IfOp::print(OpAsmPrinter &p) {
+  // The simplified syntax drops block-level arguments
+  // to the then/else regions. Fallback to the generic
+  // parser if these are found
+  Region &thenRegion = getThenGraph();
+  Region &elseRegion = getElseGraph();
+  if (!thenRegion.empty() && thenRegion.front().getNumArguments() > 0 &&
+      !elseRegion.empty() && elseRegion.front().getNumArguments() > 0) {
+    p.printGenericOp(*this, false);
+    return;
+  }
+
   bool printBlockTerminators = false;
 
   p << " " << getCondition();
@@ -3690,12 +3706,11 @@ void IfOp::print(OpAsmPrinter &p) {
     printBlockTerminators = true;
   }
   p << ' ';
-  p.printRegion(getThenGraph(),
+  p.printRegion(thenRegion,
                 /*printEntryBlockArgs=*/false,
                 /*printBlockTerminators=*/printBlockTerminators);
 
   // Print the 'else' regions if it exists and has a block.
-  auto &elseRegion = getElseGraph();
   if (!elseRegion.empty()) {
     p << " else ";
     p.printRegion(elseRegion,
diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir
new file mode 100644
index 0000000000000..3bc088d02b22c
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/controlflow.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: test_cond_if_generic_form
+// CHECK: %[[OUT:.*]] = tosa.cond_if(%[[COND:.*]], %[[IN0:.*]], %[[IN1:.*]]) ({
+// CHECK: ^bb0(%[[INA:.*]]: tensor<f32>, %[[INB:.*]]: tensor<f32>):
+// CHECK:    %[[THEN_TERM:.*]] = tosa.add %[[INA]], %[[INB]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK:    tosa.yield %[[THEN_TERM]] : tensor<f32>
+// CHECK: }, {
+// CHECK: ^bb0(%[[INC:.*]]: tensor<f32>, %[[IND:.*]]: tensor<f32>):
+// CHECK:    %[[ELSE_TERM:.*]] = tosa.sub %[[INC]], %[[IND]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK:    tosa.yield %[[ELSE_TERM]] : tensor<f32>
+// CHECK: }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK: return %[[OUT]] : tensor<f32>
+func.func @test_cond_if_generic_form(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+  %0 = tosa.cond_if(%arg2, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  },  {
+  ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+    %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_generic_form_no_block_arguments
+// COM: No block arguments means simplified form can be printed
+func.func @test_cond_if_generic_form_no_block_arguments(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+  // CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+  %0 = tosa.cond_if(%arg2) ({
+  ^bb0():
+    tosa.yield %arg0 : tensor<f32>
+  },  {
+  ^bb0():
+    tosa.yield %arg1 : tensor<f32>
+  }) : (tensor<i1>) -> tensor<f32>
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_simplified_form
+// CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+func.func @test_cond_if_simplified_form(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  } else {
+    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+    tosa.yield %1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_simplified_form_just_yield
+// CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+func.func @test_cond_if_simplified_form_just_yield(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+  %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+    tosa.yield %arg0 : tensor<f32>
+  } else {
+    tosa.yield %arg1 : tensor<f32>
+  }
+  return %0 : tensor<f32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/144859


More information about the Mlir-commits mailing list