[Mlir-commits] [mlir] 2a90631 - [mlir][spirv] Allow yielding values from selection regions (#133702)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 2 06:35:26 PDT 2025
Author: Igor Wodiany
Date: 2025-04-02T14:35:22+01:00
New Revision: 2a90631841284ba538029975e59876bb2e1714af
URL: https://github.com/llvm/llvm-project/commit/2a90631841284ba538029975e59876bb2e1714af
DIFF: https://github.com/llvm/llvm-project/commit/2a90631841284ba538029975e59876bb2e1714af.diff
LOG: [mlir][spirv] Allow yielding values from selection regions (#133702)
There are cases in SPIR-V shaders where values need to be yielded from
the selection region to make valid MLIR. For example (part of the SPIR-V
shader decompiled to GLSL):
```
bool _115
if (_107)
{
// ...
float _200 = fma(...);
// ...
_115 = _200 < _174;
}
else
{
_115 = _107;
}
bool _123;
if (_115)
{
// ...
float _213 = fma(...);
// ...
_123 = _213 < _174;
}
else
{
_123 = _115;
}
````
This patch extends `mlir.selection` so it can return values.
`mlir.merge` is used as a "yield" operation. This allows to maintain a
compatibility with code that does not yield any values, as well as, to
maintain an assumption that `mlir.merge` is the only operation in the
merge block of the selection region.
Added:
Modified:
mlir/docs/Dialects/SPIR-V.md
mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
mlir/test/Target/SPIRV/selection.mlir
Removed:
################################################################################
diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index c1d50f0fbe407..ae9afbd9fdfe5 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -528,7 +528,7 @@ MLIR system.
We introduce a `spirv.mlir.selection` and `spirv.mlir.loop` op for structured selections and
loops, respectively. The merge targets are the next ops following them. Inside
their regions, a special terminator, `spirv.mlir.merge` is introduced for branching to
-the merge target.
+the merge target and yielding values.
### Selection
@@ -603,7 +603,43 @@ func.func @selection(%cond: i1) -> () {
// ...
}
+```
+
+The selection can return values by yielding them with `spirv.mlir.merge`. This
+mechanism allows values defined within the selection region to be used outside of it.
+Without this, values that were sunk into the selection region, but used outside, would
+not be able to escape it.
+
+For example
+
+```mlir
+func.func @selection(%cond: i1) -> () {
+ %zero = spirv.Constant 0: i32
+ %var1 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+ %var2 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+ %yield:2 = spirv.mlir.selection -> i32, i32 {
+ spirv.BranchConditional %cond, ^then, ^else
+ ^then:
+ %one = spirv.Constant 1: i32
+ %three = spirv.Constant 3: i32
+ spirv.Branch ^merge(%one, %three : i32, i32)
+
+ ^else:
+ %two = spirv.Constant 2: i32
+ %four = spirv.Constant 4 : i32
+ spirv.Branch ^merge(%two, %four : i32, i32)
+
+ ^merge(%merged_1_2: i32, %merged_3_4: i32):
+ spirv.mlir.merge %merged_1_2, %merged_3_4 : i32, i32
+ }
+
+ spirv.Store "Function" %var1, %yield#0 : i32
+ spirv.Store "Function" %var2, %yield#1 : i32
+
+ spirv.Return
+}
```
### Loop
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index cb7d27e8d4b9a..039af03871411 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -352,7 +352,7 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
// -----
def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [
- Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>]> {
+ Pure, Terminator, ParentOneOf<["SelectionOp", "LoopOp"]>, ReturnLike]> {
let summary = "A special terminator for merging a structured selection/loop.";
let description = [{
@@ -361,13 +361,23 @@ def SPIRV_MergeOp : SPIRV_Op<"mlir.merge", [
merge point, which is the next op following the `spirv.mlir.selection` or
`spirv.mlir.loop` op. This op does not have a corresponding instruction in the
SPIR-V binary format; it's solely for structural purpose.
+
+ The instruction is also used to yield values from inside the selection/loop region
+ to the outside, as values that were sunk into the region cannot otherwise escape it.
}];
- let arguments = (ins);
+ let arguments = (ins Variadic<AnyType>:$operands);
let results = (outs);
- let assemblyFormat = "attr-dict";
+ let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
+
+ let builders = [
+ OpBuilder<(ins),
+ [{
+ build($_builder, $_state, ValueRange());
+ }]>
+ ];
let hasOpcode = 0;
@@ -465,13 +475,17 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
header block, and one selection merge. The selection header block should be
the first block. The selection merge block should be the last block.
The merge block should only contain a `spirv.mlir.merge` op.
+
+ Values defined inside the selection regions cannot be directly used
+ outside of them; however, the selection region can yield values. These values are
+ yielded using a `spirv.mlir.merge` op and returned as a result of the selection op.
}];
let arguments = (ins
SPIRV_SelectionControlAttr:$selection_control
);
- let results = (outs);
+ let results = (outs Variadic<AnyType>:$results);
let regions = (region AnyRegion:$body);
@@ -494,6 +508,13 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
OpBuilder &builder);
}];
+ let builders = [
+ OpBuilder<(ins "spirv::SelectionControl":$selectionControl),
+ [{
+ build($_builder, $_state, TypeRange(), selectionControl);
+ }]>
+ ];
+
let hasOpcode = 0;
let autogenSerialization = 0;
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index 6ffe7afbc727c..ed9a30086deca 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -452,6 +452,11 @@ ParseResult SelectionOp::parse(OpAsmParser &parser, OperationState &result) {
if (parseControlAttribute<spirv::SelectionControlAttr,
spirv::SelectionControl>(parser, result))
return failure();
+
+ if (succeeded(parser.parseOptionalArrow()))
+ if (parser.parseTypeList(result.types))
+ return failure();
+
return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
}
@@ -459,6 +464,10 @@ void SelectionOp::print(OpAsmPrinter &printer) {
auto control = getSelectionControl();
if (control != spirv::SelectionControl::None)
printer << " control(" << spirv::stringifySelectionControl(control) << ")";
+ if (getNumResults() > 0) {
+ printer << " -> ";
+ printer << getResultTypes();
+ }
printer << ' ';
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 273817d53d308..db838fa68a38c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1992,6 +1992,29 @@ LogicalResult ControlFlowStructurizer::structurize() {
ArrayRef<Value>(blockArgs));
}
+ // Values defined inside the selection region that need to be yielded outside
+ // the region.
+ SmallVector<Value> valuesToYield;
+ // Outside uses of values that were sunk into the selection region. Those uses
+ // will be replaced with values returned by the SelectionOp.
+ SmallVector<Value> outsideUses;
+
+ // Move block arguments of the original block (`mergeBlock`) into the merge
+ // block inside the selection (`body.back()`). Values produced by block
+ // arguments will be yielded by the selection region. We do not update uses or
+ // erase original block arguments yet. It will be done later in the code.
+ if (!isLoop) {
+ for (BlockArgument blockArg : mergeBlock->getArguments()) {
+ // Create new block arguments in the last block ("merge block") of the
+ // selection region. We create one argument for each argument in
+ // `mergeBlock`. This new value will need to be yielded, and the original
+ // value replaced, so add them to appropriate vectors.
+ body.back().addArgument(blockArg.getType(), blockArg.getLoc());
+ valuesToYield.push_back(body.back().getArguments().back());
+ outsideUses.push_back(blockArg);
+ }
+ }
+
// All the blocks cloned into the SelectionOp/LoopOp's region can now be
// cleaned up.
LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
@@ -2000,17 +2023,79 @@ LogicalResult ControlFlowStructurizer::structurize() {
for (auto *block : constructBlocks)
block->dropAllReferences();
+ // All internal uses should be removed from original blocks by now, so
+ // whatever is left is an outside use and will need to be yielded from
+ // the newly created selection region.
+ if (!isLoop) {
+ for (Block *block : constructBlocks) {
+ for (Operation &op : *block) {
+ if (!op.use_empty())
+ for (Value result : op.getResults()) {
+ valuesToYield.push_back(mapper.lookupOrNull(result));
+ outsideUses.push_back(result);
+ }
+ }
+ for (BlockArgument &arg : block->getArguments()) {
+ if (!arg.use_empty()) {
+ valuesToYield.push_back(mapper.lookupOrNull(arg));
+ outsideUses.push_back(arg);
+ }
+ }
+ }
+ }
+
+ assert(valuesToYield.size() == outsideUses.size());
+
+ // If we need to yield any values from the selection region we will take
+ // care of it here.
+ if (!isLoop && !valuesToYield.empty()) {
+ LLVM_DEBUG(logger.startLine()
+ << "[cf] yielding values from the selection region\n");
+
+ // Update `mlir.merge` with values to be yield.
+ auto mergeOps = body.back().getOps<spirv::MergeOp>();
+ Operation *merge = llvm::getSingleElement(mergeOps);
+ assert(merge);
+ merge->setOperands(valuesToYield);
+
+ // MLIR does not allow changing the number of results of an operation, so
+ // we create a new SelectionOp with required list of results and move
+ // the region from the initial SelectionOp. The initial operation is then
+ // removed. Since we move the region to the new op all links between blocks
+ // and remapping we have previously done should be preserved.
+ builder.setInsertionPoint(&mergeBlock->front());
+ auto selectionOp = builder.create<spirv::SelectionOp>(
+ location, TypeRange(outsideUses),
+ static_cast<spirv::SelectionControl>(control));
+ selectionOp->getRegion(0).takeBody(body);
+
+ // Remove initial op and swap the pointer to the newly created one.
+ op->erase();
+ op = selectionOp;
+
+ // Update all outside uses to use results of the SelectionOp and remove
+ // block arguments from the original merge block.
+ for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
+ outsideUses[i].replaceAllUsesWith(selectionOp.getResult(i));
+ for (unsigned i = 0, e = mergeBlock->getNumArguments(); i != e; ++i)
+ mergeBlock->eraseArgument(i);
+ }
+
// Check that whether some op in the to-be-erased blocks still has uses. Those
// uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
// region. We cannot handle such cases given that once a value is sinked into
- // the SelectionOp/LoopOp's region, there is no escape for it:
- // SelectionOp/LooOp does not support yield values right now.
+ // the SelectionOp/LoopOp's region, there is no escape for it.
for (auto *block : constructBlocks) {
for (Operation &op : *block)
if (!op.use_empty())
- return op.emitOpError(
- "failed control flow structurization: it has uses outside of the "
- "enclosing selection/loop construct");
+ return op.emitOpError("failed control flow structurization: value has "
+ "uses outside of the "
+ "enclosing selection/loop construct");
+ for (BlockArgument &arg : block->getArguments())
+ if (!arg.use_empty())
+ return emitError(arg.getLoc(), "failed control flow structurization: "
+ "block argument has uses outside of the "
+ "enclosing selection/loop construct");
}
// Then erase all old blocks.
@@ -2236,7 +2321,7 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
auto *mergeBlock = mergeInfo.mergeBlock;
assert(mergeBlock && "merge block cannot be nullptr");
- if (!mergeBlock->args_empty())
+ if (mergeInfo.continueBlock && !mergeBlock->args_empty())
return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
LLVM_DEBUG({
logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 4c15523a05fa8..5ed59a4134d37 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -447,6 +447,15 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
auto mergeID = getBlockID(mergeBlock);
auto loc = selectionOp.getLoc();
+ // Before we do anything replace results of the selection operation with
+ // values yielded (with `mlir.merge`) from inside the region. The selection op
+ // is being flattened so we do not have to worry about values being defined
+ // inside a region and used outside it anymore.
+ auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
+ assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
+ for (unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i)
+ selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
+
// This SelectionOp is in some MLIR block with preceding and following ops. In
// the binary format, it should reside in separate SPIR-V blocks from its
// preceding and following ops. So we need to emit unconditional branches to
@@ -483,6 +492,12 @@ LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
// instruction to start a new SPIR-V block for ops following this SelectionOp.
// The block should use the <id> for the merge block.
encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
+
+ // We do not process the mergeBlock but we still need to generate phi
+ // functions from its block arguments.
+ if (failed(emitPhiForBlockArguments(mergeBlock)))
+ return failure();
+
LLVM_DEBUG(llvm::dbgs() << "done merge ");
LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 188a55d755fd2..107c8a3207b02 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -765,6 +765,46 @@ func.func @missing_entry_block() -> () {
// -----
+func.func @selection_yield(%cond: i1) -> () {
+ %zero = spirv.Constant 0: i32
+ %var1 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+ %var2 = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+ // CHECK: {{%.*}}:2 = spirv.mlir.selection -> i32, i32 {
+ %yield:2 = spirv.mlir.selection -> i32, i32 {
+ // CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^bb1, ^bb2
+ spirv.BranchConditional %cond, ^then, ^else
+
+ // CHECK: ^bb1
+ ^then:
+ %one = spirv.Constant 1: i32
+ %three = spirv.Constant 3: i32
+ // CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
+ spirv.Branch ^merge(%one, %three : i32, i32)
+
+ // CHECK: ^bb2
+ ^else:
+ %two = spirv.Constant 2: i32
+ %four = spirv.Constant 4 : i32
+ // CHECK: spirv.Branch ^bb3({{%.*}}, {{%.*}} : i32, i32)
+ spirv.Branch ^merge(%two, %four : i32, i32)
+
+ // CHECK: ^bb3({{%.*}}: i32, {{%.*}}: i32)
+ ^merge(%merged_1_2: i32, %merged_3_4: i32):
+ // CHECK-NEXT: spirv.mlir.merge {{%.*}}, {{%.*}} : i32, i32
+ spirv.mlir.merge %merged_1_2, %merged_3_4 : i32, i32
+ }
+
+ // CHECK: spirv.Store "Function" {{%.*}}, {{%.*}}#0 : i32
+ spirv.Store "Function" %var1, %yield#0 : i32
+ // CHECK: spirv.Store "Function" {{%.*}}, {{%.*}}#1 : i32
+ spirv.Store "Function" %var2, %yield#1 : i32
+
+ spirv.Return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.Unreachable
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir
index 24abb12998d06..44625cc299230 100644
--- a/mlir/test/Target/SPIRV/selection.mlir
+++ b/mlir/test/Target/SPIRV/selection.mlir
@@ -150,3 +150,56 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.Return
}
}
+
+// -----
+
+// Selection yielding values
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+// CHECK-LABEL: @selection_yield
+ spirv.func @selection_yield(%cond: i1) -> () "None" {
+// CHECK-NEXT: spirv.Constant 0
+// CHECK-NEXT: spirv.Variable
+ %zero = spirv.Constant 0 : i32
+ %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+// CHECK: spirv.Branch ^[[BB:.+]]
+// CHECK-NEXT: ^[[BB]]:
+// CHECK-NEXT: {{%.*}} = spirv.mlir.selection -> i32
+ %yield = spirv.mlir.selection -> i32 {
+// CHECK-NEXT: spirv.BranchConditional %{{.*}} [5, 10], ^[[THEN:.+]], ^[[ELSE:.+]]
+ spirv.BranchConditional %cond [5, 10], ^then, ^else
+
+// CHECK-NEXT: ^[[THEN]]:
+ ^then:
+// CHECK-NEXT: spirv.Constant 1
+ %one = spirv.Constant 1: i32
+
+// CHECK-NEXT: spirv.Branch ^[[MERGE:.+]]({{%.*}} : i32)
+ spirv.Branch ^merge(%one : i32)
+
+// CHECK-NEXT: ^[[ELSE]]:
+ ^else:
+// CHECK-NEXT: spirv.Constant 2
+ %two = spirv.Constant 2: i32
+// CHECK-NEXT: spirv.Branch ^[[MERGE]]({{%.*}} : i32)
+ spirv.Branch ^merge(%two : i32)
+
+// CHECK-NEXT: ^[[MERGE]]({{%.*}}: i32):
+ ^merge(%merged: i32):
+// CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
+ spirv.mlir.merge %merged : i32
+ }
+
+// CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+ spirv.Store "Function" %var, %yield : i32
+
+ spirv.Return
+ }
+
+ spirv.func @main() -> () "None" {
+ spirv.Return
+ }
+ spirv.EntryPoint "GLCompute" @main
+ spirv.ExecutionMode @main "LocalSize", 1, 1, 1
+}
More information about the Mlir-commits
mailing list