[llvm] Add a pass to convert jump tables to switches. (PR #77709)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 9 13:12:33 PST 2024


================
@@ -0,0 +1,190 @@
+//===- JumpTableToSwitch.cpp ----------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/ConstantFolding.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/PostDominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+using namespace llvm;
+
+static cl::opt<unsigned>
+    JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden,
+                           cl::desc("Only split jump tables with size less or "
+                                    "equal than JumpTableSizeThreshold."),
+                           cl::init(10));
+
+// TODO: Consider adding a cost model for profitability analysis of this
+// transformation. Currently we replace a jump table with a switch if all the
+// functions in the jump table are smaller than the provided threshold.
+static cl::opt<unsigned> FunctionSizeThreshold(
+    "jump-table-to-switch-function-size-threshold", cl::Hidden,
+    cl::desc("Only split jump tables containing functions whose sizes are less "
+             "or equal than this threshold."),
+    cl::init(50));
+
+#define DEBUG_TYPE "jump-table-to-switch"
+
+namespace {
+struct JumpTableTy {
+  Value *Index;
+  SmallVector<Function *, 10> Funcs;
+};
+} // anonymous namespace
+
+static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
+                                                 PointerType *PtrTy) {
+  Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand());
+  if (!Ptr)
+    return std::nullopt;
+
+  GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr);
+  if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
+    return std::nullopt;
+
+  Function &F = *GEP->getParent()->getParent();
+  const DataLayout &DL = F.getParent()->getDataLayout();
+  const unsigned BitWidth =
+      DL.getIndexSizeInBits(GEP->getPointerAddressSpace());
+  MapVector<Value *, APInt> VariableOffsets;
+  APInt ConstantOffset(BitWidth, 0);
+  if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))
+    return std::nullopt;
+  if (VariableOffsets.size() != 1)
+    return std::nullopt;
+  // TODO: consider supporting more general patterns
+  if (!ConstantOffset.isZero())
+    return std::nullopt;
+  APInt StrideBytes = VariableOffsets.front().second;
+  const uint64_t JumpTableSizeBytes = DL.getTypeAllocSize(GV->getValueType());
+  if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0)
+    return std::nullopt;
+  const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue();
+  if (N > JumpTableSizeThreshold)
+    return std::nullopt;
+
+  JumpTableTy JumpTable;
+  JumpTable.Index = VariableOffsets.front().first;
+  JumpTable.Funcs.reserve(N);
+  for (uint64_t Index = 0; Index < N; ++Index) {
+    // ConstantOffset is zero.
+    APInt Offset = Index * StrideBytes;
+    Constant *C = ConstantFoldLoadFromConst(
+        cast<Constant>(GV->getInitializer()), PtrTy, Offset, DL);
----------------
nikic wrote:

This `cast<Constant>` should not be needed.

https://github.com/llvm/llvm-project/pull/77709


More information about the llvm-commits mailing list