[Mlir-commits] [mlir] [mlir][tosa] Print generic `cond_if` when block arguments are present (PR #144859)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 19 02:35:04 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
The generic printer/parser captures information about block arguments for then/else regions, while the simplified version does not. Currently the simplified printer is preferred by default, which means information about block arguments can be lost during a parse/print round-trip. This commit changes that behaviour so that the generic printer is preferred when block arguments have been provided, thus avoiding loss of information.
---
Full diff: https://github.com/llvm/llvm-project/pull/144859.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+20-5)
- (added) mlir/test/Dialect/Tosa/controlflow.mlir (+72)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index a32e4ccbed594..d79a8760b5498 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3649,17 +3649,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
// parse and print of IfOp refer to the implementation of SCF dialect.
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
+ OpAsmParser::UnresolvedOperand cond;
+ // Fallback to generic IfOp parser when no immediate conditional
+ // operand is provided.
+ if (!parser.parseOptionalOperand(cond).has_value()) {
+ return parser.parseGenericOperationAfterOpName(result);
+ }
+
// Create the regions for 'then'.
result.regions.reserve(2);
Region *thenRegion = result.addRegion();
Region *elseRegion = result.addRegion();
auto &builder = parser.getBuilder();
- OpAsmParser::UnresolvedOperand cond;
// Create a i1 tensor type for the boolean condition.
Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
- if (parser.parseOperand(cond) ||
- parser.resolveOperand(cond, i1Type, result.operands))
+ if (parser.resolveOperand(cond, i1Type, result.operands))
return failure();
// Parse optional results type list.
if (parser.parseOptionalArrowTypeList(result.types))
@@ -3681,6 +3686,17 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
}
void IfOp::print(OpAsmPrinter &p) {
+ // The simplified syntax drops block-level arguments
+ // to the then/else regions. Fallback to the generic
+ // parser if these are found
+ Region &thenRegion = getThenGraph();
+ Region &elseRegion = getElseGraph();
+ if (!thenRegion.empty() && thenRegion.front().getNumArguments() > 0 &&
+ !elseRegion.empty() && elseRegion.front().getNumArguments() > 0) {
+ p.printGenericOp(*this, false);
+ return;
+ }
+
bool printBlockTerminators = false;
p << " " << getCondition();
@@ -3690,12 +3706,11 @@ void IfOp::print(OpAsmPrinter &p) {
printBlockTerminators = true;
}
p << ' ';
- p.printRegion(getThenGraph(),
+ p.printRegion(thenRegion,
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/printBlockTerminators);
// Print the 'else' regions if it exists and has a block.
- auto &elseRegion = getElseGraph();
if (!elseRegion.empty()) {
p << " else ";
p.printRegion(elseRegion,
diff --git a/mlir/test/Dialect/Tosa/controlflow.mlir b/mlir/test/Dialect/Tosa/controlflow.mlir
new file mode 100644
index 0000000000000..3bc088d02b22c
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/controlflow.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: test_cond_if_generic_form
+// CHECK: %[[OUT:.*]] = tosa.cond_if(%[[COND:.*]], %[[IN0:.*]], %[[IN1:.*]]) ({
+// CHECK: ^bb0(%[[INA:.*]]: tensor<f32>, %[[INB:.*]]: tensor<f32>):
+// CHECK: %[[THEN_TERM:.*]] = tosa.add %[[INA]], %[[INB]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK: tosa.yield %[[THEN_TERM]] : tensor<f32>
+// CHECK: }, {
+// CHECK: ^bb0(%[[INC:.*]]: tensor<f32>, %[[IND:.*]]: tensor<f32>):
+// CHECK: %[[ELSE_TERM:.*]] = tosa.sub %[[INC]], %[[IND]] : (tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK: tosa.yield %[[ELSE_TERM]] : tensor<f32>
+// CHECK: }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+// CHECK: return %[[OUT]] : tensor<f32>
+func.func @test_cond_if_generic_form(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+ %0 = tosa.cond_if(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.add %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %1 = tosa.sub %arg3, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_generic_form_no_block_arguments
+// COM: No block arguments means simplified form can be printed
+func.func @test_cond_if_generic_form_no_block_arguments(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+ // CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+ %0 = tosa.cond_if(%arg2) ({
+ ^bb0():
+ tosa.yield %arg0 : tensor<f32>
+ }, {
+ ^bb0():
+ tosa.yield %arg1 : tensor<f32>
+ }) : (tensor<i1>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_simplified_form
+// CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+func.func @test_cond_if_simplified_form(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+ %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ } else {
+ %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
+
+// -----
+
+// CHECK-LABEL: test_cond_if_simplified_form_just_yield
+// CHECK: tosa.cond_if %arg2 -> (tensor<f32>)
+func.func @test_cond_if_simplified_form_just_yield(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<i1>) -> tensor<f32> {
+ %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
+ tosa.yield %arg0 : tensor<f32>
+ } else {
+ tosa.yield %arg1 : tensor<f32>
+ }
+ return %0 : tensor<f32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/144859
More information about the Mlir-commits
mailing list