[Mlir-commits] [mlir] [mlir][tosa] Check for isolated regions in `tosa.while_loop` (PR #144865)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 21 05:00:12 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

<details>
<summary>Changes</summary>

Similarly to `tosa.cond_if`, this patch checks that the cond/body regions of `tosa.while_loop` are isolated from above. This is required since the specification requires all values used in the cond/body regions are explicitly declared within the regions.

~Note: this change is dependent on https://github.com/llvm/llvm-project/pull/143772~

---
Full diff: https://github.com/llvm/llvm-project/pull/144865.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+32-25) 
- (modified) mlir/test/Dialect/Tosa/error_if_check.mlir (+57) 


``````````diff
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 48ec28acfaaaa..32b5fb63a6ece 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1200,6 +1200,28 @@ static bool isOpIsolatedWithinRegion(Operation *op, Region *region) {
   });
 }
 
+static bool isRegionIsolatedFromAbove(Region &regionToCheck) {
+  bool noLiveInValue = true;
+  regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *op) {
+    if (!isOpIsolatedWithinRegion(op, &regionToCheck)) {
+      noLiveInValue = false;
+      return WalkResult::interrupt();
+    }
+    return WalkResult::advance();
+  });
+  return noLiveInValue;
+}
+
+LogicalResult checkIsolatedRegion(Operation *op, Region &regionToCheck,
+                                  StringRef regionName) {
+  if (isRegionIsolatedFromAbove(regionToCheck))
+    return success();
+  op->emitOpError()
+      << "is not conformant to the TOSA specification. It requires the '"
+      << regionName << "' region is isolated from above.\n";
+  return failure();
+}
+
 bool checkErrorIfCondIf(Operation *op) {
   auto ifOp = dyn_cast<tosa::IfOp>(op);
   if (!ifOp)
@@ -1236,32 +1258,17 @@ bool checkErrorIfCondIf(Operation *op) {
   // used in then/else regions (see 'simplified' example above), so it
   // must be rewritten to use the generic syntax in order to be conformant
   // to the specification.
+  return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
+         failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
+}
 
-  // Returns true if the region uses no external input operands.
-  auto isIsolatedRegion = [](Region &regionToCheck) -> bool {
-    bool noLiveInValue = true;
-    regionToCheck.walk([&noLiveInValue, &regionToCheck](Operation *opInRegion) {
-      if (!isOpIsolatedWithinRegion(opInRegion, &regionToCheck)) {
-        noLiveInValue = false;
-        return WalkResult::interrupt();
-      }
-      return WalkResult::advance();
-    });
-    return noLiveInValue;
-  };
-
-  auto checkIsolatedRegion = [&](Region &regionToCheck,
-                                 StringRef regionName) -> LogicalResult {
-    if (isIsolatedRegion(regionToCheck))
-      return success();
-    op->emitOpError()
-        << "is not conformant to the TOSA specification. It requires the '"
-        << regionName << "' region is isolated from above.\n";
-    return failure();
-  };
+bool checkErrorIfWhileLoop(Operation *op) {
+  auto whileOp = dyn_cast<tosa::WhileOp>(op);
+  if (!whileOp)
+    return true;
 
-  return failed(checkIsolatedRegion(ifOp.getThenGraph(), "then")) ||
-         failed(checkIsolatedRegion(ifOp.getElseGraph(), "else"));
+  return failed(checkIsolatedRegion(op, whileOp.getCondGraph(), "cond")) ||
+         failed(checkIsolatedRegion(op, whileOp.getBodyGraph(), "body"));
 }
 
 bool checkErrorIfScatter(Operation *op) {
@@ -1293,7 +1300,7 @@ LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
   if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
       !checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
       !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
-      !checkErrorIfScatter(op))
+      !checkErrorIfWhileLoop(op) || !checkErrorIfScatter(op))
     return failure();
   return success();
 }
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 8924dd9885827..eb25011ff3a9d 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -280,3 +280,60 @@ func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f3
     }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return %0 : tensor<f32>
 }
+
+// -----
+
+func.func @test_while_loop_cond_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'cond' region is isolated from above.}}
+  %1 = "tosa.while_loop"(%0) ({
+  ^bb0(%arg3: tensor<i32>):
+    %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+    tosa.yield %3 : tensor<i1>
+  },  {
+  ^bb0(%arg3: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3 : tensor<i32>
+  }) : (tensor<i32>) -> (tensor<i32>)
+  return
+}
+
+// -----
+
+func.func @test_while_loop_body_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the 'body' region is isolated from above.}}
+  %1 = "tosa.while_loop"(%0) ({
+  ^bb0(%arg3: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.greater_equal"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %4 = "tosa.logical_not"(%3) : (tensor<i1>) -> tensor<i1>
+    tosa.yield %4 : tensor<i1>
+  },  {
+  ^bb0(%arg3: tensor<i32>):
+    %3 = "tosa.add"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3 : tensor<i32>
+  }) : (tensor<i32>) -> (tensor<i32>)
+  return
+}
+
+// -----
+
+// Check isolated while_loops are valid
+func.func @test_while_loop_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+    %2 = "tosa.greater_equal"(%arg3, %arg5) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+    "tosa.yield"(%3) : (tensor<i1>) -> ()
+  },  {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    "tosa.yield"(%3, %arg4, %arg5) : (tensor<i32>, tensor<f32>, tensor<i32>) -> ()
+  }) : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<i32>, tensor<f32>, tensor<i32>)
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/144865


More information about the Mlir-commits mailing list