[Mlir-commits] [mlir] 2a90631 - [mlir][spirv] Allow yielding values from selection regions (#133702)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 2 06:35:26 PDT 2025


Author: Igor Wodiany
Date: 2025-04-02T14:35:22+01:00
New Revision: 2a90631841284ba538029975e59876bb2e1714af

URL: https://github.com/llvm/llvm-project/commit/2a90631841284ba538029975e59876bb2e1714af
DIFF: https://github.com/llvm/llvm-project/commit/2a90631841284ba538029975e59876bb2e1714af.diff

LOG: [mlir][spirv] Allow yielding values from selection regions (#133702)

There are cases in SPIR-V shaders where values need to be yielded from
the selection region to make valid MLIR. For example (part of the SPIR-V
shader decompiled to GLSL):

```
bool _115
if (_107)
{
    // ...
    float _200 = fma(...);
    // ...
    _115 = _200 < _174;
}
else
{
    _115 = _107;
}
bool _123;
if (_115)
{
    // ...
    float _213 = fma(...);
    // ...
    _123 = _213 < _174;
}
else
{
    _123 = _115;
}
````

This patch extends `mlir.selection` so it can return values.
`mlir.merge` is used as a "yield" operation. This allows to maintain a
compatibility with code that does not yield any values, as well as, to
maintain an assumption that `mlir.merge` is the only operation in the
merge block of the selection region.

Added: 
    

Modified: 
    mlir/docs/Dialects/SPIR-V.md
    mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
    mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
    mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
    mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
    mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
    mlir/test/Target/SPIRV/selection.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index c1d50f0fbe407..ae9afbd9fdfe5 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -528,7 +528,7 @@ MLIR system.
 We introduce a `spirv.mlir.selection` and `spirv.mlir.loop` op for structured selections and
 loops, respectively. The merge targets are the next ops following them. Inside
 their regions, a special terminator, `spirv.mlir.merge` is introduced for branching to
-the merge target.
+the merge target and yielding values.
 
 ### Selection
 
@@ -603,7 +603,43 @@ func.func @selection(%cond: i1) -> () {
 
   // ...
 }
+```
+
+The selection can return values by yielding them with `spirv.mlir.merge`. This
+mechanism allows values defined within the selection region to be used outside of it.
+Without this, values that were sunk into the selection region, but used outside, would
+not be able to escape it.
+
+For example
+
+```mlir
+func.func @selection(%cond: i1) -> () {
+  %zero = spirv.Constant 0: i32
+  %var1 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+  %var2 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+  %yield:2 = spirv.mlir.selection -> i32, i32 {
+    spirv.BranchConditional %cond, ^then, ^else
 
+  ^then:
+    %one = spirv.Constant 1: i32
+    %three = spirv.Constant 3: i32
+    spirv.Branch ^merge(%one, %three : i32, i32)
+
+  ^else:
+    %two = spirv.Constant 2: i32
+    %four = spirv.Constant 4 : i32
+    spirv.Branch ^merge(%two, %four : i32, i32)
+
+  ^merge(%merged_1_2: i32, %merged_3_4: i32):
+    spirv.mlir.merge %merged_1_2, %merged_3_4 : i32, i32
+  }
+
+  spirv.Store "Function" %var1, %yield#0 : i32
+  spirv.Store "Function" %var2, %yield#1 : i32
+
+  spirv.Return
+}
 ```
 
 ### Loop

diff  --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index cb7d27e8d4b9a..039af03871411 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -352,7 +352,7 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
 // -----
 
 def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [
-    Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>]> {
+    Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>, ReturnLike]> {
   let summary = "A special terminator for merging a structured selection/loop.";
 
   let description = [{
@@ -361,13 +361,23 @@ def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [
     merge point, which is the next op following the `spirv.mlir.selection` or
     `spirv.mlir.loop` op. This op does not have a corresponding instruction in the
     SPIR-V binary format; it's solely for structural purpose.
+
+    The instruction is also used to yield values from inside the selection/loop region
+    to the outside, as values that were sunk into the region cannot otherwise escape it.
   }];
 
-  let arguments = (ins);
+  let arguments = (ins Variadic<AnyType>:$operands);
 
   let results = (outs);
 
-  let assemblyFormat = "attr-dict";
+  let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+
+  let builders = [
+    OpBuilder<(ins),
+    [{
+      build($_builder, $_state, ValueRange());
+    }]>
+  ];
 
   let hasOpcode = 0;
 
@@ -465,13 +475,17 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
     header block, and one selection merge. The selection header block should be
     the first block. The selection merge block should be the last block.
     The merge block should only contain a `spirv.mlir.merge` op.
+
+    Values defined inside the selection regions cannot be directly used
+    outside of them; however, the selection region can yield values. These values are
+    yielded using a `spirv.mlir.merge` op and returned as a result of the selection op.
   }];
 
   let arguments = (ins
     SPIRV_SelectionControlAttr:$selection_control
   );
 
-  let results = (outs);
+  let results = (outs Variadic<AnyType>:$results);
 
   let regions = (region AnyRegion:$body);
 
@@ -494,6 +508,13 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
         OpBuilder &builder);
   }];
 
+  let builders = [
+    OpBuilder<(ins "spirv::SelectionControl":$selectionControl),
+    [{
+      build($_builder, $_state, TypeRange(), selectionControl);
+    }]>
+  ];
+
   let hasOpcode = 0;
 
   let autogenSerialization = 0;

diff  --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 6ffe7afbc727c..ed9a30086deca 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -452,6 +452,11 @@ ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
   if (parseControlAttribute<spirv::SelectionControlAttr,
                             spirv::SelectionControl>(parser, result))
     return failure();
+
+  if (succeeded(parser.parseOptionalArrow()))
+    if (parser.parseTypeList(result.types))
+      return failure();
+
   return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
 }
 
@@ -459,6 +464,10 @@ void SelectionOp::print(OpAsmPrinter &printer) {
   auto control = getSelectionControl();
   if (control != spirv::SelectionControl::None)
     printer << " control(" << spirv::stringifySelectionControl(control) << ")";
+  if (getNumResults() > 0) {
+    printer << " -> ";
+    printer << getResultTypes();
+  }
   printer << ' ';
   printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
                       /*printBlockTerminators=*/true);

diff  --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 273817d53d308..db838fa68a38c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1992,6 +1992,29 @@ LogicalResult ControlFlowStructurizer::structurize() {
                                     ArrayRef<Value>(blockArgs));
   }
 
+  // Values defined inside the selection region that need to be yielded outside
+  // the region.
+  SmallVector<Value> valuesToYield;
+  // Outside uses of values that were sunk into the selection region. Those uses
+  // will be replaced with values returned by the SelectionOp.
+  SmallVector<Value> outsideUses;
+
+  // Move block arguments of the original block (`mergeBlock`) into the merge
+  // block inside the selection (`body.back()`). Values produced by block
+  // arguments will be yielded by the selection region. We do not update uses or
+  // erase original block arguments yet. It will be done later in the code.
+  if (!isLoop) {
+    for (BlockArgument blockArg : mergeBlock->getArguments()) {
+      // Create new block arguments in the last block ("merge block") of the
+      // selection region. We create one argument for each argument in
+      // `mergeBlock`. This new value will need to be yielded, and the original
+      // value replaced, so add them to appropriate vectors.
+      body.back().addArgument(blockArg.getType(), blockArg.getLoc());
+      valuesToYield.push_back(body.back().getArguments().back());
+      outsideUses.push_back(blockArg);
+    }
+  }
+
   // All the blocks cloned into the SelectionOp/LoopOp's region can now be
   // cleaned up.
   LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
@@ -2000,17 +2023,79 @@ LogicalResult ControlFlowStructurizer::structurize() {
   for (auto *block : constructBlocks)
     block->dropAllReferences();
 
+  // All internal uses should be removed from original blocks by now, so
+  // whatever is left is an outside use and will need to be yielded from
+  // the newly created selection region.
+  if (!isLoop) {
+    for (Block *block : constructBlocks) {
+      for (Operation &op : *block) {
+        if (!op.use_empty())
+          for (Value result : op.getResults()) {
+            valuesToYield.push_back(mapper.lookupOrNull(result));
+            outsideUses.push_back(result);
+          }
+      }
+      for (BlockArgument &arg : block->getArguments()) {
+        if (!arg.use_empty()) {
+          valuesToYield.push_back(mapper.lookupOrNull(arg));
+          outsideUses.push_back(arg);
+        }
+      }
+    }
+  }
+
+  assert(valuesToYield.size() == outsideUses.size());
+
+  // If we need to yield any values from the selection region we will take
+  // care of it here.
+  if (!isLoop && !valuesToYield.empty()) {
+    LLVM_DEBUG(logger.startLine()
+               << "[cf] yielding values from the selection region\n");
+
+    // Update `mlir.merge` with values to be yield.
+    auto mergeOps = body.back().getOps<spirv::MergeOp>();
+    Operation *merge = llvm::getSingleElement(mergeOps);
+    assert(merge);
+    merge->setOperands(valuesToYield);
+
+    // MLIR does not allow changing the number of results of an operation, so
+    // we create a new SelectionOp with required list of results and move
+    // the region from the initial SelectionOp. The initial operation is then
+    // removed. Since we move the region to the new op all links between blocks
+    // and remapping we have previously done should be preserved.
+    builder.setInsertionPoint(&mergeBlock->front());
+    auto selectionOp = builder.create<spirv::SelectionOp>(
+        location, TypeRange(outsideUses),
+        static_cast<spirv::SelectionControl>(control));
+    selectionOp->getRegion(0).takeBody(body);
+
+    // Remove initial op and swap the pointer to the newly created one.
+    op->erase();
+    op = selectionOp;
+
+    // Update all outside uses to use results of the SelectionOp and remove
+    // block arguments from the original merge block.
+    for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
+      outsideUses[i].replaceAllUsesWith(selectionOp.getResult(i));
+    for (unsigned i = 0, e = mergeBlock->getNumArguments(); i != e; ++i)
+      mergeBlock->eraseArgument(i);
+  }
+
   // Check that whether some op in the to-be-erased blocks still has uses. Those
   // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
   // region. We cannot handle such cases given that once a value is sinked into
-  // the SelectionOp/LoopOp's region, there is no escape for it:
-  // SelectionOp/LooOp does not support yield values right now.
+  // the SelectionOp/LoopOp's region, there is no escape for it.
   for (auto *block : constructBlocks) {
     for (Operation &op : *block)
       if (!op.use_empty())
-        return op.emitOpError(
-            "failed control flow structurization: it has uses outside of the "
-            "enclosing selection/loop construct");
+        return op.emitOpError("failed control flow structurization: value has "
+                              "uses outside of the "
+                              "enclosing selection/loop construct");
+    for (BlockArgument &arg : block->getArguments())
+      if (!arg.use_empty())
+        return emitError(arg.getLoc(), "failed control flow structurization: "
+                                       "block argument has uses outside of the "
+                                       "enclosing selection/loop construct");
   }
 
   // Then erase all old blocks.
@@ -2236,7 +2321,7 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
 
     auto *mergeBlock = mergeInfo.mergeBlock;
     assert(mergeBlock && "merge block cannot be nullptr");
-    if (!mergeBlock->args_empty())
+    if (mergeInfo.continueBlock && !mergeBlock->args_empty())
       return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
     LLVM_DEBUG({
       logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";

diff  --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 4c15523a05fa8..5ed59a4134d37 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -447,6 +447,15 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
   auto mergeID = getBlockID(mergeBlock);
   auto loc = selectionOp.getLoc();
 
+  // Before we do anything replace results of the selection operation with
+  // values yielded (with `mlir.merge`) from inside the region. The selection op
+  // is being flattened so we do not have to worry about values being defined
+  // inside a region and used outside it anymore.
+  auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
+  assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
+  for (unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i)
+    selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
+
   // This SelectionOp is in some MLIR block with preceding and following ops. In
   // the binary format, it should reside in separate SPIR-V blocks from its
   // preceding and following ops. So we need to emit unconditional branches to
@@ -483,6 +492,12 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
   // instruction to start a new SPIR-V block for ops following this SelectionOp.
   // The block should use the <id> for the merge block.
   encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
+
+  // We do not process the mergeBlock but we still need to generate phi
+  // functions from its block arguments.
+  if (failed(emitPhiForBlockArguments(mergeBlock)))
+    return failure();
+
   LLVM_DEBUG(llvm::dbgs() << "done merge ");
   LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
   LLVM_DEBUG(llvm::dbgs() << "\n");

diff  --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 188a55d755fd2..107c8a3207b02 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -765,6 +765,46 @@ func.func @missing_entry_block() -> () {
 
 // -----
 
+func.func @selection_yield(%cond: i1) -> () {
+  %zero = spirv.Constant 0: i32
+  %var1 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+  %var2 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+  // CHECK: {{%.*}}:2 = spirv.mlir.selection -> i32, i32 {
+  %yield:2 = spirv.mlir.selection -> i32, i32 {
+    // CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^bb1, ^bb2
+    spirv.BranchConditional %cond, ^then, ^else
+
+  // CHECK: ^bb1
+  ^then:
+    %one = spirv.Constant 1: i32
+    %three = spirv.Constant 3: i32
+    // CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
+    spirv.Branch ^merge(%one, %three : i32, i32)
+
+  // CHECK: ^bb2
+  ^else:
+    %two = spirv.Constant 2: i32
+    %four = spirv.Constant 4 : i32
+    // CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
+    spirv.Branch ^merge(%two, %four : i32, i32)
+
+  // CHECK: ^bb3({{%.*}}: i32, {{%.*}}: i32)
+  ^merge(%merged_1_2: i32, %merged_3_4: i32):
+    // CHECK-NEXT: spirv.mlir.merge {{%.*}}, {{%.*}} : i32, i32
+    spirv.mlir.merge %merged_1_2, %merged_3_4 : i32, i32
+  }
+
+  // CHECK: spirv.Store "Function" {{%.*}}, {{%.*}}#0 : i32
+  spirv.Store "Function" %var1, %yield#0 : i32
+  // CHECK: spirv.Store "Function" {{%.*}}, {{%.*}}#1 : i32
+  spirv.Store "Function" %var2, %yield#1 : i32
+
+  spirv.Return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.Unreachable
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index 24abb12998d06..44625cc299230 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -150,3 +150,56 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.Return
   }
 }
+
+// -----
+
+// Selection yielding values
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+// CHECK-LABEL: @selection_yield
+  spirv.func @selection_yield(%cond: i1) -> () "None" {
+// CHECK-NEXT:   spirv.Constant 0
+// CHECK-NEXT:   spirv.Variable
+    %zero = spirv.Constant 0 : i32
+    %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+// CHECK:        spirv.Branch ^[[BB:.+]]
+// CHECK-NEXT: ^[[BB]]:
+// CHECK-NEXT:   {{%.*}} = spirv.mlir.selection -> i32
+    %yield = spirv.mlir.selection -> i32 {
+// CHECK-NEXT: spirv.BranchConditional %{{.*}} [5, 10], ^[[THEN:.+]], ^[[ELSE:.+]]
+      spirv.BranchConditional %cond [5, 10], ^then, ^else
+
+// CHECK-NEXT:   ^[[THEN]]:
+    ^then:
+// CHECK-NEXT:     spirv.Constant 1
+      %one = spirv.Constant 1: i32
+
+// CHECK-NEXT:     spirv.Branch ^[[MERGE:.+]]({{%.*}} : i32)
+      spirv.Branch ^merge(%one : i32)
+
+// CHECK-NEXT:   ^[[ELSE]]:
+    ^else:
+// CHECK-NEXT:     spirv.Constant 2
+      %two = spirv.Constant 2: i32
+// CHECK-NEXT:     spirv.Branch ^[[MERGE]]({{%.*}} : i32)
+      spirv.Branch ^merge(%two : i32)
+
+// CHECK-NEXT:   ^[[MERGE]]({{%.*}}: i32):
+    ^merge(%merged: i32):
+// CHECK-NEXT:     spirv.mlir.merge {{%.*}} : i32
+      spirv.mlir.merge %merged : i32
+    }
+
+// CHECK:        spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+    spirv.Store "Function" %var, %yield : i32
+
+    spirv.Return
+  }
+
+  spirv.func @main() -> () "None" {
+    spirv.Return
+  }
+  spirv.EntryPoint "GLCompute" @main
+  spirv.ExecutionMode @main "LocalSize", 1, 1, 1
+}


        


More information about the Mlir-commits mailing list