[Mlir-commits] [mlir] [mlir][spirv] Allow yielding values from selection regions (PR #133702)
Igor Wodiany
llvmlistbot at llvm.org
Mon Mar 31 04:19:50 PDT 2025
https://github.com/IgWod-IMG updated https://github.com/llvm/llvm-project/pull/133702
>From e43db564a42f4796b46696c878110b2a578ce25a Mon Sep 17 00:00:00 2001
From: Igor Wodiany <igor.wodiany at imgtec.com>
Date: Thu, 27 Mar 2025 12:18:51 +0000
Subject: [PATCH] [mlir][spirv] Allow yielding values from selection regions
There are cases in SPIR-V shaders where values need to be
yielded from the selection region to make valid MLIR. For
example (part of the SPIR-V shader decompiled to GLSL):
```
bool _115
if (_107)
{
// ...
float _200 = fma(...);
// ...
_115 = _200 < _174;
}
else
{
_115 = _107;
}
bool _123;
if (_115)
{
// ...
float _213 = fma(...);
// ...
_123 = _213 < _174;
}
else
{
_123 = _115;
}
````
This patch extends `mlir.selection` so it can return values.
`mlir.merge` is used as a "yield" operation. This allows to
maintain a compatibility with code that does not yield any values,
as well as, to maintain an assumption that `mlir.merge` is the
only operation in the merge block of the selection region.
---
.../Dialect/SPIRV/IR/SPIRVControlFlowOps.td | 11 +-
mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp | 6 +-
mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp | 17 ++-
.../SPIRV/Deserialization/Deserializer.cpp | 100 ++++++++++++++++--
.../SPIRV/Serialization/SerializeOps.cpp | 14 +++
.../Dialect/SPIRV/IR/control-flow-ops.mlir | 38 +++++++
mlir/test/Target/SPIRV/selection.mlir | 52 +++++++++
7 files changed, 219 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index cb7d27e8d4b9a..5a0db18b2e4b0 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -352,7 +352,7 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
// -----
def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [
- Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>]> {
+ Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>, ReturnLike]> {
let summary = "A special terminator for merging a structured selection/loop.";
let description = [{
@@ -360,14 +360,15 @@ def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [
This op is a terminator used inside their regions to mean jumping to the
merge point, which is the next op following the `spirv.mlir.selection` or
`spirv.mlir.loop` op. This op does not have a corresponding instruction in the
- SPIR-V binary format; it's solely for structural purpose.
+ SPIR-V binary format; it's solely for structural purpose. The instruction is also
+ used to yield values outside the selection/loop region.
}];
- let arguments = (ins);
+ let arguments = (ins Variadic<AnyType>:$operands);
let results = (outs);
- let assemblyFormat = "attr-dict";
+ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
let hasOpcode = 0;
@@ -471,7 +472,7 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
SPIRV_SelectionControlAttr:$selection_control
);
- let results = (outs);
+ let results = (outs Variadic<AnyType>:$results);
let regions = (region AnyRegion:$body);
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 31d8cd2206148..e1fac7e532e3d 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -227,11 +227,11 @@ struct IfOpConversion : SCFToSPIRVPattern<scf::IfOp> {
// Create `spirv.selection` operation, selection header block and merge
// block.
- auto selectionOp =
- rewriter.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+ auto selectionOp = rewriter.create<spirv::SelectionOp>(
+ loc, TypeRange(), spirv::SelectionControl::None);
auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(),
selectionOp.getBody().end());
- rewriter.create<spirv::MergeOp>(loc);
+ rewriter.create<spirv::MergeOp>(loc, ValueRange());
OpBuilder::InsertionGuard guard(rewriter);
auto *selectionHeaderBlock =
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 2959d67b366b9..b154867d77f4a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -383,7 +383,7 @@ void LoopOp::addEntryAndMergeBlock(OpBuilder &builder) {
builder.createBlock(&getBody());
// Add a spirv.mlir.merge op into the merge block.
- builder.create<spirv::MergeOp>(getLoc());
+ builder.create<spirv::MergeOp>(getLoc(), ValueRange());
}
//===----------------------------------------------------------------------===//
@@ -452,6 +452,11 @@ ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
if (parseControlAttribute<spirv::SelectionControlAttr,
spirv::SelectionControl>(parser, result))
return failure();
+
+ if (succeeded(parser.parseOptionalArrow()))
+ if (parser.parseTypeList(result.types))
+ return failure();
+
return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
}
@@ -459,6 +464,10 @@ void SelectionOp::print(OpAsmPrinter &printer) {
auto control = getSelectionControl();
if (control != spirv::SelectionControl::None)
printer << " control(" << spirv::stringifySelectionControl(control) << ")";
+ if (getNumResults() > 0) {
+ printer << " -> ";
+ printer << getResultTypes();
+ }
printer << ' ';
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
@@ -526,15 +535,15 @@ void SelectionOp::addMergeBlock(OpBuilder &builder) {
builder.createBlock(&getBody());
// Add a spirv.mlir.merge op into the merge block.
- builder.create<spirv::MergeOp>(getLoc());
+ builder.create<spirv::MergeOp>(getLoc(), ValueRange());
}
SelectionOp
SelectionOp::createIfThen(Location loc, Value condition,
function_ref<void(OpBuilder &builder)> thenBody,
OpBuilder &builder) {
- auto selectionOp =
- builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
+ auto selectionOp = builder.create<spirv::SelectionOp>(
+ loc, TypeRange(), spirv::SelectionControl::None);
selectionOp.addMergeBlock(builder);
Block *mergeBlock = selectionOp.getMergeBlock();
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 273817d53d308..74ed3bb7ea33a 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1843,7 +1843,8 @@ ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
OpBuilder builder(&mergeBlock->front());
auto control = static_cast<spirv::SelectionControl>(selectionControl);
- auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
+ auto selectionOp =
+ builder.create<spirv::SelectionOp>(location, TypeRange(), control);
selectionOp.addMergeBlock(builder);
return selectionOp;
@@ -1992,6 +1993,29 @@ LogicalResult ControlFlowStructurizer::structurize() {
ArrayRef<Value>(blockArgs));
}
+ // Values defined inside the selection region that need to be yield outside
+ // the region.
+ SmallVector<Value> valuesToYield;
+ // Outside uses of values sank into the selection region. Those uses will be
+ // replaced with values returned by the SelectionOp.
+ SmallVector<Value> outsideUses;
+
+ // Move block arguments of the original block (`mergeBlock`) into the merge
+ // block inside the selection (`body.back()`). Values produced by block
+ // arguments will be yielded by the selection region. We do not update uses or
+ // erase original block arguments yet. It will be done later in the code.
+ if (!isLoop) {
+ for (BlockArgument blockArg : mergeBlock->getArguments()) {
+ // Create new block arguments in the last block ("merge block") of the
+ // selection region. We create one argument for each argument in
+ // `mergeBlock`. This new value will need to be yielded, and the original
+ // value replaced, so add them to appropriate vectors.
+ body.back().addArgument(blockArg.getType(), blockArg.getLoc());
+ valuesToYield.push_back(body.back().getArguments().back());
+ outsideUses.push_back(blockArg);
+ }
+ }
+
// All the blocks cloned into the SelectionOp/LoopOp's region can now be
// cleaned up.
LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
@@ -2000,17 +2024,79 @@ LogicalResult ControlFlowStructurizer::structurize() {
for (auto *block : constructBlocks)
block->dropAllReferences();
+ // All internal uses should be removed from original blocks by now, so
+ // whatever is left is an outside use and will need to be yielded from
+ // the newly created selection region.
+ if (!isLoop) {
+ for (Block *block : constructBlocks) {
+ for (Operation &op : *block) {
+ if (!op.use_empty())
+ for (Value result : op.getResults()) {
+ valuesToYield.push_back(mapper.lookupOrNull(result));
+ outsideUses.push_back(result);
+ }
+ }
+ for (BlockArgument &arg : block->getArguments()) {
+ if (!arg.use_empty()) {
+ valuesToYield.push_back(mapper.lookupOrNull(arg));
+ outsideUses.push_back(arg);
+ }
+ }
+ }
+ }
+
+ assert(valuesToYield.size() == outsideUses.size());
+
+ // If we need to yield any values from the selection region we will take
+ // care of it here.
+ if (!isLoop && !valuesToYield.empty()) {
+ LLVM_DEBUG(logger.startLine()
+ << "[cf] yielding values from the selection region\n");
+
+ // Update `mlir.merge` with values to be yield.
+ auto mergeOps = body.back().getOps<spirv::MergeOp>();
+ assert(std::next(mergeOps.begin()) == mergeOps.end());
+ Operation *merge = *mergeOps.begin();
+ merge->setOperands(valuesToYield);
+
+ // MLIR does not allow changing the number of results of an operation, so
+ // we create a new SelectionOp with required list of results and move
+ // the region from the initial SelectionOp. The initial operation is then
+ // removed. Since we move the region to the new op all links between blocks
+ // and remapping we have previously done should be preserved.
+ builder.setInsertionPoint(&mergeBlock->front());
+ auto selectionOp = builder.create<spirv::SelectionOp>(
+ location, TypeRange(outsideUses),
+ static_cast<spirv::SelectionControl>(control));
+ selectionOp->getRegion(0).takeBody(body);
+
+ // Remove initial op and swap the pointer to the newly created one.
+ op->erase();
+ op = selectionOp;
+
+ // Update all outside uses to use results of the SelectionOp and remove
+ // block arguments from the original merge block.
+ for (size_t i = 0; i < outsideUses.size(); i++)
+ outsideUses[i].replaceAllUsesWith(selectionOp.getResult(i));
+ for (size_t i = 0; i < mergeBlock->getNumArguments(); ++i)
+ mergeBlock->eraseArgument(i);
+ }
+
// Check that whether some op in the to-be-erased blocks still has uses. Those
// uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
// region. We cannot handle such cases given that once a value is sinked into
- // the SelectionOp/LoopOp's region, there is no escape for it:
- // SelectionOp/LooOp does not support yield values right now.
+ // the SelectionOp/LoopOp's region, there is no escape for it.
for (auto *block : constructBlocks) {
for (Operation &op : *block)
if (!op.use_empty())
- return op.emitOpError(
- "failed control flow structurization: it has uses outside of the "
- "enclosing selection/loop construct");
+ return op.emitOpError("failed control flow structurization: value has "
+ "uses outside of the "
+ "enclosing selection/loop construct");
+ for (BlockArgument &arg : block->getArguments())
+ if (!arg.use_empty())
+ return emitError(arg.getLoc(), "failed control flow structurization: "
+ "block argument has uses outside of the "
+ "enclosing selection/loop construct");
}
// Then erase all old blocks.
@@ -2236,7 +2322,7 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
auto *mergeBlock = mergeInfo.mergeBlock;
assert(mergeBlock && "merge block cannot be nullptr");
- if (!mergeBlock->args_empty())
+ if (mergeInfo.continueBlock != nullptr && !mergeBlock->args_empty())
return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
LLVM_DEBUG({
logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 4c15523a05fa8..641f139740c67 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -447,6 +447,14 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
auto mergeID = getBlockID(mergeBlock);
auto loc = selectionOp.getLoc();
+ // Before we do anything wire yielded values with the result of the selection
+ // operation. The selection op is being flatten so we do not have to worry
+ // about values being defined inside a region and used outside it anymore.
+ auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
+ assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
+ for (unsigned i = 0; i < selectionOp.getNumResults(); ++i)
+ selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
+
// This SelectionOp is in some MLIR block with preceding and following ops. In
// the binary format, it should reside in separate SPIR-V blocks from its
// preceding and following ops. So we need to emit unconditional branches to
@@ -483,6 +491,12 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
// instruction to start a new SPIR-V block for ops following this SelectionOp.
// The block should use the <id> for the merge block.
encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
+
+ // We do not process the mergeBlock but we still need to generate phi
+ // functions from its block arguments.
+ if (failed(emitPhiForBlockArguments(mergeBlock)))
+ return failure();
+
LLVM_DEBUG(llvm::dbgs() << "done merge ");
LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 188a55d755fd2..79c6b50fec73f 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -765,6 +765,44 @@ func.func @missing_entry_block() -> () {
// -----
+func.func @selection_yield(%cond: i1) -> () {
+ %zero = spirv.Constant 0: i32
+ %var1 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+ %var2 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+ // CHECK: {{%.*}}:2 = spirv.mlir.selection -> i32, i32 {
+ %yield:2 = spirv.mlir.selection -> i32, i32 {
+ // CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^bb1, ^bb2
+ spirv.BranchConditional %cond, ^then, ^else
+
+ // CHECK: ^bb1
+ ^then:
+ %one = spirv.Constant 1: i32
+ %three = spirv.Constant 3: i32
+ // CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
+ spirv.Branch ^merge(%one, %three : i32, i32)
+
+ // CHECK: ^bb2
+ ^else:
+ %two = spirv.Constant 2: i32
+ %four = spirv.Constant 4 : i32
+ // CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
+ spirv.Branch ^merge(%two, %four : i32, i32)
+
+ // CHECK: ^bb3({{%.*}}: i32, {{%.*}}: i32)
+ ^merge(%merged_1_2: i32, %merged_3_4: i32):
+ // CHECK-NEXT: spirv.mlir.merge {{%.*}}, {{%.*}} : i32, i32
+ spirv.mlir.merge %merged_1_2, %merged_3_4 : i32, i32
+ }
+
+ spirv.Store "Function" %var1, %yield#0 : i32
+ spirv.Store "Function" %var2, %yield#1 : i32
+
+ spirv.Return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.Unreachable
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index 24abb12998d06..5c9b9cdf89f8a 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -150,3 +150,55 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.Return
}
}
+
+// -----
+
+// Selection yielding values
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+// CHECK-LABEL: @selection_yield
+ spirv.func @selection_yield(%cond: i1) -> () "None" {
+// CHECK-NEXT: spirv.Constant 0
+// CHECK-NEXT: spirv.Variable
+ %zero = spirv.Constant 0 : i32
+ %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+// CHECK: spirv.Branch ^[[BB:.+]]
+// CHECK-NEXT: ^[[BB]]:
+// CHECK-NEXT: spirv.mlir.selection -> i32
+ %yield = spirv.mlir.selection -> i32 {
+// CHECK-NEXT: spirv.BranchConditional %{{.*}} [5, 10], ^[[THEN:.+]], ^[[ELSE:.+]]
+ spirv.BranchConditional %cond [5, 10], ^then, ^else
+
+// CHECK-NEXT: ^[[THEN]]:
+ ^then:
+// CHECK-NEXT: spirv.Constant 1
+ %one = spirv.Constant 1: i32
+
+// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]({{%.*}} : i32)
+ spirv.Branch ^merge(%one : i32)
+
+// CHECK-NEXT: ^[[ELSE]]:
+ ^else:
+// CHECK-NEXT: spirv.Constant 2
+ %two = spirv.Constant 2: i32
+// CHECK-NEXT: spirv.Branch ^[[MERGE]]({{%.*}} : i32)
+ spirv.Branch ^merge(%two : i32)
+
+// CHECK-NEXT: ^[[MERGE]]({{%.*}}: i32):
+ ^merge(%merged: i32):
+// CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
+ spirv.mlir.merge %merged : i32
+ }
+
+ spirv.Store "Function" %var, %yield : i32
+
+ spirv.Return
+ }
+
+ spirv.func @main() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @main
+ spirv.ExecutionMode @main "LocalSize", 1, 1, 1
+}
More information about the Mlir-commits
mailing list