[llvm] Changes to make Float-to-int scalar transform codegen deterministic (PR #92551)
Phil Camp via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 25 06:32:29 PDT 2024
https://github.com/FlameTop updated https://github.com/llvm/llvm-project/pull/92551
>From 0d5ab41e62f3b5e50a622d7143bc73c0f1fd9852 Mon Sep 17 00:00:00 2001
From: Phil Camp <phil.camp at sony.com>
Date: Fri, 17 May 2024 14:22:28 +0100
Subject: [PATCH 1/2] Changes to make Float-to-int scalar transform codegen
deterministic
---
.../llvm/Transforms/Scalar/Float2Int.h | 24 +++++++++++++++++-
llvm/lib/Transforms/Scalar/Float2Int.cpp | 25 ++++++++++++++-----
2 files changed, 42 insertions(+), 7 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Scalar/Float2Int.h b/llvm/include/llvm/Transforms/Scalar/Float2Int.h
index 337e229efcf37..6922917624e78 100644
--- a/llvm/include/llvm/Transforms/Scalar/Float2Int.h
+++ b/llvm/include/llvm/Transforms/Scalar/Float2Int.h
@@ -28,6 +28,25 @@ class LLVMContext;
class Type;
class Value;
+class OrderedInstruction {
+ Instruction *Ins;
+ unsigned int Order;
+
+public:
+ OrderedInstruction(Instruction *Inst, unsigned int Ord) : Ins(Inst), Order(Ord) {}
+
+ Instruction *getInstruction() { return Ins; }
+ unsigned int getOrder() { return Order; }
+};
+
+template <class T> struct OrderedInstructionLess {
+ bool operator()(const T &lhs, const T &rhs) const {
+ OrderedInstruction lhsOrder = lhs;
+ OrderedInstruction rhsOrder = rhs;
+ return rhsOrder.getOrder() < lhsOrder.getOrder();
+ }
+};
+
class Float2IntPass : public PassInfoMixin<Float2IntPass> {
public:
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
@@ -36,6 +55,7 @@ class Float2IntPass : public PassInfoMixin<Float2IntPass> {
bool runImpl(Function &F, const DominatorTree &DT);
private:
+ unsigned int insOrder(Instruction *I);
void findRoots(Function &F, const DominatorTree &DT);
void seen(Instruction *I, ConstantRange R);
ConstantRange badRange();
@@ -50,7 +70,9 @@ class Float2IntPass : public PassInfoMixin<Float2IntPass> {
MapVector<Instruction *, ConstantRange> SeenInsts;
SmallSetVector<Instruction *, 8> Roots;
- EquivalenceClasses<Instruction *> ECs;
+ EquivalenceClasses<OrderedInstruction,
+ OrderedInstructionLess<OrderedInstruction>> ECs;
+ MapVector<Instruction *, unsigned int> InstructionOrders;
MapVector<Instruction *, Value *> ConvertedInsts;
LLVMContext *Ctx;
};
diff --git a/llvm/lib/Transforms/Scalar/Float2Int.cpp b/llvm/lib/Transforms/Scalar/Float2Int.cpp
index da4d39b4e3ed4..cc69b78e32dc1 100644
--- a/llvm/lib/Transforms/Scalar/Float2Int.cpp
+++ b/llvm/lib/Transforms/Scalar/Float2Int.cpp
@@ -84,6 +84,16 @@ static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {
}
}
+// Instruction order - return deterministic order suitable as set
+// order for EquivalenceClasses.
+unsigned int Float2IntPass::insOrder(Instruction* I) {
+ static unsigned int order = 0;
+ if (InstructionOrders.find(I) != InstructionOrders.end())
+ return InstructionOrders[I];
+ InstructionOrders[I] = order++;
+ return order - 1;
+}
+
// Find the roots - instructions that convert from the FP domain to
// integer domain.
void Float2IntPass::findRoots(Function &F, const DominatorTree &DT) {
@@ -191,7 +201,7 @@ void Float2IntPass::walkBackwards() {
for (Value *O : I->operands()) {
if (Instruction *OI = dyn_cast<Instruction>(O)) {
// Unify def-use chains if they interfere.
- ECs.unionSets(I, OI);
+ ECs.unionSets(OrderedInstruction(I, insOrder(I)), OrderedInstruction(OI, insOrder(OI)));
if (SeenInsts.find(I)->second != badRange())
Worklist.push_back(OI);
} else if (!isa<ConstantFP>(O)) {
@@ -323,7 +333,8 @@ bool Float2IntPass::validateAndTransform(const DataLayout &DL) {
// For every member of the partition, union all the ranges together.
for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
MI != ME; ++MI) {
- Instruction *I = *MI;
+ OrderedInstruction OMI = *MI;
+ Instruction *I = OMI.getInstruction();
auto SeenI = SeenInsts.find(I);
if (SeenI == SeenInsts.end())
continue;
@@ -392,9 +403,10 @@ bool Float2IntPass::validateAndTransform(const DataLayout &DL) {
}
}
- for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
- MI != ME; ++MI)
- convert(*MI, Ty);
+ for (auto MI = ECs.member_begin(It), ME = ECs.member_end(); MI != ME; ++MI) {
+ OrderedInstruction OMI = *MI;
+ convert(OMI.getInstruction(), Ty);
+ }
MadeChange = true;
}
@@ -485,8 +497,9 @@ void Float2IntPass::cleanup() {
bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) {
LLVM_DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n");
// Clear out all state.
- ECs = EquivalenceClasses<Instruction*>();
+ ECs = EquivalenceClasses<OrderedInstruction, OrderedInstructionLess<OrderedInstruction> >();
SeenInsts.clear();
+ InstructionOrders.clear();
ConvertedInsts.clear();
Roots.clear();
>From 1db272ff8239f711c796269475bf591424f23ea1 Mon Sep 17 00:00:00 2001
From: Phil Camp <phil.camp at sony.com>
Date: Tue, 25 Jun 2024 14:31:21 +0100
Subject: [PATCH 2/2] Corrected formatting
---
llvm/include/llvm/Transforms/Scalar/Float2Int.h | 6 ++++--
llvm/lib/Transforms/Scalar/Float2Int.cpp | 11 +++++++----
2 files changed, 11 insertions(+), 6 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Scalar/Float2Int.h b/llvm/include/llvm/Transforms/Scalar/Float2Int.h
index 6922917624e78..ad994f6bf324d 100644
--- a/llvm/include/llvm/Transforms/Scalar/Float2Int.h
+++ b/llvm/include/llvm/Transforms/Scalar/Float2Int.h
@@ -33,7 +33,8 @@ class OrderedInstruction {
unsigned int Order;
public:
- OrderedInstruction(Instruction *Inst, unsigned int Ord) : Ins(Inst), Order(Ord) {}
+ OrderedInstruction(Instruction *Inst, unsigned int Ord)
+ : Ins(Inst), Order(Ord) {}
Instruction *getInstruction() { return Ins; }
unsigned int getOrder() { return Order; }
@@ -71,7 +72,8 @@ class Float2IntPass : public PassInfoMixin<Float2IntPass> {
MapVector<Instruction *, ConstantRange> SeenInsts;
SmallSetVector<Instruction *, 8> Roots;
EquivalenceClasses<OrderedInstruction,
- OrderedInstructionLess<OrderedInstruction>> ECs;
+ OrderedInstructionLess<OrderedInstruction>>
+ ECs;
MapVector<Instruction *, unsigned int> InstructionOrders;
MapVector<Instruction *, Value *> ConvertedInsts;
LLVMContext *Ctx;
diff --git a/llvm/lib/Transforms/Scalar/Float2Int.cpp b/llvm/lib/Transforms/Scalar/Float2Int.cpp
index cc69b78e32dc1..216600e412667 100644
--- a/llvm/lib/Transforms/Scalar/Float2Int.cpp
+++ b/llvm/lib/Transforms/Scalar/Float2Int.cpp
@@ -86,7 +86,7 @@ static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {
// Instruction order - return deterministic order suitable as set
// order for EquivalenceClasses.
-unsigned int Float2IntPass::insOrder(Instruction* I) {
+unsigned int Float2IntPass::insOrder(Instruction *I) {
static unsigned int order = 0;
if (InstructionOrders.find(I) != InstructionOrders.end())
return InstructionOrders[I];
@@ -201,7 +201,8 @@ void Float2IntPass::walkBackwards() {
for (Value *O : I->operands()) {
if (Instruction *OI = dyn_cast<Instruction>(O)) {
// Unify def-use chains if they interfere.
- ECs.unionSets(OrderedInstruction(I, insOrder(I)), OrderedInstruction(OI, insOrder(OI)));
+ ECs.unionSets(OrderedInstruction(I, insOrder(I)),
+ OrderedInstruction(OI, insOrder(OI)));
if (SeenInsts.find(I)->second != badRange())
Worklist.push_back(OI);
} else if (!isa<ConstantFP>(O)) {
@@ -403,7 +404,8 @@ bool Float2IntPass::validateAndTransform(const DataLayout &DL) {
}
}
- for (auto MI = ECs.member_begin(It), ME = ECs.member_end(); MI != ME; ++MI) {
+ for (auto MI = ECs.member_begin(It), ME = ECs.member_end(); MI != ME;
+ ++MI) {
OrderedInstruction OMI = *MI;
convert(OMI.getInstruction(), Ty);
}
@@ -497,7 +499,8 @@ void Float2IntPass::cleanup() {
bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) {
LLVM_DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n");
// Clear out all state.
- ECs = EquivalenceClasses<OrderedInstruction, OrderedInstructionLess<OrderedInstruction> >();
+ ECs = EquivalenceClasses<OrderedInstruction,
+ OrderedInstructionLess<OrderedInstruction> >();
SeenInsts.clear();
InstructionOrders.clear();
ConvertedInsts.clear();
More information about the llvm-commits
mailing list