[Mlir-commits] [mlir] [mlir][spirv] Add support for `spirv.mlir.break` (PR #138688)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 6 06:27:46 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Igor Wodiany (IgWod-IMG)

<details>
<summary>Changes</summary>

Since the SPIR-V dialect uses structured control flow arbitrary branching, which includes conditional loops' early exit, is not supported. This patch introduces new `break` operation that allows to support early exits in loop within the structured control flow.

The main problem this PR tries to solve is the case where a branch to the loop merge block is wrapped in a selection op. Since the selection op cannot reference blocks outside it, a different approach is needed. I am open to feedback whether a better approach exists that does not introduce a new op.

---
Full diff: https://github.com/llvm/llvm-project/pull/138688.diff


9 Files Affected:

- (modified) mlir/docs/Dialects/SPIR-V.md (+55) 
- (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td (+35) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+44) 
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.h (+4) 
- (modified) mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp (+14) 
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+1) 
- (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.h (+2) 
- (modified) mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir (+70) 
- (modified) mlir/test/Target/SPIRV/loop.mlir (+68) 


``````````diff
diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index 1e8c1c7be9f6a..526c85febb0bd 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -746,6 +746,61 @@ For example
 }
 ```
 
+#### Early Exit
+
+In the current form loops do support an early exit as any block can branch to
+the merge block of the loop. However, the problem arises when such early exit
+is conditional and the branch is sunk into a `spirv.mlir.selection` region.
+In such structure the branch inside the selection region cannot reference block
+of the loop enclosing the selection. At the same time such pattern is not unusual.
+To support early loop exit within nested structured control flow, SPIR-V dialect
+introduces `spirv.mlir.break` operation. The semantic of this operation is to branch
+to the merge block of the first enclosing loop.
+
+For example
+
+```mlir
+spirv.mlir.loop {
+  spirv.Branch ^header(%zero: i32)
+
+^header(%i : i32):
+  %cmp = spirv.SLessThan %i, %count : i32
+  spirv.BranchConditional %cmp, ^body, ^merge_loop
+
+^body:
+  %cond = spirv.SGreaterThan %i, %five : i32
+  spirv.Branch ^selection
+
+^selection:
+  spirv.mlir.selection {
+    spirv.BranchConditional %cond, ^true, ^merge_sel
+  ^true:
+    spirv.mlir.break // Jump to ^merge_loop. Regular branch cannot reference ^merge_loop, as it is outside the region.
+  ^merge_sel:
+    spirv.mlir.merge
+  }
+
+  spirv.Branch ^continue
+
+^continue:
+  %new_i = spirv.IAdd %i, %one : i32
+  spirv.Branch ^header(%new_i: i32)
+
+^merge_loop:
+  spirv.mlir.merge
+}
+```
+
+The equivalent GLSL or C code would be
+
+```c
+for (int i = 0; i < 10; ++i) {
+  x += 1;
+  if(x > 5)
+    break;
+}
+```
+
 ### Block argument for Phi
 
 There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi`
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index ef6682ab3630c..a6fc454d2fb34 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -21,6 +21,15 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 
 // -----
 
+// TODO: This is not only specific to control flow ops, so it could be moved
+// somewhere else.
+class SPIRV_HasParentOfType<string op> : PredOpTrait<
+  "op expects to be nested in " # op,
+  CPred<"getOperation()->getParentOfType<::mlir::spirv::" # op # ">() != nullptr">
+>;
+
+// -----
+
 def SPIRV_BranchOp : SPIRV_Op<"Branch", [
     DeclareOpInterfaceMethods<BranchOpInterface>, InFunctionScope, Pure,
     Terminator]> {
@@ -535,4 +544,30 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
   let hasRegionVerifier = 1;
 }
 
+// -----
+
+def SPIRV_BreakOp : SPIRV_Op<"mlir.break", [
+    Pure, Terminator, SPIRV_HasParentOfType<"LoopOp">, ReturnLike]> {
+  let summary = "Early exit from a structured loop.";
+
+  let description = [{
+    Since the SPIR-V dialect relies on structured control flow, early exit using
+    branches is not possible. Since branch cannot reference blocks outside a region
+    a `spirv.mlir.selection` cannot arbitrarily branch to the merge block of the
+    enclosing loop.
+
+    To provide support for early exits dialect implements a `spirv.mlir.break`
+    operation. The semantic of the operation is like that in GLSL / C / C++.
+    The break operation should be treated as a branch to the merge block of the
+    enclosing loop.
+  }];
+
+  let arguments = (ins);
+  let results = (outs);
+  let assemblyFormat = "attr-dict";
+  let hasOpcode = 0;
+  let autogenSerialization = 0;
+  let hasVerifier = 0;
+}
+
 #endif // MLIR_DIALECT_SPIRV_IR_CONTROLFLOW_OPS
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 1e867dde51001..7b1b9607e032f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2345,6 +2345,46 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
   return success();
 }
 
+LogicalResult spirv::Deserializer::handleEarlyExits() {
+  SmallVector<Block *> loopMergeBlocks;
+
+  // Find all blocks that are loops' merge blocks.
+  for (auto &[_, mergeInfo] : blockMergeInfo)
+    if (mergeInfo.continueBlock)
+      loopMergeBlocks.push_back(mergeInfo.mergeBlock);
+
+  for (auto &[header, mergeInfo] : blockMergeInfo) {
+    // We look for something like `if(x) break; ...` so we only process
+    // selection for now.
+    if (!mergeInfo.continueBlock) {
+      SetVector<Block *> constructBlocks;
+      constructBlocks.insert(header);
+
+      // Iterate over all blocks in the selection. This is similar to
+      // `collectBlocksInConstruct()` but with extra logic inserting
+      // `spirv.mlir.break`. We look for any block inside the selection region
+      // that jumps directly to the loop merge and does not go through the merge
+      // block of the selection. This indicates the unstructured jump so the
+      // branch is replaced with break.
+      for (unsigned i = 0; i < constructBlocks.size(); ++i) {
+        for (Block *successor : constructBlocks[i]->getSuccessors()) {
+          Block *block = constructBlocks[i];
+          if (llvm::is_contained(loopMergeBlocks, successor)) {
+            assert(!block->empty() && block->getNumSuccessors() == 1);
+            block->back().erase();
+            OpBuilder builder(block, block->end());
+            builder.create<spirv::BreakOp>(mergeInfo.loc);
+          }
+          if (successor != mergeInfo.mergeBlock)
+            constructBlocks.insert(successor);
+        }
+      }
+    }
+  }
+
+  return success();
+}
+
 LogicalResult spirv::Deserializer::structurizeControlFlow() {
   LLVM_DEBUG({
     logger.startLine()
@@ -2361,6 +2401,10 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
     return failure();
   }
 
+  if (failed(handleEarlyExits())) {
+    return failure();
+  }
+
   // TODO: This loop is non-deterministic. Iteration order may vary between runs
   // for the same shader as the key to the map is a pointer. See:
   // https://github.com/llvm/llvm-project/issues/128547
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index bcc78e3e6508d..c17b4f5f1f860 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -252,6 +252,10 @@ class Deserializer {
   /// selection construct and the merge block of another.
   LogicalResult splitConditionalBlocks();
 
+  /// Detect unstructured early exits from loops and replaces those arbitrary
+  /// branches with `spirv.mlir.break` statements.
+  LogicalResult handleEarlyExits();
+
   //===--------------------------------------------------------------------===//
   // Type
   //===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index ff3cc92ee8078..aedfd05701177 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -576,6 +576,20 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
   return success();
 }
 
+LogicalResult Serializer::processBreakOp(spirv::BreakOp breakOp) {
+  auto parentLoopOp = breakOp.getOperation()->getParentOfType<spirv::LoopOp>();
+
+  if (!parentLoopOp)
+    return failure();
+
+  auto *mergeBlock = parentLoopOp.getMergeBlock();
+  auto mergeID = getBlockID(mergeBlock);
+
+  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {mergeID});
+
+  return success();
+}
+
 LogicalResult Serializer::processBranchConditionalOp(
     spirv::BranchConditionalOp condBranchOp) {
   auto conditionID = getValueID(condBranchOp.getCondition());
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 1f4f5d7f764db..a8043a2f65086 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1255,6 +1255,7 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
         return processGlobalVariableOp(op);
       })
       .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
+      .Case([&](spirv::BreakOp op) { return processBreakOp(op); })
       .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
       .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
       .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 9edb0f4af008d..b758dd810cbea 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -261,6 +261,8 @@ class Serializer {
 
   LogicalResult processLoopOp(spirv::LoopOp loopOp);
 
+  LogicalResult processBreakOp(spirv::BreakOp breakOp);
+
   LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp);
 
   LogicalResult processBranchOp(spirv::BranchOp branchOp);
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 8ec0bf5bbaacf..23fac9a140333 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -467,6 +467,76 @@ func.func @loop_yield(%count : i32) -> () {
 
 // -----
 
+func.func @loop_break(%count : i32) -> () {
+  %zero = spirv.Constant 0: i32
+  %one  = spirv.Constant 1: i32
+  %five = spirv.Constant 5: i32
+
+  // CHECK: spirv.mlir.loop {
+  spirv.mlir.loop {
+    // CHECK-NEXT: spirv.Branch ^bb1({{%.*}}: i32)
+    spirv.Branch ^header(%zero: i32)
+
+  // CHECK-NEXT: ^bb1({{%.*}}: i32):
+  ^header(%i : i32):
+    %cmp = spirv.SLessThan %i, %count : i32
+    // CHECK: spirv.BranchConditional {{%.*}}, ^bb2, ^bb5
+    spirv.BranchConditional %cmp, ^body, ^merge
+
+  // CHECK-NEXT: ^bb2:
+  ^body:
+    %cond = spirv.SGreaterThan %i, %five : i32
+
+    // CHECK: spirv.Branch ^bb3
+    spirv.Branch ^selection
+
+  // CHECK-NEXT: ^bb3:
+  ^selection:
+    // CHECK-NEXT: spirv.mlir.selection {
+    spirv.mlir.selection {
+      // CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^bb1, ^bb2
+      spirv.BranchConditional %cond, ^true, ^merge
+    // CHECK-NEXT: ^bb1:
+    ^true:
+      // CHECK-NEXT: spirv.mlir.break
+      spirv.mlir.break
+    // CHECK-NEXT: ^bb2:
+    ^merge:
+      // CHECK-NEXT: spirv.mlir.merge
+      spirv.mlir.merge
+    }
+
+    // CHECK: spirv.Branch ^bb4
+    spirv.Branch ^continue
+
+  // CHECK-NEXT: ^bb4:
+  ^continue:
+    %new_i = spirv.IAdd %i, %one : i32
+    // CHECK: spirv.Branch ^bb1({{%.*}}: i32)
+    spirv.Branch ^header(%new_i: i32)
+
+  // CHECK-NEXT: ^bb5:
+  ^merge:
+    // CHECK-NEXT: spirv.mlir.merge
+    spirv.mlir.merge
+  }
+
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.mlir.break
+//===----------------------------------------------------------------------===//
+
+func.func @break() -> () {
+  // expected-error @+1 {{op expects to be nested in LoopOp}}
+  spirv.mlir.break
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.merge
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir
index 95b87b319ac2d..f6b5d44aa9c74 100644
--- a/mlir/test/Target/SPIRV/loop.mlir
+++ b/mlir/test/Target/SPIRV/loop.mlir
@@ -335,3 +335,71 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
     spirv.Return
   }
 }
+
+// -----
+
+// Loop with break statement
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+  spirv.func @loop_break(%count : i32) -> () "None" {
+    %zero = spirv.Constant 0: i32
+    %one  = spirv.Constant 1: i32
+    %five = spirv.Constant 5: i32
+
+    // CHECK: spirv.mlir.loop {
+    spirv.mlir.loop {
+      // CHECK-NEXT: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+      spirv.Branch ^header(%zero: i32)
+
+    // CHECK-NEXT: ^[[HEADER:.+]]({{%.*}}: i32):
+    ^header(%i : i32):
+      %cmp = spirv.SLessThan %i, %count : i32
+      // CHECK: spirv.BranchConditional {{%.*}}, ^[[BODY:.+]], ^[[MERGE:.+]]
+      spirv.BranchConditional %cmp, ^body, ^merge
+
+    // CHECK-NEXT: ^[[BODY:.+]]:
+    ^body:
+      %cond = spirv.SGreaterThan %i, %five : i32
+
+      // CHECK: spirv.Branch ^[[LINK:.+]]
+      spirv.Branch ^selection
+
+    // COM: Artificial block introduced by block splitting in the deserializer.
+    // CHECK-NEXT: ^[[LINK:.+]]:
+    // CHECK-NEXT: spirv.Branch ^[[SELECTION:.+]]
+
+    // CHECK-NEXT: ^[[SELECTION:.+]]:
+    ^selection:
+      // CHECK-NEXT: spirv.mlir.selection {
+      spirv.mlir.selection {
+        // CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^[[TRUE:.+]], ^[[FALSE:.+]]
+        spirv.BranchConditional %cond, ^true, ^merge
+      // CHECK-NEXT: ^[[TRUE:.+]]:
+      ^true:
+        // CHECK-NEXT: spirv.mlir.break
+        spirv.mlir.break
+      // CHECK-NEXT: ^[[MERGE:.+]]:
+      ^merge:
+        // CHECK-NEXT: spirv.mlir.merge
+        spirv.mlir.merge
+      }
+
+      // CHECK: spirv.Branch ^[[CONTINUE:.+]]
+      spirv.Branch ^continue
+
+    // CHECK-NEXT: ^[[CONTINUE:.+]]:
+    ^continue:
+      %new_i = spirv.IAdd %i, %one : i32
+      // CHECK: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+      spirv.Branch ^header(%new_i: i32)
+
+    // CHECK-NEXT: ^[[MERGE:.+]]:
+    ^merge:
+      // CHECK-NEXT: spirv.mlir.merge
+      spirv.mlir.merge
+    }
+
+    // CHECK: spirv.Return
+    spirv.Return
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/138688


More information about the Mlir-commits mailing list