[llvm] [FuncSpec] Improve handling of BinaryOperator instructions (PR #114534)

Hari Limaye via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 1 05:13:36 PDT 2024


https://github.com/hazzlim created https://github.com/llvm/llvm-project/pull/114534

When visiting BinaryOperator instructions during estimation of codesize savings for a candidate specialization, don't bail when the other operand is not found to be constant. This allows us to find more constants than we otherwise would, for example `and(false, x)`.

>From eb3da8c8bf0876a18b1d4486760614ff46b52fbc Mon Sep 17 00:00:00 2001
From: Hari Limaye <hari.limaye at arm.com>
Date: Thu, 31 Oct 2024 13:23:10 +0000
Subject: [PATCH] [FuncSpec] Improve handling of BinaryOperator instructions

When visiting BinaryOperator instructions during estimation of codesize
savings for a candidate specialization, don't bail when the other
operand is not found to be constant. This allows us to find more
constants than we otherwise would, for example `and(false, x)`.
---
 .../Transforms/IPO/FunctionSpecialization.cpp | 17 ++++-----
 .../IPO/FunctionSpecializationTest.cpp        | 35 +++++++++++++++++++
 2 files changed, 44 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
index 919d3143a13f7e..59ae15a6071e9c 100644
--- a/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
+++ b/llvm/lib/Transforms/IPO/FunctionSpecialization.cpp
@@ -489,16 +489,17 @@ Constant *InstCostVisitor::visitUnaryOperator(UnaryOperator &I) {
 Constant *InstCostVisitor::visitBinaryOperator(BinaryOperator &I) {
   assert(LastVisited != KnownConstants.end() && "Invalid iterator!");
 
-  bool Swap = I.getOperand(1) == LastVisited->first;
-  Value *V = Swap ? I.getOperand(0) : I.getOperand(1);
+  bool ConstOnRHS = I.getOperand(1) == LastVisited->first;
+  Value *V = ConstOnRHS ? I.getOperand(0) : I.getOperand(1);
   Constant *Other = findConstantFor(V, KnownConstants);
-  if (!Other)
-    return nullptr;
+  Value *OtherVal = Other ? Other : V;
+  Value *ConstVal = LastVisited->second;
 
-  Constant *Const = LastVisited->second;
-  return dyn_cast_or_null<Constant>(Swap ?
-        simplifyBinOp(I.getOpcode(), Other, Const, SimplifyQuery(DL))
-      : simplifyBinOp(I.getOpcode(), Const, Other, SimplifyQuery(DL)));
+  if (ConstOnRHS)
+    std::swap(ConstVal, OtherVal);
+
+  return dyn_cast_or_null<Constant>(
+      simplifyBinOp(I.getOpcode(), ConstVal, OtherVal, SimplifyQuery(DL)));
 }
 
 Constant *FunctionSpecializer::getPromotableAlloca(AllocaInst *Alloca,
diff --git a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
index c8fd366bfac65f..9f76e9ff11c3aa 100644
--- a/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
+++ b/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp
@@ -469,3 +469,38 @@ TEST_F(FunctionSpecializationTest, PhiNode) {
   EXPECT_TRUE(Test > 0);
 }
 
+TEST_F(FunctionSpecializationTest, BinOp) {
+  // Verify that we can handle binary operators even when only one operand is
+  // constant.
+  const char *ModuleString = R"(
+    define i32 @foo(i1 %a, i1 %b) {
+      %and1 = and i1 %a, %b
+      %and2 = and i1 %b, %and1
+      %sel = select i1 %and2, i32 1, i32 0
+      ret i32 %sel
+    }
+  )";
+
+  Module &M = parseModule(ModuleString);
+  Function *F = M.getFunction("foo");
+  FunctionSpecializer Specializer = getSpecializerFor(F);
+  InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
+
+  Constant *False = ConstantInt::getFalse(M.getContext());
+  BasicBlock &BB = F->front();
+  Instruction &And1 = BB.front();
+  Instruction &And2 = *++BB.begin();
+  Instruction &Select = *++BB.begin();
+
+  Cost RefCodeSize = getCodeSizeSavings(And1) + getCodeSizeSavings(And2) +
+                     getCodeSizeSavings(Select);
+  Cost RefLatency = getLatencySavings(F);
+
+  Cost TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(0), False);
+  Cost TestLatency = Visitor.getLatencySavingsForKnownConstants();
+
+  EXPECT_EQ(TestCodeSize, RefCodeSize);
+  EXPECT_TRUE(TestCodeSize > 0);
+  EXPECT_EQ(TestLatency, RefLatency);
+  EXPECT_TRUE(TestLatency > 0);
+}



More information about the llvm-commits mailing list