[llvm] [SimplifyCFG] Add optimization for switches of powers of two (PR #70977)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 14 19:26:16 PST 2023


https://github.com/DKay7 updated https://github.com/llvm/llvm-project/pull/70977

>From e3e992fb3ea59c31d060658cbe6144b42f8fb294 Mon Sep 17 00:00:00 2001
From: Daniil <kalinin.de at phystech.edu>
Date: Wed, 1 Nov 2023 22:32:23 +0300
Subject: [PATCH] Added optimization for switches of powers of two

Optimization reduces range for switches which cases are positive powers of two by replacing each case with count_trailing_zero(case).
Also, this optimization is performed only for switches with default case unreachable

Resolves #70756
---
 llvm/lib/Transforms/Utils/SimplifyCFG.cpp     |  81 +++++++++-
 .../AArch64/switch-unreachable-default.ll     |   2 +-
 .../SimplifyCFG/switch-of-powers-of-two.ll    | 139 ++++++++++++++++++
 3 files changed, 218 insertions(+), 4 deletions(-)
 create mode 100644 llvm/test/Transforms/SimplifyCFG/switch-of-powers-of-two.ll

diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 68b5b1a78a3460e..c022c068eea58c1 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -23,6 +23,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/bit.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/CaptureTracking.h"
 #include "llvm/Analysis/ConstantFolding.h"
@@ -6792,9 +6793,6 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,
 
   // This transform can be done speculatively because it is so cheap - it
   // results in a single rotate operation being inserted.
-  // FIXME: It's possible that optimizing a switch on powers of two might also
-  // be beneficial - flag values are often powers of two and we could use a CLZ
-  // as the key function.
 
   // countTrailingZeros(0) returns 64. As Values is guaranteed to have more than
   // one element and LLVM disallows duplicate cases, Shift is guaranteed to be
@@ -6839,6 +6837,80 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,
   return true;
 }
 
+static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
+                                        const DataLayout &DL,
+                                        const TargetTransformInfo &TTI) {
+
+  auto *Condition = SI->getCondition();
+  auto &Context = SI->getContext();
+  auto *CondTy = cast<IntegerType>(Condition->getType());
+
+  if (CondTy->getIntegerBitWidth() > 64 ||
+      !DL.fitsInLegalInteger(CondTy->getIntegerBitWidth()))
+    return false;
+
+  const auto CttzIntrinsicCost = TTI.getIntrinsicInstrCost(
+      IntrinsicCostAttributes(Intrinsic::cttz, CondTy,
+                              {Condition, ConstantInt::getTrue(Context)}),
+      TTI::TCK_SizeAndLatency);
+
+  if (CttzIntrinsicCost > TTI::TCC_Basic) {
+    // Inserting intrinsic is too expensive
+    return false;
+  }
+
+  // Only bother with this optimization if there are more than 3 switch cases;
+  // SDAG will only bother creating jump tables for 4 or more cases.
+  if (SI->getNumCases() < 4)
+    return false;
+
+  // We perform this optimization only for switches with
+  // unreachable default case.
+  // This assumtion will save us from checking if `Condition` is a power of two
+  if (!isa<UnreachableInst>(SI->getDefaultDest()->getFirstNonPHIOrDbg()))
+    return false;
+
+  // Check that switch cases are powers of two
+  SmallVector<int64_t, 4> Values;
+  for (const auto &Case : SI->cases()) {
+    int64_t CaseValue = Case.getCaseValue()->getValue().getSExtValue();
+    if (llvm::isPowerOf2_64(CaseValue))
+      Values.push_back(CaseValue);
+    else
+      return false;
+  }
+
+  // isSwichDence requires case values to be sorted
+  llvm::sort(Values);
+  if (isSwitchDense(Values))
+    // Switch is already dense
+    return false;
+
+  if (!isSwitchDense(
+          Values.size(),
+          llvm::countr_zero(static_cast<uint64_t>(Values.back())) -
+              llvm::countr_zero(static_cast<uint64_t>(Values.front())) + 1))
+    // Transform is unable to generate dense switch
+    return false;
+
+  Builder.SetInsertPoint(SI);
+
+  // Replace each case with its trailing zeros number
+  for (auto &Case : SI->cases()) {
+    auto *OrigValue = Case.getCaseValue();
+    Case.setValue(ConstantInt::get(OrigValue->getType(),
+                                   OrigValue->getValue().countr_zero()));
+  }
+
+  // Replace condition with its trailing zeros number
+  auto *ConditionTrailingZeros = Builder.CreateIntrinsic(
+      Intrinsic::cttz, {CondTy}, {Condition, ConstantInt::getTrue(Context)});
+
+  SI->replaceUsesOfWith(Condition, ConditionTrailingZeros);
+
+  return true;
+}
+
 bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
   BasicBlock *BB = SI->getParent();
 
@@ -6886,6 +6958,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
       SwitchToLookupTable(SI, Builder, DTU, DL, TTI))
     return requestResimplify();
 
+  if (simplifySwitchOfPowersOfTwo(SI, Builder, DL, TTI))
+    return requestResimplify();
+
   if (ReduceSwitchRange(SI, Builder, DL, TTI))
     return requestResimplify();
 
diff --git a/llvm/test/CodeGen/AArch64/switch-unreachable-default.ll b/llvm/test/CodeGen/AArch64/switch-unreachable-default.ll
index 9acc5a150b630f2..c38daffb7ade1f2 100644
--- a/llvm/test/CodeGen/AArch64/switch-unreachable-default.ll
+++ b/llvm/test/CodeGen/AArch64/switch-unreachable-default.ll
@@ -71,7 +71,7 @@ entry:
     i32 8,  label %bb2
     i32 16, label %bb3
     i32 32, label %bb4
-    i32 64, label %bb5
+    i32 -64, label %bb5
   ]
 
 ; The switch is lowered with a jump table for cases 1--32 and case 64 handled
diff --git a/llvm/test/Transforms/SimplifyCFG/switch-of-powers-of-two.ll b/llvm/test/Transforms/SimplifyCFG/switch-of-powers-of-two.ll
new file mode 100644
index 000000000000000..42a464b66e0e494
--- /dev/null
+++ b/llvm/test/Transforms/SimplifyCFG/switch-of-powers-of-two.ll
@@ -0,0 +1,139 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 3
+; RUN: opt %s -passes='simplifycfg<switch-to-lookup>' -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s
+
+target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
+
+; Check that the range of switch of powers of two is reduced and switch itself is lowered to jump-table
+define i32 @switch_of_powers(i32 %x) {
+; CHECK-LABEL: define i32 @switch_of_powers(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = call i32 @llvm.cttz.i32(i32 [[X]], i1 true)
+; CHECK-NEXT:    [[SWITCH_MASKINDEX:%.*]] = trunc i32 [[TMP0]] to i8
+; CHECK-NEXT:    [[SWITCH_SHIFTED:%.*]] = lshr i8 121, [[SWITCH_MASKINDEX]]
+; CHECK-NEXT:    [[SWITCH_LOBIT:%.*]] = trunc i8 [[SWITCH_SHIFTED]] to i1
+; CHECK-NEXT:    call void @llvm.assume(i1 [[SWITCH_LOBIT]])
+; CHECK-NEXT:    [[SWITCH_GEP:%.*]] = getelementptr inbounds [7 x i32], ptr @switch.table.switch_of_powers, i32 0, i32 [[TMP0]]
+; CHECK-NEXT:    [[SWITCH_LOAD:%.*]] = load i32, ptr [[SWITCH_GEP]], align 4
+; CHECK-NEXT:    ret i32 [[SWITCH_LOAD]]
+;
+
+entry:
+  switch i32 %x, label %default_case [
+  i32 1,  label %bb1
+  i32 8,  label %bb2
+  i32 16, label %bb3
+  i32 32, label %bb4
+  i32 64, label %bb5
+  ]
+
+
+default_case: unreachable
+bb1: br label %return
+bb2: br label %return
+bb3: br label %return
+bb4: br label %return
+bb5: br label %return
+
+return:
+  %p = phi i32 [ 3, %bb1 ], [ 2, %bb2 ], [ 1, %bb3 ], [ 0, %bb4 ], [ 42, %bb5 ]
+  ret i32 %p
+}
+
+; Check that switch's of powers of two range is not reduced if default case is reachable
+define i32 @switch_of_powers_reachable_default(i32 %x) {
+; CHECK-LABEL: define i32 @switch_of_powers_reachable_default(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    switch i32 [[X]], label [[RETURN:%.*]] [
+; CHECK-NEXT:    i32 1, label [[BB1:%.*]]
+; CHECK-NEXT:    i32 8, label [[BB2:%.*]]
+; CHECK-NEXT:    i32 16, label [[BB3:%.*]]
+; CHECK-NEXT:    i32 32, label [[BB4:%.*]]
+; CHECK-NEXT:    i32 64, label [[BB5:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       bb1:
+; CHECK-NEXT:    br label [[RETURN]]
+; CHECK:       bb2:
+; CHECK-NEXT:    br label [[RETURN]]
+; CHECK:       bb3:
+; CHECK-NEXT:    br label [[RETURN]]
+; CHECK:       bb4:
+; CHECK-NEXT:    br label [[RETURN]]
+; CHECK:       bb5:
+; CHECK-NEXT:    br label [[RETURN]]
+; CHECK:       return:
+; CHECK-NEXT:    [[P:%.*]] = phi i32 [ 3, [[BB1]] ], [ 2, [[BB2]] ], [ 1, [[BB3]] ], [ 0, [[BB4]] ], [ 42, [[BB5]] ], [ -1, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    ret i32 [[P]]
+;
+
+entry:
+  switch i32 %x, label %default_case [
+  i32 1,  label %bb1
+  i32 8,  label %bb2
+  i32 16, label %bb3
+  i32 32, label %bb4
+  i32 64, label %bb5
+  ]
+
+
+default_case: br label %return
+bb1: br label %return
+bb2: br label %return
+bb3: br label %return
+bb4: br label %return
+bb5: br label %return
+
+return:
+  %p = phi i32 [ 3, %bb1 ], [ 2, %bb2 ], [ 1, %bb3 ], [ 0, %bb4 ], [ 42, %bb5 ], [-1, %default_case]
+  ret i32 %p
+}
+
+; Check that switch with zero case is not considered as switch of powers of two
+define i32 @switch_of_non_powers(i32 %x) {
+; CHECK-LABEL: define i32 @switch_of_non_powers(
+; CHECK-SAME: i32 [[X:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    switch i32 [[X]], label [[DEFAULT_CASE:%.*]] [
+; CHECK-NEXT:    i32 0, label [[RETURN:%.*]]
+; CHECK-NEXT:    i32 1, label [[BB2:%.*]]
+; CHECK-NEXT:    i32 16, label [[BB3:%.*]]
+; CHECK-NEXT:    i32 32, label [[BB4:%.*]]
+; CHECK-NEXT:    i32 64, label [[BB5:%.*]]
+; CHECK-NEXT:    ]
+; CHECK:       default_case:
+; CHECK-NEXT:    unreachable
+; CHECK:       bb2:
+; CHECK-NEXT:    br label [[RETURN]]
+; CHECK:       bb3:
+; CHECK-NEXT:    br label [[RETURN]]
+; CHECK:       bb4:
+; CHECK-NEXT:    br label [[RETURN]]
+; CHECK:       bb5:
+; CHECK-NEXT:    br label [[RETURN]]
+; CHECK:       return:
+; CHECK-NEXT:    [[P:%.*]] = phi i32 [ 2, [[BB2]] ], [ 1, [[BB3]] ], [ 0, [[BB4]] ], [ 42, [[BB5]] ], [ 3, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    ret i32 [[P]]
+;
+
+entry:
+  switch i32 %x, label %default_case [
+  i32 0,  label %bb1
+  i32 1,  label %bb2
+  i32 16, label %bb3
+  i32 32, label %bb4
+  i32 64, label %bb5
+  ]
+
+
+default_case: unreachable
+bb1: br label %return
+bb2: br label %return
+bb3: br label %return
+bb4: br label %return
+bb5: br label %return
+
+return:
+  %p = phi i32 [ 3, %bb1 ], [ 2, %bb2 ], [ 1, %bb3 ], [ 0, %bb4 ], [ 42, %bb5 ]
+  ret i32 %p
+}



More information about the llvm-commits mailing list