[Mlir-commits] [mlir] e4dee7e - [MLIR][SPIRV] Properly (de-)serialize BranchConditionalOp.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 7 00:01:01 PDT 2021
Author: KareemErgawy-TomTom
Date: 2021-05-07T09:00:50+02:00
New Revision: e4dee7e7309a060bd8dd3c9df0a708157fc935d4
URL: https://github.com/llvm/llvm-project/commit/e4dee7e7309a060bd8dd3c9df0a708157fc935d4
DIFF: https://github.com/llvm/llvm-project/commit/e4dee7e7309a060bd8dd3c9df0a708157fc935d4.diff
LOG: [MLIR][SPIRV] Properly (de-)serialize BranchConditionalOp.
Implements proper (de-)serialization logic for BranchConditionalOp when
such ops have true/false target operands.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D101602
Added:
Modified:
mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
mlir/test/Target/SPIRV/phi.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index d48349630b5ba..bbe16717fa022 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1573,7 +1573,8 @@ LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
uint32_t value = operands[i];
Block *predecessor = getOrCreateBlock(operands[i + 1]);
- blockPhiInfo[predecessor].push_back(value);
+ std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
+ blockPhiInfo[predecessorTargetPair].push_back(value);
LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor
<< " with arg id = " << value << '\n');
}
@@ -1853,7 +1854,8 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
OpBuilder::InsertionGuard guard(opBuilder);
for (const auto &info : blockPhiInfo) {
- Block *block = info.first;
+ Block *block = info.first.first;
+ Block *target = info.first.second;
const BlockPhiInfo &phiInfo = info.second;
LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n");
LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n");
@@ -1882,6 +1884,24 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
blockArgs);
branchOp.erase();
+ } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
+ assert((branchCondOp.getTrueBlock() == target ||
+ branchCondOp.getFalseBlock() == target) &&
+ "expected target to be either the true or false target");
+ if (target == branchCondOp.trueTarget())
+ opBuilder.create<spirv::BranchConditionalOp>(
+ branchCondOp.getLoc(), branchCondOp.condition(), blockArgs,
+ branchCondOp.getFalseBlockArguments(),
+ branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(),
+ branchCondOp.falseTarget());
+ else
+ opBuilder.create<spirv::BranchConditionalOp>(
+ branchCondOp.getLoc(), branchCondOp.condition(),
+ branchCondOp.getTrueBlockArguments(), blockArgs,
+ branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(),
+ branchCondOp.getFalseBlock());
+
+ branchCondOp.erase();
} else {
return emitError(unknownLoc, "unimplemented terminator for Phi creation");
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index ac4846d63cad5..17060dddc9198 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -560,8 +560,10 @@ class Deserializer {
// Header block to its merge (and continue) target mapping.
BlockMergeInfoMap blockMergeInfo;
- // Block to its phi (block argument) mapping.
- DenseMap<Block *, BlockPhiInfo> blockPhiInfo;
+ // For each pair of {predecessor, target} blocks, maps the pair of blocks to
+ // the list of phi arguments passed from predecessor to target.
+ DenseMap<std::pair<Block * /*predecessor*/, Block * /*target*/>, BlockPhiInfo>
+ blockPhiInfo;
// Result <id> to value mapping.
DenseMap<uint32_t, Value> valueMap;
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index ab35315e7f144..773fa863c0811 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -959,7 +959,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
// OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
// So we need to collect all predecessor blocks and the arguments they send
// to this block.
- SmallVector<std::pair<Block *, Operation::operand_iterator>, 4> predecessors;
+ SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
for (Block *predecessor : block->getPredecessors()) {
auto *terminator = predecessor->getTerminator();
// The predecessor here is the immediate one according to MLIR's IR
@@ -971,7 +971,21 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
// structured control flow op's merge block.
predecessor = getPhiIncomingBlock(predecessor);
if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
- predecessors.emplace_back(predecessor, branchOp.operand_begin());
+ predecessors.emplace_back(predecessor, branchOp.getOperands());
+ } else if (auto branchCondOp =
+ dyn_cast<spirv::BranchConditionalOp>(terminator)) {
+ Optional<OperandRange> blockOperands;
+
+ for (auto successorIdx :
+ llvm::seq<unsigned>(0, predecessor->getNumSuccessors()))
+ if (predecessor->getSuccessors()[successorIdx] == block) {
+ blockOperands = branchCondOp.getSuccessorOperands(successorIdx);
+ break;
+ }
+
+ assert(blockOperands && !blockOperands->empty() &&
+ "expected non-empty block operand range");
+ predecessors.emplace_back(predecessor, *blockOperands);
} else {
return terminator->emitError("unimplemented terminator for Phi creation");
}
@@ -996,7 +1010,7 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
phiArgs.push_back(phiID);
for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
- Value value = *(predecessors[predIndex].second + argIndex);
+ Value value = predecessors[predIndex].second[argIndex];
uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
<< ") value " << value << ' ');
diff --git a/mlir/test/Target/SPIRV/phi.mlir b/mlir/test/Target/SPIRV/phi.mlir
index 807783ae74ec4..63236aa495bb4 100644
--- a/mlir/test/Target/SPIRV/phi.mlir
+++ b/mlir/test/Target/SPIRV/phi.mlir
@@ -286,3 +286,60 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.EntryPoint "GLCompute" @fmul_kernel
spv.ExecutionMode @fmul_kernel "LocalSize", 32, 1, 1
}
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+// CHECK-LABEL: @cond_branch_true_argument
+ spv.func @cond_branch_true_argument() -> () "None" {
+ %true = spv.Constant true
+ %zero = spv.Constant 0 : i32
+ %one = spv.Constant 1 : i32
+// CHECK: spv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}}, %{{.*}} : i32, i32), ^[[false1:.*]]
+ spv.BranchConditional %true, ^true1(%zero, %zero: i32, i32), ^false1
+// CHECK: [[true1]](%{{.*}}: i32, %{{.*}}: i32)
+ ^true1(%arg0: i32, %arg1: i32):
+ spv.Return
+// CHECK: [[false1]]:
+ ^false1:
+ spv.Return
+ }
+}
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+// CHECK-LABEL: @cond_branch_false_argument
+ spv.func @cond_branch_false_argument() -> () "None" {
+ %true = spv.Constant true
+ %zero = spv.Constant 0 : i32
+ %one = spv.Constant 1 : i32
+// CHECK: spv.BranchConditional %{{.*}}, ^[[true1:.*]], ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32)
+ spv.BranchConditional %true, ^true1, ^false1(%zero, %zero: i32, i32)
+// CHECK: [[true1]]:
+ ^true1:
+ spv.Return
+// CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32):
+ ^false1(%arg0: i32, %arg1: i32):
+ spv.Return
+ }
+}
+
+// -----
+
+spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
+// CHECK-LABEL: @cond_branch_true_and_false_argument
+ spv.func @cond_branch_true_and_false_argument() -> () "None" {
+ %true = spv.Constant true
+ %zero = spv.Constant 0 : i32
+ %one = spv.Constant 1 : i32
+// CHECK: spv.BranchConditional %{{.*}}, ^[[true1:.*]](%{{.*}} : i32), ^[[false1:.*]](%{{.*}}, %{{.*}} : i32, i32)
+ spv.BranchConditional %true, ^true1(%one: i32), ^false1(%zero, %zero: i32, i32)
+// CHECK: [[true1]](%{{.*}}: i32):
+ ^true1(%arg0: i32):
+ spv.Return
+// CHECK: [[false1]](%{{.*}}: i32, %{{.*}}: i32):
+ ^false1(%arg1: i32, %arg2: i32):
+ spv.Return
+ }
+}
More information about the Mlir-commits
mailing list