[llvm] 6472cb1 - [FuncSpec] Improve estimation of select instruction. (#111176)

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 9 02:25:23 PDT 2024


Author: Alexandros Lamprineas
Date: 2024-10-09T10:25:20+01:00
New Revision: 6472cb1e219f631ed504bb1c5675853168748d21

URL: https://github.com/llvm/llvm-project/commit/6472cb1e219f631ed504bb1c5675853168748d21
DIFF: https://github.com/llvm/llvm-project/commit/6472cb1e219f631ed504bb1c5675853168748d21.diff

LOG: [FuncSpec] Improve estimation of select instruction. (#111176)

When propagating a constant to a select instruction we only consider the
condition operand as the use. I am extending the logic to consider the
true and false values too, in case the condition had been found to be
constant in a previous propagation but halted.

Added: 
    

Modified: 
    llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
    llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 548335d750e33d..bd0a337e579e48 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -423,13 +423,16 @@ Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
 Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
   assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
 
-  if (I.getCondition() != LastVisited->first)
-    return nullptr;
-
-  Value *V = LastVisited->second->isZeroValue() ? I.getFalseValue()
-                                                : I.getTrueValue();
-  Constant *C = findConstantFor(V, KnownConstants);
-  return C;
+  if (I.getCondition() == LastVisited->first) {
+    Value *V = LastVisited->second->isZeroValue() ? I.getFalseValue()
+                                                  : I.getTrueValue();
+    return findConstantFor(V, KnownConstants);
+  }
+  if (Constant *Condition = findConstantFor(I.getCondition(), KnownConstants))
+    if ((I.getTrueValue() == LastVisited->first && Condition->isOneValue()) ||
+        (I.getFalseValue() == LastVisited->first && Condition->isZeroValue()))
+      return LastVisited->second;
+  return nullptr;
 }
 
 Constant *InstCostVisitor::visitCastInst(CastInst &I) {

diff  --git a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
index 52bad210b583ed..b0ff55489e1762 100644
--- a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
+++ b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
@@ -261,6 +261,34 @@ TEST_F(FunctionSpecializationTest, BranchInst) {
   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
 }
 
+TEST_F(FunctionSpecializationTest, SelectInst) {
+  const char *ModuleString = R"(
+    define i32 @foo(i1 %cond, i32 %a, i32 %b) {
+      %sel = select i1 %cond, i32 %a, i32 %b
+      ret i32 %sel
+    }
+  )";
+
+  Module &M = parseModule(ModuleString);
+  Function *F = M.getFunction("foo");
+  FunctionSpecializer Specializer = getSpecializerFor(F);
+  InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
+
+  Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
+  Constant *Zero = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 0);
+  Constant *False = ConstantInt::getFalse(M.getContext());
+  Instruction &Select = *F->front().begin();
+
+  Bonus Ref = getInstCost(Select);
+  Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), False);
+  EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
+  Test = Visitor.getSpecializationBonus(F->getArg(1), One);
+  EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
+  Test = Visitor.getSpecializationBonus(F->getArg(2), Zero);
+  EXPECT_EQ(Test, Ref);
+  EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
+}
+
 TEST_F(FunctionSpecializationTest, Misc) {
   const char *ModuleString = R"(
     %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }


        


More information about the llvm-commits mailing list