[Mlir-commits] [mlir] [mlir][tosa] Check for isolated regions in `tosa.while_loop` (PR #144865)

Luke Hutton llvmlistbot at llvm.org
Thu Jun 19 03:12:18 PDT 2025


https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/144865

Similarly to `tosa.cond_if`, this patch checks that the cond/body regions of `tosa.while_loop` are isolated from above. This is required since the specification requires all values used in the cond/body regions are explicitly declared within the regions.

Note: this change is dependent on https://github.com/llvm/llvm-project/pull/143772

>From 25dc942e88ec0ddb8ad54505ff027d6b3be61f64 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 11 Jun 2025 09:04:53 +0000
Subject: [PATCH 1/2] [mlir][tosa] Fix check for isolated regions in
 `tosa.cond_if`

This commit fixes a check in the validation pass which intended
to validate whether a `tosa.cond_if` operation was conformant to
the specification. The specification requires all values used in
the then/else regions are explicitly declared within the regions.
This change checks that these regions are 'isolated from above',
to ensure this requirement is true.

Change-Id: I1b6eac1ed571e6b1eda4a58f0677c80e22977e58
---
 .../Tosa/Transforms/TosaValidation.cpp        | 68 ++++++++++++-------
 mlir/test/Dialect/Tosa/error_if_check.mlir    | 40 +++++++++--
 2 files changed, 77 insertions(+), 31 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index d33fc902de3a1..067ee7d5a5c5a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1193,12 +1193,11 @@ bool checkErrorIfPad(Operation *op) {
   return true;
 }
 
-// Returns true if the operation takes no input operands, excluding attributes.
-static bool isNullaryOperation(Operation *op) {
-  if (isa<tosa::ConstOp>(op) || isa<tosa::ConstShapeOp>(op) ||
-      isa<tosa::YieldOp>(op) || isa<tosa::VariableOp>(op))
-    return true;
-  return false;
+static bool isOpIsolatedFromAbove(Operation *op, Region *region) {
+  return llvm::all_of(op->getOperands(), [&](auto operand) {
+    Region *operandRegion = operand.getParentRegion();
+    return region->isAncestor(operandRegion);
+  });
 }
 
 bool checkErrorIfCondIf(Operation *op) {
@@ -1206,19 +1205,43 @@ bool checkErrorIfCondIf(Operation *op) {
   if (!ifOp)
     return true;
 
-  // Whether the types and shapes of operands between the input/output list and
-  // internal regions are validated by the operation verifier. However, with
-  // support for the simplified form - where redundant operand notations are
-  // omitted - is not conformant to the specification. According to the
-  // specification, all operands passed into an operation must be explicitly
-  // declared at each operation's structure. This code section verify that the
-  // operation's form complies with this requirement.
+  // Currently the dialect supports declaring cond_if operations that
+  // have then/else regions that reference values from outside these
+  // regions. According to the specification, all values used by the
+  // then/else regions must be explicitly declared within the regions.
+  // Therefore we must check that the then/else regions are
+  // "isolated from above", in order to be conformant to the
+  // specification.
+  //
+  // Note: the dialect currently supports two styles of syntax for
+  // declaring "cond_if" operations. We'll refer to these as follows:
+  //
+  // Generic:
+  // %0 = "tosa.cond_if"(%arg0, %arg1, %arg2) ({
+  //   ^bb0(%arg3, %arg4):
+  //     tosa.yield %arg3
+  // },  {
+  //   ^bb0(%arg3, %arg4):
+  //     tosa.yield %arg4
+  // })
+  //
+  // Simplified:
+  // %0 = tosa.cond_if %arg2 {
+  //   tosa.yield %arg0
+  // } else {
+  //   tosa.yield %arg1
+  // }
+  //
+  // Unfortunately, the simplified syntax does not encapsulate values
+  // used in then/else regions (see 'simplified' example above), so it
+  // must be rewritten to use the generic syntax in order to be conformant
+  // to the specification.
 
   // Returns true if the region uses no external input operands.
-  auto isNullaryRegion = [](Region &region) -> bool {
+  auto isIsolatedRegion = [](Region &region) -> bool {
     bool noLiveInValue = true;
-    region.walk([&noLiveInValue](Operation *op) {
-      if (!isNullaryOperation(op)) {
+    region.walk([&noLiveInValue, &region](Operation *op) {
+      if (!isOpIsolatedFromAbove(op, &region)) {
         noLiveInValue = false;
         return WalkResult::interrupt();
       }
@@ -1229,18 +1252,15 @@ bool checkErrorIfCondIf(Operation *op) {
 
   mlir::Region &thenGraph = ifOp.getThenGraph();
   mlir::Region &elseGraph = ifOp.getElseGraph();
-  bool isThenGraphNullaryRegion = isNullaryRegion(thenGraph);
-  bool isElseGraphNullaryRegion = isNullaryRegion(elseGraph);
-  bool isInputListEmpty = ifOp.getInputList().size() == 0;
+  bool isThenGraphIsolatedRegion = isIsolatedRegion(thenGraph);
+  bool isElseGraphIsolatedRegion = isIsolatedRegion(elseGraph);
 
-  if ((isInputListEmpty != isThenGraphNullaryRegion) ||
-      (isInputListEmpty != isElseGraphNullaryRegion)) {
+  if (!isThenGraphIsolatedRegion || !isElseGraphIsolatedRegion) {
     op->emitOpError()
-        << "the current simplified form is not strictly conformant to the "
-           "spec, please use the generic format\n";
+        << "is not conformant to the TOSA specification. It requires the "
+           "then/else regions are isolated from above.\n";
     return false;
   }
-
   return true;
 }
 
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 1f25132d6bcf3..00c891d4afaa0 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -227,15 +227,41 @@ func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32>
 }
 
 // -----
-// CHECK-LABEL: cond_if_simplified_form
-func.func @test_cond_if_simplified_form(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
-  // expected-error at +1 {{'tosa.cond_if' op the current simplified form is not strictly conformant to the spec, please use the generic format}}
+
+func.func @test_cond_if_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+    // expected-error at +1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the then/else regions are isolated from above.}}
+    %0 = "tosa.cond_if"(%arg2) ({
+      ^bb0():
+        tosa.yield %arg0 : tensor<f32>
+      },  {
+      ^bb0():
+        tosa.yield %arg1 : tensor<f32>
+      }) : (tensor<i1>) -> tensor<f32>
+    return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  // expected-error at +1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the then/else regions are isolated from above.}}
   %0 = tosa.cond_if %arg2 -> (tensor<f32>) {
-    %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
-    tosa.yield %1 : tensor<f32>
+    tosa.yield %arg0 : tensor<f32>
   } else {
-    %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
-    tosa.yield %1 : tensor<f32>
+    tosa.yield %arg1 : tensor<f32>
   }
   return %0 : tensor<f32>
 }
+
+// -----
+
+// COM: Check isolated cond_if's are valid
+func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+  %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+    ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+      tosa.yield %arg3 : tensor<f32>
+    },  {
+    ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+      tosa.yield %arg4 : tensor<f32>
+    }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+  return %0 : tensor<f32>
+}

>From 7bd13d2e0528fe22c7d1ff584dbc8daa10f17184 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Mon, 16 Jun 2025 10:28:08 +0000
Subject: [PATCH 2/2] [mlir][tosa] Check for isolated regions in
 `tosa.while_loop`

Similarly to `tosa.cond_if`, this patch checks that the cond/body
regions of `tosa.while_loop` are isolated from above. This is required
since the specification requires all values used in the cond/body
regions are explicitly declared within the regions.

Change-Id: Ia7396b9811db54805ec33befd24ab97d1b605905
---
 .../Tosa/Transforms/TosaValidation.cpp        | 60 ++++++++++++-------
 mlir/test/Dialect/Tosa/error_if_check.mlir    | 38 ++++++++++++
 2 files changed, 77 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 067ee7d5a5c5a..30e85ba92494c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1193,13 +1193,25 @@ bool checkErrorIfPad(Operation *op) {
   return true;
 }
 
-static bool isOpIsolatedFromAbove(Operation *op, Region *region) {
+static bool isOpIsolatedFromAbove(Operation *op, Region &region) {
   return llvm::all_of(op->getOperands(), [&](auto operand) {
     Region *operandRegion = operand.getParentRegion();
-    return region->isAncestor(operandRegion);
+    return region.isAncestor(operandRegion);
   });
 }
 
+static bool isRegionIsolatedFromAbove(Region &region) {
+  bool noLiveInValue = true;
+  region.walk([&noLiveInValue, &region](Operation *op) {
+    if (!isOpIsolatedFromAbove(op, region)) {
+      noLiveInValue = false;
+      return WalkResult::interrupt();
+    }
+    return WalkResult::advance();
+  });
+  return noLiveInValue;
+}
+
 bool checkErrorIfCondIf(Operation *op) {
   auto ifOp = dyn_cast<tosa::IfOp>(op);
   if (!ifOp)
@@ -1236,24 +1248,10 @@ bool checkErrorIfCondIf(Operation *op) {
   // used in then/else regions (see 'simplified' example above), so it
   // must be rewritten to use the generic syntax in order to be conformant
   // to the specification.
-
-  // Returns true if the region uses no external input operands.
-  auto isIsolatedRegion = [](Region &region) -> bool {
-    bool noLiveInValue = true;
-    region.walk([&noLiveInValue, &region](Operation *op) {
-      if (!isOpIsolatedFromAbove(op, &region)) {
-        noLiveInValue = false;
-        return WalkResult::interrupt();
-      }
-      return WalkResult::advance();
-    });
-    return noLiveInValue;
-  };
-
-  mlir::Region &thenGraph = ifOp.getThenGraph();
-  mlir::Region &elseGraph = ifOp.getElseGraph();
-  bool isThenGraphIsolatedRegion = isIsolatedRegion(thenGraph);
-  bool isElseGraphIsolatedRegion = isIsolatedRegion(elseGraph);
+  Region &thenGraph = ifOp.getThenGraph();
+  Region &elseGraph = ifOp.getElseGraph();
+  bool isThenGraphIsolatedRegion = isRegionIsolatedFromAbove(thenGraph);
+  bool isElseGraphIsolatedRegion = isRegionIsolatedFromAbove(elseGraph);
 
   if (!isThenGraphIsolatedRegion || !isElseGraphIsolatedRegion) {
     op->emitOpError()
@@ -1264,10 +1262,30 @@ bool checkErrorIfCondIf(Operation *op) {
   return true;
 }
 
+bool checkErrorIfWhileLoop(Operation *op) {
+  auto whileOp = dyn_cast<tosa::WhileOp>(op);
+  if (!whileOp)
+    return true;
+
+  Region &condGraph = whileOp.getCondGraph();
+  Region &bodyGraph = whileOp.getBodyGraph();
+  bool isCondGraphIsolatedRegion = isRegionIsolatedFromAbove(condGraph);
+  bool isBodyGraphIsolatedRegion = isRegionIsolatedFromAbove(bodyGraph);
+
+  if (!isCondGraphIsolatedRegion || !isBodyGraphIsolatedRegion) {
+    op->emitOpError()
+        << "is not conformant to the TOSA specification. It requires the "
+           "cond/body regions are isolated from above.\n";
+    return false;
+  }
+  return true;
+}
+
 LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
   if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
       !checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
-      !checkErrorIfPad(op) || !checkErrorIfCondIf(op))
+      !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
+      !checkErrorIfWhileLoop(op))
     return failure();
   return success();
 }
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 00c891d4afaa0..77830c7be2e9e 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -265,3 +265,41 @@ func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f3
     }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
   return %0 : tensor<f32>
 }
+
+// -----
+
+func.func @test_while_loop_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  // expected-error at +1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the cond/body regions are isolated from above.}}
+  %1 = "tosa.while_loop"(%0) ({
+  ^bb0(%arg3: tensor<i32>):
+    %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+    tosa.yield %3 : tensor<i1>
+  },  {
+  ^bb0(%arg3: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    tosa.yield %3 : tensor<i32>
+  }) : (tensor<i32>) -> (tensor<i32>)
+  return
+}
+
+// -----
+
+// COM: Check isolated while_loops are valid
+func.func @test_while_loop_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<i32>) {
+  %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+  %1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+    %2 = "tosa.greater_equal"(%arg3, %arg5) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+    %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+    "tosa.yield"(%3) : (tensor<i1>) -> ()
+  },  {
+  ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+    %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+    %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+    "tosa.yield"(%3, %arg4, %arg5) : (tensor<i32>, tensor<f32>, tensor<i32>) -> ()
+  }) : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<i32>, tensor<f32>, tensor<i32>)
+  return
+}



More information about the Mlir-commits mailing list