[Mlir-commits] [mlir] [mlir][tosa] Fix check for isolated regions in `tosa.cond_if` (PR #143772)
Luke Hutton
llvmlistbot at llvm.org
Fri Jul 18 07:35:07 PDT 2025
https://github.com/lhutton1 updated https://github.com/llvm/llvm-project/pull/143772
>From 68d93e575c7340ea39064fac979d8eeebc2ad3c2 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 11 Jun 2025 09:04:53 +0000
Subject: [PATCH 1/2] [mlir][tosa] Fix check for isolated regions in
`tosa.cond_if`
This commit fixes a check in the validation pass which intended
to validate whether a `tosa.cond_if` operation was conformant to
the specification. The specification requires all values used in
the then/else regions are explicitly declared within the regions.
This change checks that these regions are 'isolated from above',
to ensure this requirement is true.
Change-Id: I1b6eac1ed571e6b1eda4a58f0677c80e22977e58
---
.../Tosa/Transforms/TosaValidation.cpp | 68 ++++++++++++-------
mlir/test/Dialect/Tosa/error_if_check.mlir | 40 +++++++++--
2 files changed, 77 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 3f27849b8c90c..4c321a5d5b181 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1193,12 +1193,11 @@ bool checkErrorIfPad(Operation *op) {
return true;
}
-// Returns true if the operation takes no input operands, excluding attributes.
-static bool isNullaryOperation(Operation *op) {
- if (isa<tosa::ConstOp>(op) || isa<tosa::ConstShapeOp>(op) ||
- isa<tosa::YieldOp>(op) || isa<tosa::VariableOp>(op))
- return true;
- return false;
+static bool isOpIsolatedFromAbove(Operation *op, Region *region) {
+ return llvm::all_of(op->getOperands(), [&](auto operand) {
+ Region *operandRegion = operand.getParentRegion();
+ return region->isAncestor(operandRegion);
+ });
}
bool checkErrorIfCondIf(Operation *op) {
@@ -1206,19 +1205,43 @@ bool checkErrorIfCondIf(Operation *op) {
if (!ifOp)
return true;
- // Whether the types and shapes of operands between the input/output list and
- // internal regions are validated by the operation verifier. However, with
- // support for the simplified form - where redundant operand notations are
- // omitted - is not conformant to the specification. According to the
- // specification, all operands passed into an operation must be explicitly
- // declared at each operation's structure. This code section verify that the
- // operation's form complies with this requirement.
+ // Currently the dialect supports declaring cond_if operations that
+ // have then/else regions that reference values from outside these
+ // regions. According to the specification, all values used by the
+ // then/else regions must be explicitly declared within the regions.
+ // Therefore we must check that the then/else regions are
+ // "isolated from above", in order to be conformant to the
+ // specification.
+ //
+ // Note: the dialect currently supports two styles of syntax for
+ // declaring "cond_if" operations. We'll refer to these as follows:
+ //
+ // Generic:
+ // %0 = "tosa.cond_if"(%arg0, %arg1, %arg2) ({
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg3
+ // }, {
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg4
+ // })
+ //
+ // Simplified:
+ // %0 = tosa.cond_if %arg2 {
+ // tosa.yield %arg0
+ // } else {
+ // tosa.yield %arg1
+ // }
+ //
+ // Unfortunately, the simplified syntax does not encapsulate values
+ // used in then/else regions (see 'simplified' example above), so it
+ // must be rewritten to use the generic syntax in order to be conformant
+ // to the specification.
// Returns true if the region uses no external input operands.
- auto isNullaryRegion = [](Region ®ion) -> bool {
+ auto isIsolatedRegion = [](Region ®ion) -> bool {
bool noLiveInValue = true;
- region.walk([&noLiveInValue](Operation *op) {
- if (!isNullaryOperation(op)) {
+ region.walk([&noLiveInValue, ®ion](Operation *op) {
+ if (!isOpIsolatedFromAbove(op, ®ion)) {
noLiveInValue = false;
return WalkResult::interrupt();
}
@@ -1229,18 +1252,15 @@ bool checkErrorIfCondIf(Operation *op) {
mlir::Region &thenGraph = ifOp.getThenGraph();
mlir::Region &elseGraph = ifOp.getElseGraph();
- bool isThenGraphNullaryRegion = isNullaryRegion(thenGraph);
- bool isElseGraphNullaryRegion = isNullaryRegion(elseGraph);
- bool isInputListEmpty = ifOp.getInputList().size() == 0;
+ bool isThenGraphIsolatedRegion = isIsolatedRegion(thenGraph);
+ bool isElseGraphIsolatedRegion = isIsolatedRegion(elseGraph);
- if ((isInputListEmpty != isThenGraphNullaryRegion) ||
- (isInputListEmpty != isElseGraphNullaryRegion)) {
+ if (!isThenGraphIsolatedRegion || !isElseGraphIsolatedRegion) {
op->emitOpError()
- << "the current simplified form is not strictly conformant to the "
- "spec, please use the generic format\n";
+ << "is not conformant to the TOSA specification. It requires the "
+ "then/else regions are isolated from above.\n";
return false;
}
-
return true;
}
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 1f25132d6bcf3..00c891d4afaa0 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -227,15 +227,41 @@ func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32>
}
// -----
-// CHECK-LABEL: cond_if_simplified_form
-func.func @test_cond_if_simplified_form(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
- // expected-error at +1 {{'tosa.cond_if' op the current simplified form is not strictly conformant to the spec, please use the generic format}}
+
+func.func @test_cond_if_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the then/else regions are isolated from above.}}
+ %0 = "tosa.cond_if"(%arg2) ({
+ ^bb0():
+ tosa.yield %arg0 : tensor<f32>
+ }, {
+ ^bb0():
+ tosa.yield %arg1 : tensor<f32>
+ }) : (tensor<i1>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the then/else regions are isolated from above.}}
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
- %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
- tosa.yield %1 : tensor<f32>
+ tosa.yield %arg0 : tensor<f32>
} else {
- %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
- tosa.yield %1 : tensor<f32>
+ tosa.yield %arg1 : tensor<f32>
}
return %0 : tensor<f32>
}
+
+// -----
+
+// COM: Check isolated cond_if's are valid
+func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ tosa.yield %arg3 : tensor<f32>
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ tosa.yield %arg4 : tensor<f32>
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
>From fd0487ec6f3ceb367e498fb9fc0bbf0f124f4d55 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Fri, 18 Jul 2025 14:03:53 +0000
Subject: [PATCH 2/2] Address review comments
- check then/else region isolation independently
- renamed isOpIsolatedFromAbove to isOpIsolatedWithinRegion
- remove "COM" prefix in lit test
Change-Id: I5988ba2e75c9aa81c57321628d49d35634847f03
---
.../Tosa/Transforms/TosaValidation.cpp | 31 ++++++++--------
mlir/test/Dialect/Tosa/error_if_check.mlir | 35 +++++++++++++------
2 files changed, 40 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 4c321a5d5b181..3f941d10bfc59 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1193,10 +1193,10 @@ bool checkErrorIfPad(Operation *op) {
return true;
}
-static bool isOpIsolatedFromAbove(Operation *op, Region *region) {
+static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
return llvm::all_of(op->getOperands(), [&](auto operand) {
Region *operandRegion = operand.getParentRegion();
- return region->isAncestor(operandRegion);
+ return operandRegion && region->isAncestor(operandRegion);
});
}
@@ -1238,10 +1238,10 @@ bool checkErrorIfCondIf(Operation *op) {
// to the specification.
// Returns true if the region uses no external input operands.
- auto isIsolatedRegion = [](Region ®ion) -> bool {
+ auto isIsolatedRegion = [](Region ®ionToCheck) -> bool {
bool noLiveInValue = true;
- region.walk([&noLiveInValue, ®ion](Operation *op) {
- if (!isOpIsolatedFromAbove(op, ®ion)) {
+ regionToCheck.walk([&noLiveInValue, ®ionToCheck](Operation *opInRegion) {
+ if (!isOpIsolatedWithinRegion(opInRegion, ®ionToCheck)) {
noLiveInValue = false;
return WalkResult::interrupt();
}
@@ -1250,18 +1250,17 @@ bool checkErrorIfCondIf(Operation *op) {
return noLiveInValue;
};
- mlir::Region &thenGraph = ifOp.getThenGraph();
- mlir::Region &elseGraph = ifOp.getElseGraph();
- bool isThenGraphIsolatedRegion = isIsolatedRegion(thenGraph);
- bool isElseGraphIsolatedRegion = isIsolatedRegion(elseGraph);
-
- if (!isThenGraphIsolatedRegion || !isElseGraphIsolatedRegion) {
+ auto checkIsolatedRegion = [&](Region ®ionToCheck, StringRef regionName) -> LogicalResult {
+ if (isIsolatedRegion(regionToCheck))
+ return success();
op->emitOpError()
- << "is not conformant to the TOSA specification. It requires the "
- "then/else regions are isolated from above.\n";
- return false;
- }
- return true;
+ << "is not conformant to the TOSA specification. It requires the '"
+ << regionName << "' region is isolated from above.\n";
+ return failure();
+ };
+
+ return failed(checkIsolatedRegion(ifOp.getThenGraph(), "then")) ||
+ failed(checkIsolatedRegion(ifOp.getElseGraph(), "else"));
}
bool checkErrorIfScatter(Operation *op) {
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 00c891d4afaa0..8924dd9885827 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -228,22 +228,37 @@ func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32>
// -----
-func.func @test_cond_if_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
- // expected-error at +1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the then/else regions are isolated from above.}}
- %0 = "tosa.cond_if"(%arg2) ({
- ^bb0():
- tosa.yield %arg0 : tensor<f32>
- }, {
- ^bb0():
+func.func @test_cond_if_then_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}}
+ %0 = "tosa.cond_if"(%arg2, %arg1) ({
+ ^bb0(%arg3: tensor<f32>):
tosa.yield %arg1 : tensor<f32>
- }) : (tensor<i1>) -> tensor<f32>
+ }, {
+ ^bb0(%arg3: tensor<f32>):
+ tosa.yield %arg3 : tensor<f32>
+ }) : (tensor<i1>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// -----
+func.func @test_cond_if_else_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'else' region is isolated from above.}}
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ tosa.yield %arg3 : tensor<f32>
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ %add = tosa.add %arg0, %arg4 : (tensor<f32>, tensor<f32>) -> tensor<f32>
+ tosa.yield %add : tensor<f32>
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
- // expected-error at +1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the then/else regions are isolated from above.}}
+ // expected-error at +1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the 'then' region is isolated from above.}}
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
tosa.yield %arg0 : tensor<f32>
} else {
@@ -254,7 +269,7 @@ func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f3
// -----
-// COM: Check isolated cond_if's are valid
+// Check isolated cond_if's are valid
func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
%0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
More information about the Mlir-commits
mailing list