[llvm] Added optimization for switches of powers of two (PR #70977)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 1 12:51:11 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Daniil (DKay7)

<details>
<summary>Changes</summary>

Optimization reduces range for switches which cases are positive powers of two by replacing each case with count_trailing_zero(case).

Resolves #<!-- -->70756

---
Full diff: https://github.com/llvm/llvm-project/pull/70977.diff


1 Files Affected:

- (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+73-3) 


``````````diff
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 68b5b1a78a3460e..5e44015e1e93595 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -50,6 +50,7 @@
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
@@ -6792,9 +6793,6 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,
 
   // This transform can be done speculatively because it is so cheap - it
   // results in a single rotate operation being inserted.
-  // FIXME: It's possible that optimizing a switch on powers of two might also
-  // be beneficial - flag values are often powers of two and we could use a CLZ
-  // as the key function.
 
   // countTrailingZeros(0) returns 64. As Values is guaranteed to have more than
   // one element and LLVM disallows duplicate cases, Shift is guaranteed to be
@@ -6839,6 +6837,75 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,
   return true;
 }
 
+static bool isSwitchOfPowersOfTwo(ArrayRef<int64_t> Values) {
+  for (auto &Value : Values) {
+    if (Value <= 0 || (Value & (Value - 1)) != 0)
+      return false;
+  }
+
+  return true;
+}
+
+static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
+                                        const DataLayout &DL) {
+
+  auto *CondTy = cast<IntegerType>(SI->getCondition()->getType());
+
+  if (CondTy->getIntegerBitWidth() > 64 ||
+      !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth()))
+    return false;
+
+  // Only bother with this optimization if there are more than 3 switch cases;
+  // SDAG will only bother creating jump tables for 4 or more cases.
+  if (SI->getNumCases() < 4)
+    return false;
+
+  SmallVector<int64_t, 4> Values;
+  for (const auto &Case : SI->cases())
+    Values.push_back(Case.getCaseValue()->getValue().getSExtValue());
+
+  if (!isSwitchOfPowersOfTwo(Values))
+    return false;
+
+  Builder.SetInsertPoint(SI);
+
+  auto *Condition = SI->getCondition();
+  auto &Context = SI->getContext();
+  Function *Cttz =
+      Intrinsic::getDeclaration(SI->getModule(), Intrinsic::cttz, {CondTy});
+
+  // FIXME we can also perform this optimization only for switches with
+  // unreachable default case.
+  // This assumtion will save us from checking if `Condition` is a power of two
+
+  // checking if switch condition is a power of two. If not, just set its number
+  // of trailing zeros to 0 since `case 0` is forbidden it will lead to jumping
+  // to default case
+  // condition & (condition - 1) result in zero if condition is a power of two
+  auto *Sub = Builder.CreateSub(Condition, ConstantInt::get(CondTy, 1));
+  auto *And = Builder.CreateAnd(Condition, Sub);
+  auto *Cmp = Builder.CreateICmpNE(And, ConstantInt::get(CondTy, 0));
+
+  // if condition is not a power of two, we add 1 to it to make set its number
+  // of trailing zeros to 0. and case 0 is forbidden in this optimization.
+  auto *Add = Builder.CreateOr(Condition, Builder.CreateZExt(Cmp, CondTy));
+
+  // FIXME maybe we should check if cttz intrinsic is cheap on the target
+  // architecture
+  auto *ResultTrailingZeros = Builder.CreateCall(
+      Cttz, {Add, ConstantInt::get(Type::getInt1Ty(Context), 0)});
+  SI->replaceUsesOfWith(Condition, ResultTrailingZeros);
+
+  // Replace each case with its trailing zeros number
+  for (auto &Case : SI->cases()) {
+    auto *OrigValue = Case.getCaseValue();
+    Case.setValue(cast<ConstantInt>(ConstantInt::get(
+        OrigValue->getType(), OrigValue->getValue().countr_zero())));
+  }
+
+  return true;
+}
+
 bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
   BasicBlock *BB = SI->getParent();
 
@@ -6886,6 +6953,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
       SwitchToLookupTable(SI, Builder, DTU, DL, TTI))
     return requestResimplify();
 
+  if (simplifySwitchOfPowersOfTwo(SI, Builder, DL))
+    return requestResimplify();
+
   if (ReduceSwitchRange(SI, Builder, DL, TTI))
     return requestResimplify();
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/70977


More information about the llvm-commits mailing list