[Mlir-commits] [mlir] [mlir][spirv] Allow yielding values from selection regions (PR #133702)

Igor Wodiany llvmlistbot at llvm.org
Mon Mar 31 04:19:50 PDT 2025


https://github.com/IgWod-IMG updated https://github.com/llvm/llvm-project/pull/133702

>From e43db564a42f4796b46696c878110b2a578ce25a Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Thu, 27 Mar 2025 12:18:51 +0000
Subject: [PATCH] [mlir][spirv] Allow yielding values from selection regions

There are cases in SPIR-V shaders where values need to be
yielded from the selection region to make valid MLIR. For
example (part of the SPIR-V shader decompiled to GLSL):

```
bool _115
if (_107)
{
    // ...
    float _200 = fma(...);
    // ...
    _115 = _200 < _174;
}
else
{
    _115 = _107;
}
bool _123;
if (_115)
{
    // ...
    float _213 = fma(...);
    // ...
    _123 = _213 < _174;
}
else
{
    _123 = _115;
}
````

This patch extends `mlir.selection` so it can return values.
`mlir.merge` is used as a "yield" operation. This allows to
maintain a compatibility with code that does not yield any values,
as well as, to maintain an assumption that `mlir.merge` is the
only operation in the merge block of the selection region.
---
 .../Dialect/SPIRV/IR/SPIRVControlFlowOps.td   |  11 +-
 mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp |   6 +-
 mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp  |  17 ++-
 .../SPIRV/Deserialization/Deserializer.cpp    | 100 ++++++++++++++++--
 .../SPIRV/Serialization/SerializeOps.cpp      |  14 +++
 .../Dialect/SPIRV/IR/control-flow-ops.mlir    |  38 +++++++
 mlir/test/Target/SPIRV/selection.mlir         |  52 +++++++++
 7 files changed, 219 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index cb7d27e8d4b9a..5a0db18b2e4b0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -352,7 +352,7 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
 // -----
 
 def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [
-    Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>]> {
+    Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>, ReturnLike]> {
   let summary = "A special terminator for merging a structured selection/loop.";
 
   let description = [{
@@ -360,14 +360,15 @@ def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [
     This op is a terminator used inside their regions to mean jumping to the
     merge point, which is the next op following the `spirv.mlir.selection` or
     `spirv.mlir.loop` op. This op does not have a corresponding instruction in the
-    SPIR-V binary format; it's solely for structural purpose.
+    SPIR-V binary format; it's solely for structural purpose. The instruction is also
+    used to yield values outside the selection/loop region.
   }];
 
-  let arguments = (ins);
+  let arguments = (ins Variadic<AnyType>:$operands);
 
   let results = (outs);
 
-  let assemblyFormat = "attr-dict";
+  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
 
   let hasOpcode = 0;
 
@@ -471,7 +472,7 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
     SPIRV_SelectionControlAttr:$selection_control
   );
 
-  let results = (outs);
+  let results = (outs Variadic<AnyType>:$results);
 
   let regions = (region AnyRegion:$body);
 
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 31d8cd2206148..e1fac7e532e3d 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -227,11 +227,11 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
 
     // Create `spirv.selection` operation, selection header block and merge
     // block.
-    auto selectionOp =
-        rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+    auto selectionOp = rewriter.create<spirv::SelectionOp>(
+        loc, TypeRange(), spirv::SelectionControl::None);
     auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
                                             selectionOp.getBody().end());
-    rewriter.create<spirv::MergeOp>(loc);
+    rewriter.create<spirv::MergeOp>(loc, ValueRange());
 
     OpBuilder::InsertionGuard guard(rewriter);
     auto *selectionHeaderBlock =
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 2959d67b366b9..b154867d77f4a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -383,7 +383,7 @@ void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
   builder.createBlock(&getBody());
 
   // Add a spirv.mlir.merge op into the merge block.
-  builder.create<spirv::MergeOp>(getLoc());
+  builder.create<spirv::MergeOp>(getLoc(), ValueRange());
 }
 
 //===----------------------------------------------------------------------===//
@@ -452,6 +452,11 @@ ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
   if (parseControlAttribute<spirv::SelectionControlAttr,
                             spirv::SelectionControl>(parser, result))
     return failure();
+
+  if (succeeded(parser.parseOptionalArrow()))
+    if (parser.parseTypeList(result.types))
+      return failure();
+
   return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
 }
 
@@ -459,6 +464,10 @@ void SelectionOp::print(OpAsmPrinter &printer) {
   auto control = getSelectionControl();
   if (control != spirv::SelectionControl::None)
     printer << " control(" << spirv::stringifySelectionControl(control) << ")";
+  if (getNumResults() > 0) {
+    printer << " -> ";
+    printer << getResultTypes();
+  }
   printer << ' ';
   printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
                       /*printBlockTerminators=*/true);
@@ -526,15 +535,15 @@ void SelectionOp::addMergeBlock(OpBuilder &builder) {
   builder.createBlock(&getBody());
 
   // Add a spirv.mlir.merge op into the merge block.
-  builder.create<spirv::MergeOp>(getLoc());
+  builder.create<spirv::MergeOp>(getLoc(), ValueRange());
 }
 
 SelectionOp
 SelectionOp::createIfThen(Location loc, Value condition,
                           function_ref<void(OpBuilder &builder)> thenBody,
                           OpBuilder &builder) {
-  auto selectionOp =
-      builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+  auto selectionOp = builder.create<spirv::SelectionOp>(
+      loc, TypeRange(), spirv::SelectionControl::None);
 
   selectionOp.addMergeBlock(builder);
   Block *mergeBlock = selectionOp.getMergeBlock();
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 273817d53d308..74ed3bb7ea33a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1843,7 +1843,8 @@ ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
   OpBuilder builder(&mergeBlock->front());
 
   auto control = static_cast<spirv::SelectionControl>(selectionControl);
-  auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
+  auto selectionOp =
+      builder.create<spirv::SelectionOp>(location, TypeRange(), control);
   selectionOp.addMergeBlock(builder);
 
   return selectionOp;
@@ -1992,6 +1993,29 @@ LogicalResult ControlFlowStructurizer::structurize() {
                                     ArrayRef<Value>(blockArgs));
   }
 
+  // Values defined inside the selection region that need to be yield outside
+  // the region.
+  SmallVector<Value> valuesToYield;
+  // Outside uses of values sank into the selection region. Those uses will be
+  // replaced with values returned by the SelectionOp.
+  SmallVector<Value> outsideUses;
+
+  // Move block arguments of the original block (`mergeBlock`) into the merge
+  // block inside the selection (`body.back()`). Values produced by block
+  // arguments will be yielded by the selection region. We do not update uses or
+  // erase original block arguments yet. It will be done later in the code.
+  if (!isLoop) {
+    for (BlockArgument blockArg : mergeBlock->getArguments()) {
+      // Create new block arguments in the last block ("merge block") of the
+      // selection region. We create one argument for each argument in
+      // `mergeBlock`. This new value will need to be yielded, and the original
+      // value replaced, so add them to appropriate vectors.
+      body.back().addArgument(blockArg.getType(), blockArg.getLoc());
+      valuesToYield.push_back(body.back().getArguments().back());
+      outsideUses.push_back(blockArg);
+    }
+  }
+
   // All the blocks cloned into the SelectionOp/LoopOp's region can now be
   // cleaned up.
   LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
@@ -2000,17 +2024,79 @@ LogicalResult ControlFlowStructurizer::structurize() {
   for (auto *block : constructBlocks)
     block->dropAllReferences();
 
+  // All internal uses should be removed from original blocks by now, so
+  // whatever is left is an outside use and will need to be yielded from
+  // the newly created selection region.
+  if (!isLoop) {
+    for (Block *block : constructBlocks) {
+      for (Operation &op : *block) {
+        if (!op.use_empty())
+          for (Value result : op.getResults()) {
+            valuesToYield.push_back(mapper.lookupOrNull(result));
+            outsideUses.push_back(result);
+          }
+      }
+      for (BlockArgument &arg : block->getArguments()) {
+        if (!arg.use_empty()) {
+          valuesToYield.push_back(mapper.lookupOrNull(arg));
+          outsideUses.push_back(arg);
+        }
+      }
+    }
+  }
+
+  assert(valuesToYield.size() == outsideUses.size());
+
+  // If we need to yield any values from the selection region we will take
+  // care of it here.
+  if (!isLoop && !valuesToYield.empty()) {
+    LLVM_DEBUG(logger.startLine()
+               << "[cf] yielding values from the selection region\n");
+
+    // Update `mlir.merge` with values to be yield.
+    auto mergeOps = body.back().getOps<spirv::MergeOp>();
+    assert(std::next(mergeOps.begin()) == mergeOps.end());
+    Operation *merge = *mergeOps.begin();
+    merge->setOperands(valuesToYield);
+
+    // MLIR does not allow changing the number of results of an operation, so
+    // we create a new SelectionOp with required list of results and move
+    // the region from the initial SelectionOp. The initial operation is then
+    // removed. Since we move the region to the new op all links between blocks
+    // and remapping we have previously done should be preserved.
+    builder.setInsertionPoint(&mergeBlock->front());
+    auto selectionOp = builder.create<spirv::SelectionOp>(
+        location, TypeRange(outsideUses),
+        static_cast<spirv::SelectionControl>(control));
+    selectionOp->getRegion(0).takeBody(body);
+
+    // Remove initial op and swap the pointer to the newly created one.
+    op->erase();
+    op = selectionOp;
+
+    // Update all outside uses to use results of the SelectionOp and remove
+    // block arguments from the original merge block.
+    for (size_t i = 0; i < outsideUses.size(); i++)
+      outsideUses[i].replaceAllUsesWith(selectionOp.getResult(i));
+    for (size_t i = 0; i < mergeBlock->getNumArguments(); ++i)
+      mergeBlock->eraseArgument(i);
+  }
+
   // Check that whether some op in the to-be-erased blocks still has uses. Those
   // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
   // region. We cannot handle such cases given that once a value is sinked into
-  // the SelectionOp/LoopOp's region, there is no escape for it:
-  // SelectionOp/LooOp does not support yield values right now.
+  // the SelectionOp/LoopOp's region, there is no escape for it.
   for (auto *block : constructBlocks) {
     for (Operation &op : *block)
       if (!op.use_empty())
-        return op.emitOpError(
-            "failed control flow structurization: it has uses outside of the "
-            "enclosing selection/loop construct");
+        return op.emitOpError("failed control flow structurization: value has "
+                              "uses outside of the "
+                              "enclosing selection/loop construct");
+    for (BlockArgument &arg : block->getArguments())
+      if (!arg.use_empty())
+        return emitError(arg.getLoc(), "failed control flow structurization: "
+                                       "block argument has uses outside of the "
+                                       "enclosing selection/loop construct");
   }
 
   // Then erase all old blocks.
@@ -2236,7 +2322,7 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
 
     auto *mergeBlock = mergeInfo.mergeBlock;
     assert(mergeBlock && "merge block cannot be nullptr");
-    if (!mergeBlock->args_empty())
+    if (mergeInfo.continueBlock != nullptr && !mergeBlock->args_empty())
       return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
     LLVM_DEBUG({
       logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 4c15523a05fa8..641f139740c67 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -447,6 +447,14 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
   auto mergeID = getBlockID(mergeBlock);
   auto loc = selectionOp.getLoc();
 
+  // Before we do anything wire yielded values with the result of the selection
+  // operation. The selection op is being flatten so we do not have to worry
+  // about values being defined inside a region and used outside it anymore.
+  auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
+  assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
+  for (unsigned i = 0; i < selectionOp.getNumResults(); ++i)
+    selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
+
   // This SelectionOp is in some MLIR block with preceding and following ops. In
   // the binary format, it should reside in separate SPIR-V blocks from its
   // preceding and following ops. So we need to emit unconditional branches to
@@ -483,6 +491,12 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
   // instruction to start a new SPIR-V block for ops following this SelectionOp.
   // The block should use the <id> for the merge block.
   encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
+
+  // We do not process the mergeBlock but we still need to generate phi
+  // functions from its block arguments.
+  if (failed(emitPhiForBlockArguments(mergeBlock)))
+    return failure();
+
   LLVM_DEBUG(llvm::dbgs() << "done merge ");
   LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
   LLVM_DEBUG(llvm::dbgs() << "\n");
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 188a55d755fd2..79c6b50fec73f 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -765,6 +765,44 @@ func.func @missing_entry_block() -> () {
 
 // -----
 
+func.func @selection_yield(%cond: i1) -> () {
+  %zero = spirv.Constant 0: i32
+  %var1 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+  %var2 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+  // CHECK: {{%.*}}:2 = spirv.mlir.selection -> i32, i32 {
+  %yield:2 = spirv.mlir.selection -> i32, i32 {
+    // CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^bb1, ^bb2
+    spirv.BranchConditional %cond, ^then, ^else
+
+  // CHECK: ^bb1
+  ^then:
+    %one = spirv.Constant 1: i32
+    %three = spirv.Constant 3: i32
+    // CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
+    spirv.Branch ^merge(%one, %three : i32, i32)
+
+  // CHECK: ^bb2
+  ^else:
+    %two = spirv.Constant 2: i32
+    %four = spirv.Constant 4 : i32
+    // CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
+    spirv.Branch ^merge(%two, %four : i32, i32)
+
+  // CHECK: ^bb3({{%.*}}: i32, {{%.*}}: i32)
+  ^merge(%merged_1_2: i32, %merged_3_4: i32):
+    // CHECK-NEXT: spirv.mlir.merge {{%.*}}, {{%.*}} : i32, i32
+    spirv.mlir.merge %merged_1_2, %merged_3_4 : i32, i32
+  }
+
+  spirv.Store "Function" %var1, %yield#0 : i32
+  spirv.Store "Function" %var2, %yield#1 : i32
+
+  spirv.Return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.Unreachable
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index 24abb12998d06..5c9b9cdf89f8a 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -150,3 +150,55 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.Return
   }
 }
+
+// -----
+
+// Selection yielding values
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+// CHECK-LABEL: @selection_yield
+  spirv.func @selection_yield(%cond: i1) -> () "None" {
+// CHECK-NEXT:   spirv.Constant 0
+// CHECK-NEXT:   spirv.Variable
+    %zero = spirv.Constant 0 : i32
+    %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+// CHECK:        spirv.Branch ^[[BB:.+]]
+// CHECK-NEXT: ^[[BB]]:
+// CHECK-NEXT:   spirv.mlir.selection -> i32
+    %yield = spirv.mlir.selection -> i32 {
+// CHECK-NEXT: spirv.BranchConditional %{{.*}} [5, 10], ^[[THEN:.+]], ^[[ELSE:.+]]
+      spirv.BranchConditional %cond [5, 10], ^then, ^else
+
+// CHECK-NEXT:   ^[[THEN]]:
+    ^then:
+// CHECK-NEXT:     spirv.Constant 1
+      %one = spirv.Constant 1: i32
+
+// CHECK-NEXT:     spirv.Branch ^[[MERGE:.+]]({{%.*}} : i32)
+      spirv.Branch ^merge(%one : i32)
+
+// CHECK-NEXT:   ^[[ELSE]]:
+    ^else:
+// CHECK-NEXT:     spirv.Constant 2
+      %two = spirv.Constant 2: i32
+// CHECK-NEXT:     spirv.Branch ^[[MERGE]]({{%.*}} : i32)
+      spirv.Branch ^merge(%two : i32)
+
+// CHECK-NEXT:   ^[[MERGE]]({{%.*}}: i32):
+    ^merge(%merged: i32):
+// CHECK-NEXT:     spirv.mlir.merge {{%.*}} : i32
+      spirv.mlir.merge %merged : i32
+    }
+
+    spirv.Store "Function" %var, %yield : i32
+
+    spirv.Return
+  }
+
+  spirv.func @main() -> () "None" {
+    spirv.Return
+  }
+  spirv.EntryPoint "GLCompute" @main
+  spirv.ExecutionMode @main "LocalSize", 1, 1, 1
+}



More information about the Mlir-commits mailing list