[llvm] [SPIR-V] Fix BB ordering & register lifetime (PR #111026)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 21 06:16:19 PDT 2024
Nathan =?utf-8?q?Gauër?= <brioche at google.com>,
Nathan =?utf-8?q?Gauër?= <brioche at google.com>,
Nathan =?utf-8?q?Gauër?= <brioche at google.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/111026 at github.com>
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-spir-v
Author: Nathan Gauër (Keenuts)
<details>
<summary>Changes</summary>
The "topological" sorting was behaving incorrectly in some cases:
the exit of a loop could have a lower rank than a node in the loop.
This causes issues when structurizing some patterns, and also codegen
issues as we could generate BBs in the incorrect order in regard to the
SPIR-V spec.
Fixing this ordering alone broke other parts of the structurizer, which
by luck worked. Had to fix those.
Added more test cases, especially to test basic patterns.
I also needed to tweak/disable some tests for 2 reasons:
- SPIR-V now required reg2mem/mem2reg to run. Meaning dead stores
are optimized away. Some tests require tweaks to avoid having the
whole function removed.
- Mem2Reg will generate variable & load/stores. This generates
G_BITCAST in several cases. And there is currently something wrong
we do with G_BITCAST which causes MIR verifier to complain.
Until this is resolved, I disabled -verify-machineinstrs flag on
those tests.
---
Patch is 141.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/111026.diff
97 Files Affected:
- (modified) llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp (+15-7)
- (modified) llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp (+51-50)
- (modified) llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp (+13-2)
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+77-32)
- (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+7-1)
- (modified) llvm/test/CodeGen/SPIRV/HlslBufferLoad.ll (+5-5)
- (modified) llvm/test/CodeGen/SPIRV/OpVariable_order.ll (+2-2)
- (modified) llvm/test/CodeGen/SPIRV/ShaderBufferImage.ll (+1-1)
- (modified) llvm/test/CodeGen/SPIRV/ShaderImage.ll (+1-1)
- (modified) llvm/test/CodeGen/SPIRV/basic_float_types.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/basic_int_types.ll (-3)
- (modified) llvm/test/CodeGen/SPIRV/basic_int_types_spirvdis.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_DispatchThreadID.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveGetLaneIndex.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveIsFirstLane.ll (+1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/abs.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/acos.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/all.ll (+3-2)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/any.ll (+3-2)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/asin.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/atan.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/atan2.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/ceil.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/cos.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/cosh.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/countbits.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/degrees.ll (+1-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/exp.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/exp2.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fdot.ll (+1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/floor.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fmad.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fmax.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fmin.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/frac.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll (+1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/imad.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/length.ll (+1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/lerp.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/log.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/log10.ll (+10-22)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/log2.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/normalize.ll (+1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/pow.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/radians.ll (+1-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rcp.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/reversebits.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/round.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/saturate.ll (+1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/sign.ll (+1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/sin.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/sinh.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/smax.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/smin.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/sqrt.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/step.ll (+1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/tan.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/tanh.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/trunc.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/umax.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/umin.ll (+2-1)
- (modified) llvm/test/CodeGen/SPIRV/literals.ll (-3)
- (added) llvm/test/CodeGen/SPIRV/structurizer/basic-if.ll (+53)
- (added) llvm/test/CodeGen/SPIRV/structurizer/basic-imbalanced-if.ll (+47)
- (added) llvm/test/CodeGen/SPIRV/structurizer/basic-loop.ll (+59)
- (added) llvm/test/CodeGen/SPIRV/structurizer/basic-phi.ll (+58)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.cond-op.ll (+86-99)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.do.break.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.do.continue.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.do.nested.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.for.break.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.for.continue.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.for.nested.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.for.short-circuited-cond.ll (+2-3)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.if.const-cond.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.if.for.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.if.nested.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.if.plain.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.logical-and.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.logical-or.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.return.early.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.switch.ifstmt.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.switch.ifstmt.simple.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.switch.ifstmt.simple2.ll (-1)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/cf.while.break.ll (+27-29)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/condition-linear.ll (+72-70)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/do-continue.ll (+77-79)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/do-nested.ll (+55-67)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/do-plain.ll (+64-60)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/logical-or.ll (+54-61)
- (added) llvm/test/CodeGen/SPIRV/structurizer/loop-continue-split.ll (+104)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-break.ll (+19-19)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-convergence-in-break.ll (+9-9)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/merge-exit-multiple-break.ll (+20-15)
- (added) llvm/test/CodeGen/SPIRV/structurizer/phi-exit.ll (+45)
- (modified) llvm/test/CodeGen/SPIRV/structurizer/return-early.ll (+14-6)
``````````diff
diff --git a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
index 9930d067173df7..c22492ec43b095 100644
--- a/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVMergeRegionExitTargets.cpp
@@ -130,6 +130,13 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
assert(false && "Unhandled terminator type.");
}
+ AllocaInst *CreateVariable(Function &F, Type *Type,
+ BasicBlock::iterator Position) {
+ const DataLayout &DL = F.getDataLayout();
+ return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
+ Position);
+ }
+
// Run the pass on the given convergence region, ignoring the sub-regions.
// Returns true if the CFG changed, false otherwise.
bool runOnConvergenceRegionNoRecurse(LoopInfo &LI,
@@ -152,6 +159,9 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
auto NewExitTarget = BasicBlock::Create(F->getContext(), "new.exit", F);
IRBuilder<> Builder(NewExitTarget);
+ AllocaInst *Variable = CreateVariable(*F, Builder.getInt32Ty(),
+ F->begin()->getFirstInsertionPt());
+
// CodeGen output needs to be stable. Using the set as-is would order
// the targets differently depending on the allocation pattern.
// Sorting per basic-block ordering in the function.
@@ -176,18 +186,16 @@ class SPIRVMergeRegionExitTargets : public FunctionPass {
std::vector<std::pair<BasicBlock *, Value *>> ExitToVariable;
for (auto Exit : SortedExits) {
llvm::Value *Value = createExitVariable(Exit, TargetToValue);
+ IRBuilder<> B2(Exit);
+ B2.SetInsertPoint(Exit->getFirstInsertionPt());
+ B2.CreateStore(Value, Variable);
ExitToVariable.emplace_back(std::make_pair(Exit, Value));
}
- // Gather the correct value depending on the exit we came from.
- llvm::PHINode *node =
- Builder.CreatePHI(Builder.getInt32Ty(), ExitToVariable.size());
- for (auto [BB, Value] : ExitToVariable) {
- node->addIncoming(Value, BB);
- }
+ llvm::Value *Load = Builder.CreateLoad(Builder.getInt32Ty(), Variable);
// Creating the switch to jump to the correct exit target.
- llvm::SwitchInst *Sw = Builder.CreateSwitch(node, SortedExitTargets[0],
+ llvm::SwitchInst *Sw = Builder.CreateSwitch(Load, SortedExitTargets[0],
SortedExitTargets.size() - 1);
for (size_t i = 1; i < SortedExitTargets.size(); i++) {
BasicBlock *BB = SortedExitTargets[i];
diff --git a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
index 211a060ee103bc..5b6d31782c2093 100644
--- a/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp
@@ -87,6 +87,8 @@ BasicBlock *getExitFor(const ConvergenceRegion *CR) {
// Returns the merge block designated by I if I is a merge instruction, nullptr
// otherwise.
BasicBlock *getDesignatedMergeBlock(Instruction *I) {
+ if (I == nullptr)
+ return nullptr;
IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
if (II == nullptr)
return nullptr;
@@ -102,6 +104,8 @@ BasicBlock *getDesignatedMergeBlock(Instruction *I) {
// Returns the continue block designated by I if I is an OpLoopMerge, nullptr
// otherwise.
BasicBlock *getDesignatedContinueBlock(Instruction *I) {
+ if (I == nullptr)
+ return nullptr;
IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
if (II == nullptr)
return nullptr;
@@ -284,18 +288,6 @@ void replaceBranchTargets(BasicBlock *BB, BasicBlock *OldTarget,
assert(false && "Unhandled terminator type.");
}
-// Replaces basic bloc operands |OldSrc| or OpPhi instructions in |BB| by
-// |NewSrc|. This function does not simplify the OpPhi instruction once
-// transformed.
-void replacePhiTargets(BasicBlock *BB, BasicBlock *OldSrc, BasicBlock *NewSrc) {
- for (PHINode &Phi : BB->phis()) {
- int index = Phi.getBasicBlockIndex(OldSrc);
- if (index == -1)
- continue;
- Phi.setIncomingBlock(index, NewSrc);
- }
-}
-
} // anonymous namespace
// Given a reducible CFG, produces a structurized CFG in the SPIR-V sense,
@@ -447,48 +439,40 @@ class SPIRVStructurizer : public FunctionPass {
// clang-format on
std::vector<Edge>
createAliasBlocksForComplexEdges(std::vector<Edge> Edges) {
- std::unordered_map<BasicBlock *, BasicBlock *> Seen;
+ std::unordered_set<BasicBlock *> Seen;
std::vector<Edge> Output;
Output.reserve(Edges.size());
for (auto &[Src, Dst] : Edges) {
- auto [iterator, inserted] = Seen.insert({Src, Dst});
- if (inserted) {
- Output.emplace_back(Src, Dst);
- continue;
+ auto [iterator, inserted] = Seen.insert(Src);
+ if (!inserted) {
+ // Src already a source node. Cannot have 2 edges from A to B.
+ // Creating alias source block.
+ BasicBlock *NewSrc =
+ BasicBlock::Create(F.getContext(), "new.src", &F);
+ replaceBranchTargets(Src, Dst, NewSrc);
+ IRBuilder<> Builder(NewSrc);
+ Builder.CreateBr(Dst);
+ Src = NewSrc;
}
- // The exact same edge was already seen. Ignoring.
- if (iterator->second == Dst)
- continue;
-
- // The same Src block branches to 2 distinct blocks. This will be an
- // issue for the generated OpPhi. Creating alias block.
- BasicBlock *NewSrc =
- BasicBlock::Create(F.getContext(), "new.exit.src", &F);
- replaceBranchTargets(Src, Dst, NewSrc);
- replacePhiTargets(Dst, Src, NewSrc);
-
- IRBuilder<> Builder(NewSrc);
- Builder.CreateBr(Dst);
-
- Seen.emplace(NewSrc, Dst);
- Output.emplace_back(NewSrc, Dst);
+ Output.emplace_back(Src, Dst);
}
return Output;
}
+ AllocaInst *CreateVariable(Function &F, Type *Type,
+ BasicBlock::iterator Position) {
+ const DataLayout &DL = F.getDataLayout();
+ return new AllocaInst(Type, DL.getAllocaAddrSpace(), nullptr, "reg",
+ Position);
+ }
+
// Given a construct defined by |Header|, and a list of exiting edges
// |Edges|, creates a new single exit node, fixing up those edges.
BasicBlock *createSingleExitNode(BasicBlock *Header,
std::vector<Edge> &Edges) {
- auto NewExit = BasicBlock::Create(F.getContext(), "new.exit", &F);
- IRBuilder<> ExitBuilder(NewExit);
-
- std::vector<BasicBlock *> Dsts;
- std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
-
// Given 2 edges: Src1 -> Dst, Src2 -> Dst:
// If Dst has an PHI node, and Src1 and Src2 are both operands, both Src1
// and Src2 cannot be hidden by NewExit. Create 2 new nodes: Alias1,
@@ -496,6 +480,10 @@ class SPIRVStructurizer : public FunctionPass {
// Dst PHI node to look for Alias1 and Alias2.
std::vector<Edge> FixedEdges = createAliasBlocksForComplexEdges(Edges);
+ std::vector<BasicBlock *> Dsts;
+ std::unordered_map<BasicBlock *, ConstantInt *> DstToIndex;
+ auto NewExit = BasicBlock::Create(F.getContext(), "new.exit", &F);
+ IRBuilder<> ExitBuilder(NewExit);
for (auto &[Src, Dst] : FixedEdges) {
if (DstToIndex.count(Dst) != 0)
continue;
@@ -506,33 +494,34 @@ class SPIRVStructurizer : public FunctionPass {
if (Dsts.size() == 1) {
for (auto &[Src, Dst] : FixedEdges) {
replaceBranchTargets(Src, Dst, NewExit);
- replacePhiTargets(Dst, Src, NewExit);
}
ExitBuilder.CreateBr(Dsts[0]);
return NewExit;
}
- PHINode *PhiNode =
- ExitBuilder.CreatePHI(ExitBuilder.getInt32Ty(), FixedEdges.size());
-
+ AllocaInst *Variable = CreateVariable(F, ExitBuilder.getInt32Ty(),
+ F.begin()->getFirstInsertionPt());
for (auto &[Src, Dst] : FixedEdges) {
- PhiNode->addIncoming(DstToIndex[Dst], Src);
+ IRBuilder<> B2(Src);
+ B2.SetInsertPoint(Src->getFirstInsertionPt());
+ B2.CreateStore(DstToIndex[Dst], Variable);
replaceBranchTargets(Src, Dst, NewExit);
- replacePhiTargets(Dst, Src, NewExit);
}
+ llvm::Value *Load =
+ ExitBuilder.CreateLoad(ExitBuilder.getInt32Ty(), Variable);
+
// If we can avoid an OpSwitch, generate an OpBranch. Reason is some
// OpBranch are allowed to exist without a new OpSelectionMerge if one of
// the branch is the parent's merge node, while OpSwitches are not.
if (Dsts.size() == 2) {
- Value *Condition = ExitBuilder.CreateCmp(CmpInst::ICMP_EQ,
- DstToIndex[Dsts[0]], PhiNode);
+ Value *Condition =
+ ExitBuilder.CreateCmp(CmpInst::ICMP_EQ, DstToIndex[Dsts[0]], Load);
ExitBuilder.CreateCondBr(Condition, Dsts[0], Dsts[1]);
return NewExit;
}
- SwitchInst *Sw =
- ExitBuilder.CreateSwitch(PhiNode, Dsts[0], Dsts.size() - 1);
+ SwitchInst *Sw = ExitBuilder.CreateSwitch(Load, Dsts[0], Dsts.size() - 1);
for (auto It = Dsts.begin() + 1; It != Dsts.end(); ++It) {
Sw->addCase(DstToIndex[*It], *It);
}
@@ -576,7 +565,7 @@ class SPIRVStructurizer : public FunctionPass {
// Creates a new basic block in F with a single OpUnreachable instruction.
BasicBlock *CreateUnreachable(Function &F) {
- BasicBlock *BB = BasicBlock::Create(F.getContext(), "new.exit", &F);
+ BasicBlock *BB = BasicBlock::Create(F.getContext(), "unreachable", &F);
IRBuilder<> Builder(BB);
Builder.CreateUnreachable();
return BB;
@@ -1127,6 +1116,18 @@ class SPIRVStructurizer : public FunctionPass {
continue;
Modified = true;
+
+ if (Merge == nullptr) {
+ Merge = *successors(Header).begin();
+ IRBuilder<> Builder(Header);
+ Builder.SetInsertPoint(Header->getTerminator());
+
+ auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
+ SmallVector<Value *, 1> Args = {MergeAddress};
+ Builder.CreateIntrinsic(Intrinsic::spv_selection_merge, {}, {Args});
+ continue;
+ }
+
Instruction *SplitInstruction = Merge->getTerminator();
if (isMergeInstruction(SplitInstruction->getPrevNode()))
SplitInstruction = SplitInstruction->getPrevNode();
diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
index e5384b2eb2c2c1..133a98375d840f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
@@ -29,6 +29,7 @@
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Pass.h"
#include "llvm/Target/TargetOptions.h"
+#include "llvm/Transforms/Scalar/Reg2Mem.h"
#include "llvm/Transforms/Utils.h"
#include <optional>
@@ -162,6 +163,8 @@ void SPIRVPassConfig::addIRPasses() {
TargetPassConfig::addIRPasses();
if (TM.getSubtargetImpl()->isVulkanEnv()) {
+ addPass(createRegToMemWrapperPass());
+
// 1. Simplify loop for subsequent transformations. After this steps, loops
// have the following properties:
// - loops have a single entry edge (pre-header to loop header).
@@ -169,13 +172,21 @@ void SPIRVPassConfig::addIRPasses() {
// - loops have a single back-edge.
addPass(createLoopSimplifyPass());
- // 2. Merge the convergence region exit nodes into one. After this step,
+ // 2. Removes registers whose lifetime spans across basic blocks. Also
+ // removes phi nodes. This will greatly simplify the next steps.
+ addPass(createRegToMemWrapperPass());
+
+ // 3. Merge the convergence region exit nodes into one. After this step,
// regions are single-entry, single-exit. This will help determine the
// correct merge block.
addPass(createSPIRVMergeRegionExitTargetsPass());
- // 3. Structurize.
+ // 4. Structurize.
addPass(createSPIRVStructurizerPass());
+
+ // 5. Reduce the amount of variables required by pushing some operations
+ // back to virtual registers.
+ addPass(createPromoteMemoryToRegisterPass());
}
addPass(createSPIRVRegularizerPass());
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
index dff33b16b9cfcf..f9b361e163c909 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.cpp
@@ -460,53 +460,98 @@ PartialOrderingVisitor::getReachableFrom(BasicBlock *Start) {
return Output;
}
-size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Rank) {
- if (Visited.count(BB) != 0)
- return Rank;
+bool PartialOrderingVisitor::CanBeVisited(BasicBlock *BB) const {
+ for (BasicBlock *P : predecessors(BB)) {
+ // Ignore back-edges.
+ if (DT.dominates(BB, P))
+ continue;
- Loop *L = LI.getLoopFor(BB);
- const bool isLoopHeader = LI.isLoopHeader(BB);
+ // One of the predecessor hasn't been visited. Not ready yet.
+ if (BlockToOrder.count(P) == 0)
+ return false;
- if (BlockToOrder.count(BB) == 0) {
- OrderInfo Info = {Rank, Visited.size()};
- BlockToOrder.emplace(BB, Info);
- } else {
- BlockToOrder[BB].Rank = std::max(BlockToOrder[BB].Rank, Rank);
+ // If the block is a loop exit, the loop must be finished before
+ // we can continue.
+ Loop *L = LI.getLoopFor(P);
+ if (L == nullptr || L->contains(BB))
+ continue;
+
+ // SPIR-V requires a single back-edge. And the backend first
+ // step transforms loops into the simplified format. If we have
+ // more than 1 back-edge, something is wrong.
+ assert(L->getNumBackEdges() <= 1);
+
+ // If the loop has no latch, loop's rank won't matter, so we can
+ // proceed.
+ BasicBlock *Latch = L->getLoopLatch();
+ assert(Latch);
+ if (Latch == nullptr)
+ continue;
+
+ // The latch is not ready yet, let's wait.
+ if (BlockToOrder.count(Latch) == 0)
+ return false;
}
- for (BasicBlock *Predecessor : predecessors(BB)) {
- if (isLoopHeader && L->contains(Predecessor)) {
+ return true;
+}
+
+size_t PartialOrderingVisitor::GetNodeRank(BasicBlock *BB) const {
+ size_t result = 0;
+
+ for (BasicBlock *P : predecessors(BB)) {
+ // Ignore back-edges.
+ if (DT.dominates(BB, P))
continue;
- }
- if (BlockToOrder.count(Predecessor) == 0) {
- return Rank;
+ auto Iterator = BlockToOrder.end();
+ Loop *L = LI.getLoopFor(P);
+ BasicBlock *Latch = L ? L->getLoopLatch() : nullptr;
+
+ // If the predecessor is either outside a loop, or part of
+ // the same loop, simply take its rank + 1.
+ if (L == nullptr || L->contains(BB) || Latch == nullptr) {
+ Iterator = BlockToOrder.find(P);
+ } else {
+ // Otherwise, take the loop's rank (highest rank in the loop) as base.
+ // Since loops have a single latch, highest rank is easy to find.
+ // If the loop has no latch, then it doesn't matter.
+ Iterator = BlockToOrder.find(Latch);
}
+
+ assert(Iterator != BlockToOrder.end());
+ result = std::max(result, Iterator->second.Rank + 1);
}
- Visited.insert(BB);
+ return result;
+}
+
+size_t PartialOrderingVisitor::visit(BasicBlock *BB, size_t Unused) {
+ ToVisit.push(BB);
+ Queued.insert(BB);
- SmallVector<BasicBlock *, 2> OtherSuccessors;
- SmallVector<BasicBlock *, 2> LoopSuccessors;
+ while (ToVisit.size() != 0) {
+ BasicBlock *BB = ToVisit.front();
+ ToVisit.pop();
- for (BasicBlock *Successor : successors(BB)) {
- // Ignoring back-edges.
- if (DT.dominates(Successor, BB))
+ if (!CanBeVisited(BB)) {
+ ToVisit.push(BB);
continue;
+ }
- if (isLoopHeader && L->contains(Successor)) {
- LoopSuccessors.push_back(Successor);
- } else
- OtherSuccessors.push_back(Successor);
- }
+ size_t Rank = GetNodeRank(BB);
+ OrderInfo Info = {Rank, BlockToOrder.size()};
+ BlockToOrder.emplace(BB, Info);
- for (BasicBlock *BB : LoopSuccessors)
- Rank = std::max(Rank, visit(BB, Rank + 1));
+ for (BasicBlock *S : successors(BB)) {
+ if (Queued.count(S) != 0)
+ continue;
+ ToVisit.push(S);
+ Queued.insert(S);
+ }
+ }
- size_t OutputRank = Rank;
- for (BasicBlock *Item : OtherSuccessors)
- OutputRank = std::max(OutputRank, visit(Item, Rank + 1));
- return OutputRank;
+ return 0;
}
PartialOrderingVisitor::PartialOrderingVisitor(Function &F) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h
index 83e717e6ea58fd..11fd3a5c61dcae 100644
--- a/llvm/lib/Target/SPIRV/SPIRVUtils.h
+++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h
@@ -18,6 +18,7 @@
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/TypedPointerType.h"
+#include <queue>
#include <string>
#include <unordered_set>
@@ -62,7 +63,9 @@ class SPIRVSubtarget;
class PartialOrderingVisitor {
DomTreeBuilder::BBDomTree DT;
LoopInfo LI;
- std::unordered_set<BasicBlock *> Visited = {};
+
+ std::unordered_set<BasicBlock *> Queued = {};
+ std::queue<BasicBlock *> ToVisit = {};
struct OrderInfo {
size_t Rank;
@@ -80,6 +83,9 @@ class PartialOrderingVisitor {
// Visits |BB| with the current rank being |Rank|.
size_t visit(BasicBlock *BB, size_t Rank);
+ size_t GetNodeRank(BasicBlock *BB) const;
+ bool CanBeVisited(BasicBlock *BB) const;
+
public:
// Build the visitor to operate on the function F.
PartialOrderingVisitor(Function &F);
diff --git a/llvm/test/CodeGen/SPIRV/HlslBufferLoad.ll b/llvm/test/CodeGen/SPIRV/HlslBufferLoad.ll
index fe960f0d6f2f9a..2c9ad1a657a1a5 100644
--- a/llvm/test/CodeGen/SPIRV/HlslBufferLoad.ll
+++ b/llvm/test/CodeGen/SPIRV/HlslBufferLoad.ll
@@ -1,4 +1,4 @@
-; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-vulkan-library %s -o - | FileCheck %s
+; RUN: llc -O0 -mtriple=spirv-vulkan-library %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-library %s -o - -filetype=obj | spirv-val %}
; CHECK-DAG: OpDecorate [[IntBufferVar:%[0-9]+]] DescriptorSet 16
@@ -18,13 +18,13 @@
; CHECK: {{%[0-9]+}} = OpFunction {{%[0-9]+}} DontInline {{%[0-9]+}}
; CHECK-NEXT: OpLabel
define void @RWBufferLoad() #0 {
-; CHECK-NEXT: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeInt]] [[IntBufferVar]]
+; CHECK: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeInt]] [[IntBufferVar]]
%buffer0 = call target("spirv.Image", i32, 5, 2, 0, 0, 2, 24)
@llvm.spv.handle.fromBinding.tspirv.Image_f32_5_2_0_0_2_24(
i32 16, i32 7, i32 1, i32 0, i1 false)
; Make sure we use the same variable with multiple loads.
-; CHECK-NEXT: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeInt]] [[IntBufferVar]]
+; CHECK: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeInt]] [[IntBufferVar]]
%buffer1 = call target("spirv.Image", i32, 5, 2, 0, 0, 2, 24)
@llvm.spv.handle.fromBinding.tspirv.Image_f32_5_2_0_0_2_24(
i32 16, i32 7, i32 1, i32 0, i1 false)
@@ -36,7 +36,7 @@ define void @RWBufferLoad() #0 {
define void @UseDifferentGlobalVar() #0 {
; Make sure we use a different variable from the first function. They have
; different types.
-; CHECK-NEXT: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeFloat]] [[FloatBufferVar]]
+; CHECK: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeFloat]] [[FloatBufferVar]]
%buffer0 = call target("spirv.Image", float, 5, 2, 0, 0, 2, 3)
@llvm.spv.handle.fromBinding.tspirv.Image_f32_5_2_0_0_2_3(
i32 16, i32 7, i32 1, i32 0, i1 false)
@@ -48,7 +48,7 @@ define void @UseDifferentGlobalVar() #0 {
define void @ReuseGlobalVarFromFirstFunction() #0 {
; Make sure we use the same variable as the first function. They should be the
; same in case one function calls the other.
-; CHECK-NEXT: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeInt]] [[IntBufferVar]]
+; CHECK: [[buffer:%[0-9]+]] = OpLoad [[RWBufferTypeInt]] [[IntBufferVar]]
%buffer1 = call target("spirv.Image", i32, 5, 2, 0, 0, 2, 24)
@llvm.spv.handle.fromBinding.tspirv.Image_f32_5_2_0_0_2_24(
i32 16, i32 7, i32 1, i32 0, i1 false)
diff --git a/llvm/test/CodeGen/SPIRV/OpVariable_order.ll b/llvm/test/CodeGen/SPIRV/OpVariable_order.ll
index 6057bf38d4c4c4..c68250697c4a7b 100644
--- a/llvm/test/Code...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/111026
More information about the llvm-commits
mailing list