[Mlir-commits] [mlir] [mlir][tosa] Check for isolated regions in `tosa.while_loop` (PR #144865)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 21 05:00:12 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tosa
Author: Luke Hutton (lhutton1)
<details>
<summary>Changes</summary>
Similarly to `tosa.cond_if`, this patch checks that the cond/body regions of `tosa.while_loop` are isolated from above. This is required since the specification requires all values used in the cond/body regions are explicitly declared within the regions.
~Note: this change is dependent on https://github.com/llvm/llvm-project/pull/143772~
---
Full diff: https://github.com/llvm/llvm-project/pull/144865.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+32-25)
- (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (+57)
``````````diff
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 48ec28acfaaaa..32b5fb63a6ece 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1200,6 +1200,28 @@ static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
});
}
+static bool isRegionIsolatedFromAbove(Region ®ionToCheck) {
+ bool noLiveInValue = true;
+ regionToCheck.walk([&noLiveInValue, ®ionToCheck](Operation *op) {
+ if (!isOpIsolatedWithinRegion(op, ®ionToCheck)) {
+ noLiveInValue = false;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ return noLiveInValue;
+}
+
+LogicalResult checkIsolatedRegion(Operation *op, Region ®ionToCheck,
+ StringRef regionName) {
+ if (isRegionIsolatedFromAbove(regionToCheck))
+ return success();
+ op->emitOpError()
+ << "is not conformant to the TOSA specification. It requires the '"
+ << regionName << "' region is isolated from above.\n";
+ return failure();
+}
+
bool checkErrorIfCondIf(Operation *op) {
auto ifOp = dyn_cast<tosa::IfOp>(op);
if (!ifOp)
@@ -1236,32 +1258,17 @@ bool checkErrorIfCondIf(Operation *op) {
// used in then/else regions (see 'simplified' example above), so it
// must be rewritten to use the generic syntax in order to be conformant
// to the specification.
+ return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
+ failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
+}
- // Returns true if the region uses no external input operands.
- auto isIsolatedRegion = [](Region ®ionToCheck) -> bool {
- bool noLiveInValue = true;
- regionToCheck.walk([&noLiveInValue, ®ionToCheck](Operation *opInRegion) {
- if (!isOpIsolatedWithinRegion(opInRegion, ®ionToCheck)) {
- noLiveInValue = false;
- return WalkResult::interrupt();
- }
- return WalkResult::advance();
- });
- return noLiveInValue;
- };
-
- auto checkIsolatedRegion = [&](Region ®ionToCheck,
- StringRef regionName) -> LogicalResult {
- if (isIsolatedRegion(regionToCheck))
- return success();
- op->emitOpError()
- << "is not conformant to the TOSA specification. It requires the '"
- << regionName << "' region is isolated from above.\n";
- return failure();
- };
+bool checkErrorIfWhileLoop(Operation *op) {
+ auto whileOp = dyn_cast<tosa::WhileOp>(op);
+ if (!whileOp)
+ return true;
- return failed(checkIsolatedRegion(ifOp.getThenGraph(), "then")) ||
- failed(checkIsolatedRegion(ifOp.getElseGraph(), "else"));
+ return failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
+ failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body"));
}
bool checkErrorIfScatter(Operation *op) {
@@ -1293,7 +1300,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
!checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
!checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
- !checkErrorIfScatter(op))
+ !checkErrorIfWhileLoop(op) || !checkErrorIfScatter(op))
return failure();
return success();
}
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 8924dd9885827..eb25011ff3a9d 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -280,3 +280,60 @@ func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f3
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
+
+// -----
+
+func.func @test_while_loop_cond_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'cond' region is isolated from above.}}
+ %1 = "tosa.while_loop"(%0) ({
+ ^bb0(%arg3: tensor<i32>):
+ %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+ tosa.yield %3 : tensor<i1>
+ }, {
+ ^bb0(%arg3: tensor<i32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %3 : tensor<i32>
+ }) : (tensor<i32>) -> (tensor<i32>)
+ return
+}
+
+// -----
+
+func.func @test_while_loop_body_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'body' region is isolated from above.}}
+ %1 = "tosa.while_loop"(%0) ({
+ ^bb0(%arg3: tensor<i32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tosa.greater_equal"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %4 = "tosa.logical_not"(%3) : (tensor<i1>) -> tensor<i1>
+ tosa.yield %4 : tensor<i1>
+ }, {
+ ^bb0(%arg3: tensor<i32>):
+ %3 = "tosa.add"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %3 : tensor<i32>
+ }) : (tensor<i32>) -> (tensor<i32>)
+ return
+}
+
+// -----
+
+// Check isolated while_loops are valid
+func.func @test_while_loop_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+ %2 = "tosa.greater_equal"(%arg3, %arg5) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+ "tosa.yield"(%3) : (tensor<i1>) -> ()
+ }, {
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ "tosa.yield"(%3, %arg4, %arg5) : (tensor<i32>, tensor<f32>, tensor<i32>) -> ()
+ }) : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<i32>, tensor<f32>, tensor<i32>)
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/144865
More information about the Mlir-commits
mailing list