[llvm] [FuncSpec] Improve estimation of select instruction. (PR #111176)
Alexandros Lamprineas via llvm-commits
llvm-commits at lists.llvm.org
Tue Oct 8 08:02:37 PDT 2024
https://github.com/labrinea updated https://github.com/llvm/llvm-project/pull/111176
>From 4744e5922e8ef1e960ec085cd5a85cea9170e724 Mon Sep 17 00:00:00 2001
From: Alexandros Lamprineas <alexandros.lamprineas at arm.com>
Date: Fri, 4 Oct 2024 16:17:51 +0100
Subject: [PATCH 1/2] [FuncSpec] Improve estimation of select instruction.
When propagating a constant to a select instruction we only consider
the condition operand as the use. I am extending the logic to consider
the true and false values too, in case the condition had been found
to be constant in a previous propagation but halted.
---
.../Transforms/IPO/FunctionSpecialization.cpp | 17 ++++++++++-------
1 file changed, 10 insertions(+), 7 deletions(-)
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 548335d750e33d..7d109af9091479 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -423,13 +423,16 @@ Constant *InstCostVisitor::visitGetElementPtrInst(GetElementPtrInst &I) {
Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
- if (I.getCondition() != LastVisited->first)
- return nullptr;
-
- Value *V = LastVisited->second->isZeroValue() ? I.getFalseValue()
- : I.getTrueValue();
- Constant *C = findConstantFor(V, KnownConstants);
- return C;
+ if (I.getCondition() == LastVisited->first) {
+ Value *V = LastVisited->second->isZeroValue() ? I.getFalseValue()
+ : I.getTrueValue();
+ return findConstantFor(V, KnownConstants);
+ }
+ if (Constant *Condition = findConstantFor(I.getCondition(), KnownConstants))
+ if (I.getTrueValue() == LastVisited->first && Condition->isOneValue() ||
+ I.getFalseValue() == LastVisited->first && Condition->isZeroValue())
+ return LastVisited->second;
+ return nullptr;
}
Constant *InstCostVisitor::visitCastInst(CastInst &I) {
>From 2d8c3396dbba6dd8e8dba06da0bdf3da2bfba329 Mon Sep 17 00:00:00 2001
From: Alexandros Lamprineas <alexandros.lamprineas at arm.com>
Date: Tue, 8 Oct 2024 15:59:18 +0100
Subject: [PATCH 2/2] Changes from last revision:
* Added surrounding parentheses in condition to fix warning.
* Added a unittest.
---
.../Transforms/IPO/FunctionSpecialization.cpp | 4 +--
.../IPO/FunctionSpecializationTest.cpp | 28 +++++++++++++++++++
2 files changed, 30 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 7d109af9091479..bd0a337e579e48 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -429,8 +429,8 @@ Constant *InstCostVisitor::visitSelectInst(SelectInst &I) {
return findConstantFor(V, KnownConstants);
}
if (Constant *Condition = findConstantFor(I.getCondition(), KnownConstants))
- if (I.getTrueValue() == LastVisited->first && Condition->isOneValue() ||
- I.getFalseValue() == LastVisited->first && Condition->isZeroValue())
+ if ((I.getTrueValue() == LastVisited->first && Condition->isOneValue()) ||
+ (I.getFalseValue() == LastVisited->first && Condition->isZeroValue()))
return LastVisited->second;
return nullptr;
}
diff --git a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
index 52bad210b583ed..b0ff55489e1762 100644
--- a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
+++ b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
@@ -261,6 +261,34 @@ TEST_F(FunctionSpecializationTest, BranchInst) {
EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
}
+TEST_F(FunctionSpecializationTest, SelectInst) {
+ const char *ModuleString = R"(
+ define i32 @foo(i1 %cond, i32 %a, i32 %b) {
+ %sel = select i1 %cond, i32 %a, i32 %b
+ ret i32 %sel
+ }
+ )";
+
+ Module &M = parseModule(ModuleString);
+ Function *F = M.getFunction("foo");
+ FunctionSpecializer Specializer = getSpecializerFor(F);
+ InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
+
+ Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
+ Constant *Zero = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 0);
+ Constant *False = ConstantInt::getFalse(M.getContext());
+ Instruction &Select = *F->front().begin();
+
+ Bonus Ref = getInstCost(Select);
+ Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), False);
+ EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
+ Test = Visitor.getSpecializationBonus(F->getArg(1), One);
+ EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
+ Test = Visitor.getSpecializationBonus(F->getArg(2), Zero);
+ EXPECT_EQ(Test, Ref);
+ EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
+}
+
TEST_F(FunctionSpecializationTest, Misc) {
const char *ModuleString = R"(
%struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
More information about the llvm-commits
mailing list