[llvm] Changes to make Float-to-int scalar transform codegen deterministic (PR #92551)

Phil Camp via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 25 06:32:29 PDT 2024


https://github.com/FlameTop updated https://github.com/llvm/llvm-project/pull/92551

>From 0d5ab41e62f3b5e50a622d7143bc73c0f1fd9852 Mon Sep 17 00:00:00 2001
From: Phil Camp <phil.camp at sony.com>
Date: Fri, 17 May 2024 14:22:28 +0100
Subject: [PATCH 1/2] Changes to make Float-to-int scalar transform codegen
 deterministic

---
 .../llvm/Transforms/Scalar/Float2Int.h        | 24 +++++++++++++++++-
 llvm/lib/Transforms/Scalar/Float2Int.cpp      | 25 ++++++++++++++-----
 2 files changed, 42 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Scalar/Float2Int.h b/llvm/include/llvm/Transforms/Scalar/Float2Int.h
index 337e229efcf37..6922917624e78 100644
--- a/llvm/include/llvm/Transforms/Scalar/Float2Int.h
+++ b/llvm/include/llvm/Transforms/Scalar/Float2Int.h
@@ -28,6 +28,25 @@ class LLVMContext;
 class Type;
 class Value;
 
+class OrderedInstruction {
+  Instruction *Ins;
+  unsigned int Order;
+
+public:
+  OrderedInstruction(Instruction *Inst, unsigned int Ord) : Ins(Inst), Order(Ord) {}
+
+  Instruction *getInstruction() { return Ins; }
+  unsigned int getOrder() { return Order; }
+};
+
+template <class T> struct OrderedInstructionLess {
+  bool operator()(const T &lhs, const T &rhs) const {
+    OrderedInstruction lhsOrder = lhs;
+    OrderedInstruction rhsOrder = rhs;
+    return rhsOrder.getOrder() < lhsOrder.getOrder();
+  }
+};
+
 class Float2IntPass : public PassInfoMixin<Float2IntPass> {
 public:
   PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
@@ -36,6 +55,7 @@ class Float2IntPass : public PassInfoMixin<Float2IntPass> {
   bool runImpl(Function &F, const DominatorTree &DT);
 
 private:
+  unsigned int insOrder(Instruction *I);
   void findRoots(Function &F, const DominatorTree &DT);
   void seen(Instruction *I, ConstantRange R);
   ConstantRange badRange();
@@ -50,7 +70,9 @@ class Float2IntPass : public PassInfoMixin<Float2IntPass> {
 
   MapVector<Instruction *, ConstantRange> SeenInsts;
   SmallSetVector<Instruction *, 8> Roots;
-  EquivalenceClasses<Instruction *> ECs;
+  EquivalenceClasses<OrderedInstruction,
+                     OrderedInstructionLess<OrderedInstruction>> ECs;
+  MapVector<Instruction *, unsigned int> InstructionOrders;
   MapVector<Instruction *, Value *> ConvertedInsts;
   LLVMContext *Ctx;
 };
diff --git a/llvm/lib/Transforms/Scalar/Float2Int.cpp b/llvm/lib/Transforms/Scalar/Float2Int.cpp
index da4d39b4e3ed4..cc69b78e32dc1 100644
--- a/llvm/lib/Transforms/Scalar/Float2Int.cpp
+++ b/llvm/lib/Transforms/Scalar/Float2Int.cpp
@@ -84,6 +84,16 @@ static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {
   }
 }
 
+// Instruction order - return deterministic order suitable as set
+// order for EquivalenceClasses.
+unsigned int Float2IntPass::insOrder(Instruction* I) {
+  static unsigned int order = 0;
+  if (InstructionOrders.find(I) != InstructionOrders.end())
+    return InstructionOrders[I];
+  InstructionOrders[I] = order++;
+  return order - 1;
+}
+
 // Find the roots - instructions that convert from the FP domain to
 // integer domain.
 void Float2IntPass::findRoots(Function &F, const DominatorTree &DT) {
@@ -191,7 +201,7 @@ void Float2IntPass::walkBackwards() {
     for (Value *O : I->operands()) {
       if (Instruction *OI = dyn_cast<Instruction>(O)) {
         // Unify def-use chains if they interfere.
-        ECs.unionSets(I, OI);
+        ECs.unionSets(OrderedInstruction(I, insOrder(I)), OrderedInstruction(OI, insOrder(OI)));
         if (SeenInsts.find(I)->second != badRange())
           Worklist.push_back(OI);
       } else if (!isa<ConstantFP>(O)) {
@@ -323,7 +333,8 @@ bool Float2IntPass::validateAndTransform(const DataLayout &DL) {
     // For every member of the partition, union all the ranges together.
     for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
          MI != ME; ++MI) {
-      Instruction *I = *MI;
+      OrderedInstruction OMI = *MI;
+      Instruction *I = OMI.getInstruction();
       auto SeenI = SeenInsts.find(I);
       if (SeenI == SeenInsts.end())
         continue;
@@ -392,9 +403,10 @@ bool Float2IntPass::validateAndTransform(const DataLayout &DL) {
       }
     }
 
-    for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
-         MI != ME; ++MI)
-      convert(*MI, Ty);
+    for (auto MI = ECs.member_begin(It), ME = ECs.member_end(); MI != ME; ++MI) {
+      OrderedInstruction OMI = *MI;
+      convert(OMI.getInstruction(), Ty);
+    }
     MadeChange = true;
   }
 
@@ -485,8 +497,9 @@ void Float2IntPass::cleanup() {
 bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) {
   LLVM_DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n");
   // Clear out all state.
-  ECs = EquivalenceClasses<Instruction*>();
+  ECs = EquivalenceClasses<OrderedInstruction, OrderedInstructionLess<OrderedInstruction> >();
   SeenInsts.clear();
+  InstructionOrders.clear();
   ConvertedInsts.clear();
   Roots.clear();
 

>From 1db272ff8239f711c796269475bf591424f23ea1 Mon Sep 17 00:00:00 2001
From: Phil Camp <phil.camp at sony.com>
Date: Tue, 25 Jun 2024 14:31:21 +0100
Subject: [PATCH 2/2] Corrected formatting

---
 llvm/include/llvm/Transforms/Scalar/Float2Int.h |  6 ++++--
 llvm/lib/Transforms/Scalar/Float2Int.cpp        | 11 +++++++----
 2 files changed, 11 insertions(+), 6 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Scalar/Float2Int.h b/llvm/include/llvm/Transforms/Scalar/Float2Int.h
index 6922917624e78..ad994f6bf324d 100644
--- a/llvm/include/llvm/Transforms/Scalar/Float2Int.h
+++ b/llvm/include/llvm/Transforms/Scalar/Float2Int.h
@@ -33,7 +33,8 @@ class OrderedInstruction {
   unsigned int Order;
 
 public:
-  OrderedInstruction(Instruction *Inst, unsigned int Ord) : Ins(Inst), Order(Ord) {}
+  OrderedInstruction(Instruction *Inst, unsigned int Ord)
+      : Ins(Inst), Order(Ord) {}
 
   Instruction *getInstruction() { return Ins; }
   unsigned int getOrder() { return Order; }
@@ -71,7 +72,8 @@ class Float2IntPass : public PassInfoMixin<Float2IntPass> {
   MapVector<Instruction *, ConstantRange> SeenInsts;
   SmallSetVector<Instruction *, 8> Roots;
   EquivalenceClasses<OrderedInstruction,
-                     OrderedInstructionLess<OrderedInstruction>> ECs;
+                     OrderedInstructionLess<OrderedInstruction>>
+      ECs;
   MapVector<Instruction *, unsigned int> InstructionOrders;
   MapVector<Instruction *, Value *> ConvertedInsts;
   LLVMContext *Ctx;
diff --git a/llvm/lib/Transforms/Scalar/Float2Int.cpp b/llvm/lib/Transforms/Scalar/Float2Int.cpp
index cc69b78e32dc1..216600e412667 100644
--- a/llvm/lib/Transforms/Scalar/Float2Int.cpp
+++ b/llvm/lib/Transforms/Scalar/Float2Int.cpp
@@ -86,7 +86,7 @@ static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {
 
 // Instruction order - return deterministic order suitable as set
 // order for EquivalenceClasses.
-unsigned int Float2IntPass::insOrder(Instruction* I) {
+unsigned int Float2IntPass::insOrder(Instruction *I) {
   static unsigned int order = 0;
   if (InstructionOrders.find(I) != InstructionOrders.end())
     return InstructionOrders[I];
@@ -201,7 +201,8 @@ void Float2IntPass::walkBackwards() {
     for (Value *O : I->operands()) {
       if (Instruction *OI = dyn_cast<Instruction>(O)) {
         // Unify def-use chains if they interfere.
-        ECs.unionSets(OrderedInstruction(I, insOrder(I)), OrderedInstruction(OI, insOrder(OI)));
+        ECs.unionSets(OrderedInstruction(I, insOrder(I)),
+                      OrderedInstruction(OI, insOrder(OI)));
         if (SeenInsts.find(I)->second != badRange())
           Worklist.push_back(OI);
       } else if (!isa<ConstantFP>(O)) {
@@ -403,7 +404,8 @@ bool Float2IntPass::validateAndTransform(const DataLayout &DL) {
       }
     }
 
-    for (auto MI = ECs.member_begin(It), ME = ECs.member_end(); MI != ME; ++MI) {
+    for (auto MI = ECs.member_begin(It), ME = ECs.member_end(); MI != ME;
+         ++MI) {
       OrderedInstruction OMI = *MI;
       convert(OMI.getInstruction(), Ty);
     }
@@ -497,7 +499,8 @@ void Float2IntPass::cleanup() {
 bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) {
   LLVM_DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n");
   // Clear out all state.
-  ECs = EquivalenceClasses<OrderedInstruction, OrderedInstructionLess<OrderedInstruction> >();
+  ECs = EquivalenceClasses<OrderedInstruction,
+                           OrderedInstructionLess<OrderedInstruction> >();
   SeenInsts.clear();
   InstructionOrders.clear();
   ConvertedInsts.clear();



More information about the llvm-commits mailing list