[Mlir-commits] [mlir] [mlir][spirv] Add support for `spirv.mlir.break` (PR #138688)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 6 06:27:46 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: Igor Wodiany (IgWod-IMG)
<details>
<summary>Changes</summary>
Since the SPIR-V dialect uses structured control flow arbitrary branching, which includes conditional loops' early exit, is not supported. This patch introduces new `break` operation that allows to support early exits in loop within the structured control flow.
The main problem this PR tries to solve is the case where a branch to the loop merge block is wrapped in a selection op. Since the selection op cannot reference blocks outside it, a different approach is needed. I am open to feedback whether a better approach exists that does not introduce a new op.
---
Full diff: https://github.com/llvm/llvm-project/pull/138688.diff
9 Files Affected:
- (modified) mlir/docs/Dialects/SPIR-V.md (+55)
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td (+35)
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+44)
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+4)
- (modified) mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp (+14)
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+1)
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.h (+2)
- (modified) mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir (+70)
- (modified) mlir/test/Target/SPIRV/loop.mlir (+68)
``````````diff
diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index 1e8c1c7be9f6a..526c85febb0bd 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -746,6 +746,61 @@ For example
}
```
+#### Early Exit
+
+In the current form loops do support an early exit as any block can branch to
+the merge block of the loop. However, the problem arises when such early exit
+is conditional and the branch is sunk into a `spirv.mlir.selection` region.
+In such structure the branch inside the selection region cannot reference block
+of the loop enclosing the selection. At the same time such pattern is not unusual.
+To support early loop exit within nested structured control flow, SPIR-V dialect
+introduces `spirv.mlir.break` operation. The semantic of this operation is to branch
+to the merge block of the first enclosing loop.
+
+For example
+
+```mlir
+spirv.mlir.loop {
+ spirv.Branch ^header(%zero: i32)
+
+^header(%i : i32):
+ %cmp = spirv.SLessThan %i, %count : i32
+ spirv.BranchConditional %cmp, ^body, ^merge_loop
+
+^body:
+ %cond = spirv.SGreaterThan %i, %five : i32
+ spirv.Branch ^selection
+
+^selection:
+ spirv.mlir.selection {
+ spirv.BranchConditional %cond, ^true, ^merge_sel
+ ^true:
+ spirv.mlir.break // Jump to ^merge_loop. Regular branch cannot reference ^merge_loop, as it is outside the region.
+ ^merge_sel:
+ spirv.mlir.merge
+ }
+
+ spirv.Branch ^continue
+
+^continue:
+ %new_i = spirv.IAdd %i, %one : i32
+ spirv.Branch ^header(%new_i: i32)
+
+^merge_loop:
+ spirv.mlir.merge
+}
+```
+
+The equivalent GLSL or C code would be
+
+```c
+for (int i = 0; i < 10; ++i) {
+ x += 1;
+ if(x > 5)
+ break;
+}
+```
+
### Block argument for Phi
There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi`
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index ef6682ab3630c..a6fc454d2fb34 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -21,6 +21,15 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
// -----
+// TODO: This is not only specific to control flow ops, so it could be moved
+// somewhere else.
+class SPIRV_HasParentOfType<string op> : PredOpTrait<
+ "op expects to be nested in " # op,
+ CPred<"getOperation()->getParentOfType<::mlir::spirv::" # op # ">() != nullptr">
+>;
+
+// -----
+
def SPIRV_BranchOp : SPIRV_Op<"Branch", [
DeclareOpInterfaceMethods<BranchOpInterface>, InFunctionScope, Pure,
Terminator]> {
@@ -535,4 +544,30 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
let hasRegionVerifier = 1;
}
+// -----
+
+def SPIRV_BreakOp : SPIRV_Op<"mlir.break", [
+ Pure, Terminator, SPIRV_HasParentOfType<"LoopOp">, ReturnLike]> {
+ let summary = "Early exit from a structured loop.";
+
+ let description = [{
+ Since the SPIR-V dialect relies on structured control flow, early exit using
+ branches is not possible. Since branch cannot reference blocks outside a region
+ a `spirv.mlir.selection` cannot arbitrarily branch to the merge block of the
+ enclosing loop.
+
+ To provide support for early exits dialect implements a `spirv.mlir.break`
+ operation. The semantic of the operation is like that in GLSL / C / C++.
+ The break operation should be treated as a branch to the merge block of the
+ enclosing loop.
+ }];
+
+ let arguments = (ins);
+ let results = (outs);
+ let assemblyFormat = "attr-dict";
+ let hasOpcode = 0;
+ let autogenSerialization = 0;
+ let hasVerifier = 0;
+}
+
#endif // MLIR_DIALECT_SPIRV_IR_CONTROLFLOW_OPS
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 1e867dde51001..7b1b9607e032f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2345,6 +2345,46 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
return success();
}
+LogicalResult spirv::Deserializer::handleEarlyExits() {
+ SmallVector<Block *> loopMergeBlocks;
+
+ // Find all blocks that are loops' merge blocks.
+ for (auto &[_, mergeInfo] : blockMergeInfo)
+ if (mergeInfo.continueBlock)
+ loopMergeBlocks.push_back(mergeInfo.mergeBlock);
+
+ for (auto &[header, mergeInfo] : blockMergeInfo) {
+ // We look for something like `if(x) break; ...` so we only process
+ // selection for now.
+ if (!mergeInfo.continueBlock) {
+ SetVector<Block *> constructBlocks;
+ constructBlocks.insert(header);
+
+ // Iterate over all blocks in the selection. This is similar to
+ // `collectBlocksInConstruct()` but with extra logic inserting
+ // `spirv.mlir.break`. We look for any block inside the selection region
+ // that jumps directly to the loop merge and does not go through the merge
+ // block of the selection. This indicates the unstructured jump so the
+ // branch is replaced with break.
+ for (unsigned i = 0; i < constructBlocks.size(); ++i) {
+ for (Block *successor : constructBlocks[i]->getSuccessors()) {
+ Block *block = constructBlocks[i];
+ if (llvm::is_contained(loopMergeBlocks, successor)) {
+ assert(!block->empty() && block->getNumSuccessors() == 1);
+ block->back().erase();
+ OpBuilder builder(block, block->end());
+ builder.create<spirv::BreakOp>(mergeInfo.loc);
+ }
+ if (successor != mergeInfo.mergeBlock)
+ constructBlocks.insert(successor);
+ }
+ }
+ }
+ }
+
+ return success();
+}
+
LogicalResult spirv::Deserializer::structurizeControlFlow() {
LLVM_DEBUG({
logger.startLine()
@@ -2361,6 +2401,10 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
return failure();
}
+ if (failed(handleEarlyExits())) {
+ return failure();
+ }
+
// TODO: This loop is non-deterministic. Iteration order may vary between runs
// for the same shader as the key to the map is a pointer. See:
// https://github.com/llvm/llvm-project/issues/128547
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index bcc78e3e6508d..c17b4f5f1f860 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -252,6 +252,10 @@ class Deserializer {
/// selection construct and the merge block of another.
LogicalResult splitConditionalBlocks();
+ /// Detect unstructured early exits from loops and replaces those arbitrary
+ /// branches with `spirv.mlir.break` statements.
+ LogicalResult handleEarlyExits();
+
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index ff3cc92ee8078..aedfd05701177 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -576,6 +576,20 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
return success();
}
+LogicalResult Serializer::processBreakOp(spirv::BreakOp breakOp) {
+ auto parentLoopOp = breakOp.getOperation()->getParentOfType<spirv::LoopOp>();
+
+ if (!parentLoopOp)
+ return failure();
+
+ auto *mergeBlock = parentLoopOp.getMergeBlock();
+ auto mergeID = getBlockID(mergeBlock);
+
+ encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {mergeID});
+
+ return success();
+}
+
LogicalResult Serializer::processBranchConditionalOp(
spirv::BranchConditionalOp condBranchOp) {
auto conditionID = getValueID(condBranchOp.getCondition());
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 1f4f5d7f764db..a8043a2f65086 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1255,6 +1255,7 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
return processGlobalVariableOp(op);
})
.Case([&](spirv::LoopOp op) { return processLoopOp(op); })
+ .Case([&](spirv::BreakOp op) { return processBreakOp(op); })
.Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
.Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
.Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 9edb0f4af008d..b758dd810cbea 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -261,6 +261,8 @@ class Serializer {
LogicalResult processLoopOp(spirv::LoopOp loopOp);
+ LogicalResult processBreakOp(spirv::BreakOp breakOp);
+
LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp);
LogicalResult processBranchOp(spirv::BranchOp branchOp);
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 8ec0bf5bbaacf..23fac9a140333 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -467,6 +467,76 @@ func.func @loop_yield(%count : i32) -> () {
// -----
+func.func @loop_break(%count : i32) -> () {
+ %zero = spirv.Constant 0: i32
+ %one = spirv.Constant 1: i32
+ %five = spirv.Constant 5: i32
+
+ // CHECK: spirv.mlir.loop {
+ spirv.mlir.loop {
+ // CHECK-NEXT: spirv.Branch ^bb1({{%.*}}: i32)
+ spirv.Branch ^header(%zero: i32)
+
+ // CHECK-NEXT: ^bb1({{%.*}}: i32):
+ ^header(%i : i32):
+ %cmp = spirv.SLessThan %i, %count : i32
+ // CHECK: spirv.BranchConditional {{%.*}}, ^bb2, ^bb5
+ spirv.BranchConditional %cmp, ^body, ^merge
+
+ // CHECK-NEXT: ^bb2:
+ ^body:
+ %cond = spirv.SGreaterThan %i, %five : i32
+
+ // CHECK: spirv.Branch ^bb3
+ spirv.Branch ^selection
+
+ // CHECK-NEXT: ^bb3:
+ ^selection:
+ // CHECK-NEXT: spirv.mlir.selection {
+ spirv.mlir.selection {
+ // CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^bb1, ^bb2
+ spirv.BranchConditional %cond, ^true, ^merge
+ // CHECK-NEXT: ^bb1:
+ ^true:
+ // CHECK-NEXT: spirv.mlir.break
+ spirv.mlir.break
+ // CHECK-NEXT: ^bb2:
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+ }
+
+ // CHECK: spirv.Branch ^bb4
+ spirv.Branch ^continue
+
+ // CHECK-NEXT: ^bb4:
+ ^continue:
+ %new_i = spirv.IAdd %i, %one : i32
+ // CHECK: spirv.Branch ^bb1({{%.*}}: i32)
+ spirv.Branch ^header(%new_i: i32)
+
+ // CHECK-NEXT: ^bb5:
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+ }
+
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.mlir.break
+//===----------------------------------------------------------------------===//
+
+func.func @break() -> () {
+ // expected-error @+1 {{op expects to be nested in LoopOp}}
+ spirv.mlir.break
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.mlir.merge
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir
index 95b87b319ac2d..f6b5d44aa9c74 100644
--- a/mlir/test/Target/SPIRV/loop.mlir
+++ b/mlir/test/Target/SPIRV/loop.mlir
@@ -335,3 +335,71 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
spirv.Return
}
}
+
+// -----
+
+// Loop with break statement
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+ spirv.func @loop_break(%count : i32) -> () "None" {
+ %zero = spirv.Constant 0: i32
+ %one = spirv.Constant 1: i32
+ %five = spirv.Constant 5: i32
+
+ // CHECK: spirv.mlir.loop {
+ spirv.mlir.loop {
+ // CHECK-NEXT: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+ spirv.Branch ^header(%zero: i32)
+
+ // CHECK-NEXT: ^[[HEADER:.+]]({{%.*}}: i32):
+ ^header(%i : i32):
+ %cmp = spirv.SLessThan %i, %count : i32
+ // CHECK: spirv.BranchConditional {{%.*}}, ^[[BODY:.+]], ^[[MERGE:.+]]
+ spirv.BranchConditional %cmp, ^body, ^merge
+
+ // CHECK-NEXT: ^[[BODY:.+]]:
+ ^body:
+ %cond = spirv.SGreaterThan %i, %five : i32
+
+ // CHECK: spirv.Branch ^[[LINK:.+]]
+ spirv.Branch ^selection
+
+ // COM: Artificial block introduced by block splitting in the deserializer.
+ // CHECK-NEXT: ^[[LINK:.+]]:
+ // CHECK-NEXT: spirv.Branch ^[[SELECTION:.+]]
+
+ // CHECK-NEXT: ^[[SELECTION:.+]]:
+ ^selection:
+ // CHECK-NEXT: spirv.mlir.selection {
+ spirv.mlir.selection {
+ // CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^[[TRUE:.+]], ^[[FALSE:.+]]
+ spirv.BranchConditional %cond, ^true, ^merge
+ // CHECK-NEXT: ^[[TRUE:.+]]:
+ ^true:
+ // CHECK-NEXT: spirv.mlir.break
+ spirv.mlir.break
+ // CHECK-NEXT: ^[[MERGE:.+]]:
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+ }
+
+ // CHECK: spirv.Branch ^[[CONTINUE:.+]]
+ spirv.Branch ^continue
+
+ // CHECK-NEXT: ^[[CONTINUE:.+]]:
+ ^continue:
+ %new_i = spirv.IAdd %i, %one : i32
+ // CHECK: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+ spirv.Branch ^header(%new_i: i32)
+
+ // CHECK-NEXT: ^[[MERGE:.+]]:
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+ }
+
+ // CHECK: spirv.Return
+ spirv.Return
+ }
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/138688
More information about the Mlir-commits
mailing list