[llvm] Changes to make Float-to-int scalar transform codegen deterministic (PR #92551)

Phil Camp via llvm-commits llvm-commits at lists.llvm.org
Fri May 17 07:09:51 PDT 2024


https://github.com/FlameTop created https://github.com/llvm/llvm-project/pull/92551

The current scalar transform Float2Int features non-deterministic code generation based on memory layout. While the resultant code is functionally the same, its is not identical. This can interfere with caching build algorithms. 

The cause is the member 'EC' which is an equivalence class of Instruction pointers. Under the hood the Equivalence class uses an ordered list to store the Instruction pointers. As no order is provided the pointers are simply stored in the order based on their value (i.e. the address they were malloc'd on the heap). Many heap allocation schemes will allocate such addresses in an ascending manner so the order is constant between compilations. However, we use a heap allocation scheme that does not guarantee such symmetry between runs. 

This change simply provides an order to the 'EC' list based on the order in which the Instructions are found. This should be constant between different runs. It uses a new member in the Float2Int class so as to not burden the Instruction object with more fields that are only used in this transform. 

>From 0d5ab41e62f3b5e50a622d7143bc73c0f1fd9852 Mon Sep 17 00:00:00 2001
From: Phil Camp <phil.camp at sony.com>
Date: Fri, 17 May 2024 14:22:28 +0100
Subject: [PATCH] Changes to make Float-to-int scalar transform codegen
 deterministic

---
 .../llvm/Transforms/Scalar/Float2Int.h        | 24 +++++++++++++++++-
 llvm/lib/Transforms/Scalar/Float2Int.cpp      | 25 ++++++++++++++-----
 2 files changed, 42 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Scalar/Float2Int.h b/llvm/include/llvm/Transforms/Scalar/Float2Int.h
index 337e229efcf37..6922917624e78 100644
--- a/llvm/include/llvm/Transforms/Scalar/Float2Int.h
+++ b/llvm/include/llvm/Transforms/Scalar/Float2Int.h
@@ -28,6 +28,25 @@ class LLVMContext;
 class Type;
 class Value;
 
+class OrderedInstruction {
+  Instruction *Ins;
+  unsigned int Order;
+
+public:
+  OrderedInstruction(Instruction *Inst, unsigned int Ord) : Ins(Inst), Order(Ord) {}
+
+  Instruction *getInstruction() { return Ins; }
+  unsigned int getOrder() { return Order; }
+};
+
+template <class T> struct OrderedInstructionLess {
+  bool operator()(const T &lhs, const T &rhs) const {
+    OrderedInstruction lhsOrder = lhs;
+    OrderedInstruction rhsOrder = rhs;
+    return rhsOrder.getOrder() < lhsOrder.getOrder();
+  }
+};
+
 class Float2IntPass : public PassInfoMixin<Float2IntPass> {
 public:
   PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
@@ -36,6 +55,7 @@ class Float2IntPass : public PassInfoMixin<Float2IntPass> {
   bool runImpl(Function &F, const DominatorTree &DT);
 
 private:
+  unsigned int insOrder(Instruction *I);
   void findRoots(Function &F, const DominatorTree &DT);
   void seen(Instruction *I, ConstantRange R);
   ConstantRange badRange();
@@ -50,7 +70,9 @@ class Float2IntPass : public PassInfoMixin<Float2IntPass> {
 
   MapVector<Instruction *, ConstantRange> SeenInsts;
   SmallSetVector<Instruction *, 8> Roots;
-  EquivalenceClasses<Instruction *> ECs;
+  EquivalenceClasses<OrderedInstruction,
+                     OrderedInstructionLess<OrderedInstruction>> ECs;
+  MapVector<Instruction *, unsigned int> InstructionOrders;
   MapVector<Instruction *, Value *> ConvertedInsts;
   LLVMContext *Ctx;
 };
diff --git a/llvm/lib/Transforms/Scalar/Float2Int.cpp b/llvm/lib/Transforms/Scalar/Float2Int.cpp
index da4d39b4e3ed4..cc69b78e32dc1 100644
--- a/llvm/lib/Transforms/Scalar/Float2Int.cpp
+++ b/llvm/lib/Transforms/Scalar/Float2Int.cpp
@@ -84,6 +84,16 @@ static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {
   }
 }
 
+// Instruction order - return deterministic order suitable as set
+// order for EquivalenceClasses.
+unsigned int Float2IntPass::insOrder(Instruction* I) {
+  static unsigned int order = 0;
+  if (InstructionOrders.find(I) != InstructionOrders.end())
+    return InstructionOrders[I];
+  InstructionOrders[I] = order++;
+  return order - 1;
+}
+
 // Find the roots - instructions that convert from the FP domain to
 // integer domain.
 void Float2IntPass::findRoots(Function &F, const DominatorTree &DT) {
@@ -191,7 +201,7 @@ void Float2IntPass::walkBackwards() {
     for (Value *O : I->operands()) {
       if (Instruction *OI = dyn_cast<Instruction>(O)) {
         // Unify def-use chains if they interfere.
-        ECs.unionSets(I, OI);
+        ECs.unionSets(OrderedInstruction(I, insOrder(I)), OrderedInstruction(OI, insOrder(OI)));
         if (SeenInsts.find(I)->second != badRange())
           Worklist.push_back(OI);
       } else if (!isa<ConstantFP>(O)) {
@@ -323,7 +333,8 @@ bool Float2IntPass::validateAndTransform(const DataLayout &DL) {
     // For every member of the partition, union all the ranges together.
     for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
          MI != ME; ++MI) {
-      Instruction *I = *MI;
+      OrderedInstruction OMI = *MI;
+      Instruction *I = OMI.getInstruction();
       auto SeenI = SeenInsts.find(I);
       if (SeenI == SeenInsts.end())
         continue;
@@ -392,9 +403,10 @@ bool Float2IntPass::validateAndTransform(const DataLayout &DL) {
       }
     }
 
-    for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
-         MI != ME; ++MI)
-      convert(*MI, Ty);
+    for (auto MI = ECs.member_begin(It), ME = ECs.member_end(); MI != ME; ++MI) {
+      OrderedInstruction OMI = *MI;
+      convert(OMI.getInstruction(), Ty);
+    }
     MadeChange = true;
   }
 
@@ -485,8 +497,9 @@ void Float2IntPass::cleanup() {
 bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) {
   LLVM_DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n");
   // Clear out all state.
-  ECs = EquivalenceClasses<Instruction*>();
+  ECs = EquivalenceClasses<OrderedInstruction, OrderedInstructionLess<OrderedInstruction> >();
   SeenInsts.clear();
+  InstructionOrders.clear();
   ConvertedInsts.clear();
   Roots.clear();
 



More information about the llvm-commits mailing list