[Mlir-commits] [mlir] [mlir][tosa] Check for isolated regions in `tosa.while_loop` (PR #144865)
Luke Hutton
llvmlistbot at llvm.org
Thu Jun 19 03:12:18 PDT 2025
https://github.com/lhutton1 created https://github.com/llvm/llvm-project/pull/144865
Similarly to `tosa.cond_if`, this patch checks that the cond/body regions of `tosa.while_loop` are isolated from above. This is required since the specification requires all values used in the cond/body regions are explicitly declared within the regions.
Note: this change is dependent on https://github.com/llvm/llvm-project/pull/143772
>From 25dc942e88ec0ddb8ad54505ff027d6b3be61f64 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Wed, 11 Jun 2025 09:04:53 +0000
Subject: [PATCH 1/2] [mlir][tosa] Fix check for isolated regions in
`tosa.cond_if`
This commit fixes a check in the validation pass which intended
to validate whether a `tosa.cond_if` operation was conformant to
the specification. The specification requires all values used in
the then/else regions are explicitly declared within the regions.
This change checks that these regions are 'isolated from above',
to ensure this requirement is true.
Change-Id: I1b6eac1ed571e6b1eda4a58f0677c80e22977e58
---
.../Tosa/Transforms/TosaValidation.cpp | 68 ++++++++++++-------
mlir/test/Dialect/Tosa/error_if_check.mlir | 40 +++++++++--
2 files changed, 77 insertions(+), 31 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index d33fc902de3a1..067ee7d5a5c5a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1193,12 +1193,11 @@ bool checkErrorIfPad(Operation *op) {
return true;
}
-// Returns true if the operation takes no input operands, excluding attributes.
-static bool isNullaryOperation(Operation *op) {
- if (isa<tosa::ConstOp>(op) || isa<tosa::ConstShapeOp>(op) ||
- isa<tosa::YieldOp>(op) || isa<tosa::VariableOp>(op))
- return true;
- return false;
+static bool isOpIsolatedFromAbove(Operation *op, Region *region) {
+ return llvm::all_of(op->getOperands(), [&](auto operand) {
+ Region *operandRegion = operand.getParentRegion();
+ return region->isAncestor(operandRegion);
+ });
}
bool checkErrorIfCondIf(Operation *op) {
@@ -1206,19 +1205,43 @@ bool checkErrorIfCondIf(Operation *op) {
if (!ifOp)
return true;
- // Whether the types and shapes of operands between the input/output list and
- // internal regions are validated by the operation verifier. However, with
- // support for the simplified form - where redundant operand notations are
- // omitted - is not conformant to the specification. According to the
- // specification, all operands passed into an operation must be explicitly
- // declared at each operation's structure. This code section verify that the
- // operation's form complies with this requirement.
+ // Currently the dialect supports declaring cond_if operations that
+ // have then/else regions that reference values from outside these
+ // regions. According to the specification, all values used by the
+ // then/else regions must be explicitly declared within the regions.
+ // Therefore we must check that the then/else regions are
+ // "isolated from above", in order to be conformant to the
+ // specification.
+ //
+ // Note: the dialect currently supports two styles of syntax for
+ // declaring "cond_if" operations. We'll refer to these as follows:
+ //
+ // Generic:
+ // %0 = "tosa.cond_if"(%arg0, %arg1, %arg2) ({
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg3
+ // }, {
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg4
+ // })
+ //
+ // Simplified:
+ // %0 = tosa.cond_if %arg2 {
+ // tosa.yield %arg0
+ // } else {
+ // tosa.yield %arg1
+ // }
+ //
+ // Unfortunately, the simplified syntax does not encapsulate values
+ // used in then/else regions (see 'simplified' example above), so it
+ // must be rewritten to use the generic syntax in order to be conformant
+ // to the specification.
// Returns true if the region uses no external input operands.
- auto isNullaryRegion = [](Region ®ion) -> bool {
+ auto isIsolatedRegion = [](Region ®ion) -> bool {
bool noLiveInValue = true;
- region.walk([&noLiveInValue](Operation *op) {
- if (!isNullaryOperation(op)) {
+ region.walk([&noLiveInValue, ®ion](Operation *op) {
+ if (!isOpIsolatedFromAbove(op, ®ion)) {
noLiveInValue = false;
return WalkResult::interrupt();
}
@@ -1229,18 +1252,15 @@ bool checkErrorIfCondIf(Operation *op) {
mlir::Region &thenGraph = ifOp.getThenGraph();
mlir::Region &elseGraph = ifOp.getElseGraph();
- bool isThenGraphNullaryRegion = isNullaryRegion(thenGraph);
- bool isElseGraphNullaryRegion = isNullaryRegion(elseGraph);
- bool isInputListEmpty = ifOp.getInputList().size() == 0;
+ bool isThenGraphIsolatedRegion = isIsolatedRegion(thenGraph);
+ bool isElseGraphIsolatedRegion = isIsolatedRegion(elseGraph);
- if ((isInputListEmpty != isThenGraphNullaryRegion) ||
- (isInputListEmpty != isElseGraphNullaryRegion)) {
+ if (!isThenGraphIsolatedRegion || !isElseGraphIsolatedRegion) {
op->emitOpError()
- << "the current simplified form is not strictly conformant to the "
- "spec, please use the generic format\n";
+ << "is not conformant to the TOSA specification. It requires the "
+ "then/else regions are isolated from above.\n";
return false;
}
-
return true;
}
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 1f25132d6bcf3..00c891d4afaa0 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -227,15 +227,41 @@ func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32>
}
// -----
-// CHECK-LABEL: cond_if_simplified_form
-func.func @test_cond_if_simplified_form(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
- // expected-error at +1 {{'tosa.cond_if' op the current simplified form is not strictly conformant to the spec, please use the generic format}}
+
+func.func @test_cond_if_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the then/else regions are isolated from above.}}
+ %0 = "tosa.cond_if"(%arg2) ({
+ ^bb0():
+ tosa.yield %arg0 : tensor<f32>
+ }, {
+ ^bb0():
+ tosa.yield %arg1 : tensor<f32>
+ }) : (tensor<i1>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
+
+// -----
+
+func.func @test_cond_if_simplified_form_not_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ // expected-error at +1 {{'tosa.cond_if' op is not conformant to the TOSA specification. It requires the then/else regions are isolated from above.}}
%0 = tosa.cond_if %arg2 -> (tensor<f32>) {
- %1 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
- tosa.yield %1 : tensor<f32>
+ tosa.yield %arg0 : tensor<f32>
} else {
- %1 = tosa.sub %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
- tosa.yield %1 : tensor<f32>
+ tosa.yield %arg1 : tensor<f32>
}
return %0 : tensor<f32>
}
+
+// -----
+
+// COM: Check isolated cond_if's are valid
+func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
+ %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ tosa.yield %arg3 : tensor<f32>
+ }, {
+ ^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
+ tosa.yield %arg4 : tensor<f32>
+ }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
+ return %0 : tensor<f32>
+}
>From 7bd13d2e0528fe22c7d1ff584dbc8daa10f17184 Mon Sep 17 00:00:00 2001
From: Luke Hutton <luke.hutton at arm.com>
Date: Mon, 16 Jun 2025 10:28:08 +0000
Subject: [PATCH 2/2] [mlir][tosa] Check for isolated regions in
`tosa.while_loop`
Similarly to `tosa.cond_if`, this patch checks that the cond/body
regions of `tosa.while_loop` are isolated from above. This is required
since the specification requires all values used in the cond/body
regions are explicitly declared within the regions.
Change-Id: Ia7396b9811db54805ec33befd24ab97d1b605905
---
.../Tosa/Transforms/TosaValidation.cpp | 60 ++++++++++++-------
mlir/test/Dialect/Tosa/error_if_check.mlir | 38 ++++++++++++
2 files changed, 77 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 067ee7d5a5c5a..30e85ba92494c 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1193,13 +1193,25 @@ bool checkErrorIfPad(Operation *op) {
return true;
}
-static bool isOpIsolatedFromAbove(Operation *op, Region *region) {
+static bool isOpIsolatedFromAbove(Operation *op, Region ®ion) {
return llvm::all_of(op->getOperands(), [&](auto operand) {
Region *operandRegion = operand.getParentRegion();
- return region->isAncestor(operandRegion);
+ return region.isAncestor(operandRegion);
});
}
+static bool isRegionIsolatedFromAbove(Region ®ion) {
+ bool noLiveInValue = true;
+ region.walk([&noLiveInValue, ®ion](Operation *op) {
+ if (!isOpIsolatedFromAbove(op, region)) {
+ noLiveInValue = false;
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+ return noLiveInValue;
+}
+
bool checkErrorIfCondIf(Operation *op) {
auto ifOp = dyn_cast<tosa::IfOp>(op);
if (!ifOp)
@@ -1236,24 +1248,10 @@ bool checkErrorIfCondIf(Operation *op) {
// used in then/else regions (see 'simplified' example above), so it
// must be rewritten to use the generic syntax in order to be conformant
// to the specification.
-
- // Returns true if the region uses no external input operands.
- auto isIsolatedRegion = [](Region ®ion) -> bool {
- bool noLiveInValue = true;
- region.walk([&noLiveInValue, ®ion](Operation *op) {
- if (!isOpIsolatedFromAbove(op, ®ion)) {
- noLiveInValue = false;
- return WalkResult::interrupt();
- }
- return WalkResult::advance();
- });
- return noLiveInValue;
- };
-
- mlir::Region &thenGraph = ifOp.getThenGraph();
- mlir::Region &elseGraph = ifOp.getElseGraph();
- bool isThenGraphIsolatedRegion = isIsolatedRegion(thenGraph);
- bool isElseGraphIsolatedRegion = isIsolatedRegion(elseGraph);
+ Region &thenGraph = ifOp.getThenGraph();
+ Region &elseGraph = ifOp.getElseGraph();
+ bool isThenGraphIsolatedRegion = isRegionIsolatedFromAbove(thenGraph);
+ bool isElseGraphIsolatedRegion = isRegionIsolatedFromAbove(elseGraph);
if (!isThenGraphIsolatedRegion || !isElseGraphIsolatedRegion) {
op->emitOpError()
@@ -1264,10 +1262,30 @@ bool checkErrorIfCondIf(Operation *op) {
return true;
}
+bool checkErrorIfWhileLoop(Operation *op) {
+ auto whileOp = dyn_cast<tosa::WhileOp>(op);
+ if (!whileOp)
+ return true;
+
+ Region &condGraph = whileOp.getCondGraph();
+ Region &bodyGraph = whileOp.getBodyGraph();
+ bool isCondGraphIsolatedRegion = isRegionIsolatedFromAbove(condGraph);
+ bool isBodyGraphIsolatedRegion = isRegionIsolatedFromAbove(bodyGraph);
+
+ if (!isCondGraphIsolatedRegion || !isBodyGraphIsolatedRegion) {
+ op->emitOpError()
+ << "is not conformant to the TOSA specification. It requires the "
+ "cond/body regions are isolated from above.\n";
+ return false;
+ }
+ return true;
+}
+
LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
!checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
- !checkErrorIfPad(op) || !checkErrorIfCondIf(op))
+ !checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
+ !checkErrorIfWhileLoop(op))
return failure();
return success();
}
diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir
index 00c891d4afaa0..77830c7be2e9e 100644
--- a/mlir/test/Dialect/Tosa/error_if_check.mlir
+++ b/mlir/test/Dialect/Tosa/error_if_check.mlir
@@ -265,3 +265,41 @@ func.func @test_cond_if_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<f3
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
+
+// -----
+
+func.func @test_while_loop_not_isolated_from_above(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<f32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ // expected-error at +1 {{'tosa.while_loop' op is not conformant to the TOSA specification. It requires the cond/body regions are isolated from above.}}
+ %1 = "tosa.while_loop"(%0) ({
+ ^bb0(%arg3: tensor<i32>):
+ %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+ tosa.yield %3 : tensor<i1>
+ }, {
+ ^bb0(%arg3: tensor<i32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ tosa.yield %3 : tensor<i32>
+ }) : (tensor<i32>) -> (tensor<i32>)
+ return
+}
+
+// -----
+
+// COM: Check isolated while_loops are valid
+func.func @test_while_loop_isolated_from_above(%arg0: tensor<f32>, %arg1: tensor<i32>) {
+ %0 = "tosa.const"() {values = dense<0> : tensor<i32>} : () -> tensor<i32>
+ %1:3 = "tosa.while_loop"(%0, %arg0, %arg1) ({
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+ %2 = "tosa.greater_equal"(%arg3, %arg5) : (tensor<i32>, tensor<i32>) -> tensor<i1>
+ %3 = "tosa.logical_not"(%2) : (tensor<i1>) -> tensor<i1>
+ "tosa.yield"(%3) : (tensor<i1>) -> ()
+ }, {
+ ^bb0(%arg3: tensor<i32>, %arg4: tensor<f32>, %arg5: tensor<i32>):
+ %2 = "tosa.const"() {values = dense<1> : tensor<i32>} : () -> tensor<i32>
+ %3 = "tosa.add"(%arg3, %2) : (tensor<i32>, tensor<i32>) -> tensor<i32>
+ "tosa.yield"(%3, %arg4, %arg5) : (tensor<i32>, tensor<f32>, tensor<i32>) -> ()
+ }) : (tensor<i32>, tensor<f32>, tensor<i32>) -> (tensor<i32>, tensor<f32>, tensor<i32>)
+ return
+}
More information about the Mlir-commits
mailing list