[llvm] [InferAlignment] Add variable offset alignment analysis for GEPs in InferAlignment (PR #185980)

via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 17 19:18:59 PDT 2026


https://github.com/yasmincs updated https://github.com/llvm/llvm-project/pull/185980

>From 6144fa5281162a9c48d5fa1414952af115ea0f1e Mon Sep 17 00:00:00 2001
From: ysarita <ysarita at nvidia.com>
Date: Tue, 17 Mar 2026 22:52:57 +0000
Subject: [PATCH 1/3] Add test for variable offset alignment inference

---
 .../variable-offset-alignment.ll              | 84 +++++++++++++++++++
 1 file changed, 84 insertions(+)
 create mode 100644 llvm/test/Transforms/InferAlignment/variable-offset-alignment.ll

diff --git a/llvm/test/Transforms/InferAlignment/variable-offset-alignment.ll b/llvm/test/Transforms/InferAlignment/variable-offset-alignment.ll
new file mode 100644
index 0000000000000..b5e59e73bdf88
--- /dev/null
+++ b/llvm/test/Transforms/InferAlignment/variable-offset-alignment.ll
@@ -0,0 +1,84 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2
+; RUN: opt < %s -passes=infer-alignment -S | FileCheck %s
+
+define void @test_assume_then_load(ptr %base, i64 %idx) {
+; CHECK-LABEL: define void @test_assume_then_load
+; CHECK-SAME: (ptr [[BASE:%.*]], i64 [[IDX:%.*]]) {
+; CHECK-NEXT:    [[IDX_SHIFTED:%.*]] = shl i64 [[IDX]], 2
+; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr inbounds float, ptr [[BASE]], i64 [[IDX_SHIFTED]]
+; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[GEP1]], i64 16) ]
+; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr inbounds float, ptr [[BASE]], i64 [[IDX_SHIFTED]]
+; CHECK-NEXT:    [[LOAD:%.*]] = load float, ptr [[GEP2]], align 4
+; CHECK-NEXT:    ret void
+;
+  %idx.shifted = shl i64 %idx, 2
+  %gep1 = getelementptr inbounds float, ptr %base, i64 %idx.shifted
+  call void @llvm.assume(i1 true) [ "align"(ptr %gep1, i64 16) ]
+  %gep2 = getelementptr inbounds float, ptr %base, i64 %idx.shifted
+  %load = load float, ptr %gep2, align 4
+  ret void
+}
+
+define void @test_assume_addrspace_cast(ptr %base, i64 %idx) {
+; CHECK-LABEL: define void @test_assume_addrspace_cast
+; CHECK-SAME: (ptr [[BASE:%.*]], i64 [[IDX:%.*]]) {
+; CHECK-NEXT:    [[IDX_SHIFTED:%.*]] = shl i64 [[IDX]], 2
+; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr inbounds float, ptr [[BASE]], i64 [[IDX_SHIFTED]]
+; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[GEP1]], i64 16) ]
+; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr [[BASE]] to ptr addrspace(1)
+; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr inbounds float, ptr addrspace(1) [[CAST]], i64 [[IDX_SHIFTED]]
+; CHECK-NEXT:    [[LOAD:%.*]] = load float, ptr addrspace(1) [[GEP2]], align 4
+; CHECK-NEXT:    ret void
+;
+  %idx.shifted = shl i64 %idx, 2
+  %gep1 = getelementptr inbounds float, ptr %base, i64 %idx.shifted
+  call void @llvm.assume(i1 true) [ "align"(ptr %gep1, i64 16) ]
+  %cast = addrspacecast ptr %base to ptr addrspace(1)
+  %gep2 = getelementptr inbounds float, ptr addrspace(1) %cast, i64 %idx.shifted
+  %load = load float, ptr addrspace(1) %gep2, align 4
+  ret void
+}
+
+define void @test_multiple_assumes(ptr %base, i64 %idx1, i64 %idx2) {
+; CHECK-LABEL: define void @test_multiple_assumes
+; CHECK-SAME: (ptr [[BASE:%.*]], i64 [[IDX1:%.*]], i64 [[IDX2:%.*]]) {
+; CHECK-NEXT:    [[IDX1_SHIFTED:%.*]] = shl i64 [[IDX1]], 2
+; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr inbounds float, ptr [[BASE]], i64 [[IDX1_SHIFTED]]
+; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[GEP1]], i64 16) ]
+; CHECK-NEXT:    [[IDX2_SHIFTED:%.*]] = shl i64 [[IDX2]], 1
+; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr inbounds float, ptr [[BASE]], i64 [[IDX2_SHIFTED]]
+; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[GEP2]], i64 8) ]
+; CHECK-NEXT:    [[IDX3_SHIFTED:%.*]] = shl i64 [[IDX1]], 2
+; CHECK-NEXT:    [[GEP3:%.*]] = getelementptr inbounds float, ptr [[BASE]], i64 [[IDX3_SHIFTED]]
+; CHECK-NEXT:    [[LOAD:%.*]] = load float, ptr [[GEP3]], align 4
+; CHECK-NEXT:    ret void
+;
+  %idx1.shifted = shl i64 %idx1, 2
+  %gep1 = getelementptr inbounds float, ptr %base, i64 %idx1.shifted
+  call void @llvm.assume(i1 true) [ "align"(ptr %gep1, i64 16) ]
+  %idx2.shifted = shl i64 %idx2, 1
+  %gep2 = getelementptr inbounds float, ptr %base, i64 %idx2.shifted
+  call void @llvm.assume(i1 true) [ "align"(ptr %gep2, i64 8) ]
+  %idx3.shifted = shl i64 %idx1, 2
+  %gep3 = getelementptr inbounds float, ptr %base, i64 %idx3.shifted
+  %load = load float, ptr %gep3, align 4
+  ret void
+}
+
+define void @test_base_align_then_variable_gep(ptr %base, i64 %idx) {
+; CHECK-LABEL: define void @test_base_align_then_variable_gep
+; CHECK-SAME: (ptr [[BASE:%.*]], i64 [[IDX:%.*]]) {
+; CHECK-NEXT:    [[LOAD1:%.*]] = load float, ptr [[BASE]], align 16
+; CHECK-NEXT:    [[IDX_SHIFTED:%.*]] = shl i64 [[IDX]], 1
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds float, ptr [[BASE]], i64 [[IDX_SHIFTED]]
+; CHECK-NEXT:    [[LOAD2:%.*]] = load float, ptr [[GEP]], align 4
+; CHECK-NEXT:    ret void
+;
+  %load1 = load float, ptr %base, align 16
+  %idx.shifted = shl i64 %idx, 1
+  %gep = getelementptr inbounds float, ptr %base, i64 %idx.shifted
+  %load2 = load float, ptr %gep, align 4
+  ret void
+}
+
+declare void @llvm.assume(i1)

>From 29686343019d3461efb5f72701b8a33a1df889f3 Mon Sep 17 00:00:00 2001
From: ysarita <ysarita at nvidia.com>
Date: Tue, 17 Mar 2026 22:53:24 +0000
Subject: [PATCH 2/3] Add variable offset alignment computations to
 inferalignment

---
 llvm/lib/Transforms/Scalar/InferAlignment.cpp | 194 ++++++++++++++++--
 .../variable-offset-alignment.ll              |   8 +-
 2 files changed, 186 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/InferAlignment.cpp b/llvm/lib/Transforms/Scalar/InferAlignment.cpp
index a08f70ac188a6..02c097c5292b4 100644
--- a/llvm/lib/Transforms/Scalar/InferAlignment.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAlignment.cpp
@@ -21,6 +21,7 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/Support/KnownBits.h"
+#include "llvm/Support/MathExtras.h"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/Local.h"
 
@@ -119,32 +120,131 @@ bool inferAlignment(Function &F, AssumptionCache &AC, DominatorTree &DT) {
     return Align(1ull << std::min(Known.getBitWidth() - 1, TrailZ));
   };
 
+  // Helper function to compute variable offset alignment and base pointer
+  // If ConstOffset > 0, the effective offset alignment is limited by the
+  // constant offset
+  auto computeVariableOffsetAlignment =
+      [&](Value *Ptr, Instruction *I,
+          uint64_t ConstOffset = 0) -> std::pair<Value *, Align> {
+    Align VarOffsetAlign = Align(1);
+    Value *VarBasePtr = Ptr;
+    bool FirstGEP = true;
+    while (true) {
+      if (auto *GEP = dyn_cast<GEPOperator>(VarBasePtr)) {
+        // We can only handle GEPs with a single index
+        if (GEP->getNumIndices() != 1)
+          break;
+
+        Value *Idx = GEP->idx_begin()->get();
+        KnownBits Known = computeKnownBits(Idx, DL, &AC, I, &DT);
+        unsigned TrailZ = std::min(Known.countMinTrailingZeros(),
+                                   +Value::MaxAlignmentExponent);
+        Align IndexAlign(1ull << std::min(Known.getBitWidth() - 1, TrailZ));
+        Type *EltTy = GEP->getSourceElementType();
+        TypeSize EltSizeType = DL.getTypeAllocSize(EltTy);
+
+        // If we encounter a scalable type, we can't compute alignment for the
+        // chain
+        if (EltSizeType.isScalable())
+          break;
+
+        uint64_t EltSize = EltSizeType.getFixedValue();
+
+        // Compute offset alignment: multiply index alignment by element size,
+        // then take the greatest power of 2 that divides the product
+        uint64_t Product = IndexAlign.value() * EltSize;
+        uint64_t ProductAlignValue = Product > 0 ? (Product & (~Product + 1))
+                                                 : 1; // Extract lowest set bit
+        Align GEPOffsetAlign = Align(ProductAlignValue);
+        if (FirstGEP) {
+          VarOffsetAlign = GEPOffsetAlign;
+          FirstGEP = false;
+        } else {
+          // When combining offsets that are added together, use GCD
+          // (commonAlignment)
+          VarOffsetAlign =
+              commonAlignment(VarOffsetAlign, GEPOffsetAlign.value());
+        }
+
+        VarBasePtr = GEP->getPointerOperand();
+      } else {
+        break;
+      }
+    }
+    VarBasePtr = VarBasePtr->stripPointerCasts();
+
+    // If we have a constant offset, the effective alignment is the GCD of both
+    if (ConstOffset > 0) {
+      VarOffsetAlign = commonAlignment(VarOffsetAlign, ConstOffset);
+    }
+
+    return {VarBasePtr, VarOffsetAlign};
+  };
+
   // Propagate alignment between loads and stores that originate from the
   // same base pointer.
   DenseMap<Value *, Align> BestBasePointerAligns;
-  auto InferFromBasePointer = [&](Value *PtrOp, Align LoadStoreAlign) {
+
+  // Compute final alignment from a base pointer and offset.
+  // UseConstOffset: if true, use ConstOffset; if false, use VarOffsetAlign
+  auto computeFinalAlign = [&](Value *BasePtr, Align FallbackAlign,
+                               bool UseConstOffset, uint64_t ConstOffset,
+                               Align VarOffsetAlign) -> Align {
+    Align StoredBaseAlign = Align(1);
+    if (auto It = BestBasePointerAligns.find(BasePtr);
+        It != BestBasePointerAligns.end()) {
+      StoredBaseAlign = It->second;
+    }
+
+    Align BaseAlign =
+        StoredBaseAlign > Align(1) ? StoredBaseAlign : FallbackAlign;
+
+    // Apply offset alignment (either constant or variable)
+    if (UseConstOffset) {
+      return commonAlignment(BaseAlign, ConstOffset);
+    } else {
+      return commonAlignment(BaseAlign, VarOffsetAlign.value());
+    }
+  };
+
+  auto InferFromBasePointer = [&](Value *PtrOp, Align LoadStoreAlign,
+                                  Instruction *I) {
+    // Handle constant offsets
     APInt OffsetFromBase(DL.getIndexTypeSizeInBits(PtrOp->getType()), 0);
-    PtrOp = PtrOp->stripAndAccumulateConstantOffsets(DL, OffsetFromBase, true);
+    Value *ConstBasePtr =
+        PtrOp->stripAndAccumulateConstantOffsets(DL, OffsetFromBase, true);
+    uint64_t ConstOffsetVal = OffsetFromBase.abs().getLimitedValue();
+
     // Derive the base pointer alignment from the load/store alignment
     // and the offset from the base pointer.
-    Align BasePointerAlign =
-        commonAlignment(LoadStoreAlign, OffsetFromBase.getLimitedValue());
+    Align BasePointerAlign = commonAlignment(LoadStoreAlign, ConstOffsetVal);
 
     auto [It, Inserted] =
-        BestBasePointerAligns.try_emplace(PtrOp, BasePointerAlign);
+        BestBasePointerAligns.try_emplace(ConstBasePtr, BasePointerAlign);
     if (!Inserted) {
       // If the stored base pointer alignment is better than the
       // base pointer alignment we derived, we may be able to use it
       // to improve the load/store alignment. If not, store the
       // improved base pointer alignment for future iterations.
-      if (It->second > BasePointerAlign) {
-        Align BetterLoadStoreAlign =
-            commonAlignment(It->second, OffsetFromBase.getLimitedValue());
-        return BetterLoadStoreAlign;
+      if (It->second < BasePointerAlign) {
+        It->second = BasePointerAlign;
       }
-      It->second = BasePointerAlign;
     }
-    return LoadStoreAlign;
+
+    // Handle variable offsets (constant offset is handled inside the function)
+    auto [VarBasePtr, VarOffsetAlign] =
+        computeVariableOffsetAlignment(ConstBasePtr, I, ConstOffsetVal);
+
+    // Compute final alignment for constant method
+    Align ConstFinalAlign = computeFinalAlign(ConstBasePtr, LoadStoreAlign,
+                                              true, ConstOffsetVal, Align(1));
+
+    // Compute final alignment for variable method
+    Align VarFinalAlign =
+        computeFinalAlign(VarBasePtr, LoadStoreAlign, false, 0, VarOffsetAlign);
+
+    // Return the larger of the two final alignments
+    return std::max(ConstFinalAlign, VarFinalAlign);
   };
 
   for (BasicBlock &BB : F) {
@@ -156,11 +256,81 @@ bool inferAlignment(Function &F, AssumptionCache &AC, DominatorTree &DT) {
     // single basic block is correct too.
     BestBasePointerAligns.clear();
 
+    // First pass: Process assumes to populate base pointer alignment map
+    for (Instruction &I : BB) {
+      auto *Assume = dyn_cast<IntrinsicInst>(&I);
+      if (!Assume || Assume->getIntrinsicID() != Intrinsic::assume)
+        continue;
+
+      // Extract alignment information from assume operand bundles
+      for (unsigned Idx = 0; Idx < Assume->getNumOperandBundles(); ++Idx) {
+        OperandBundleUse OB = Assume->getOperandBundleAt(Idx);
+        if (OB.getTagID() != LLVMContext::OB_align)
+          continue;
+
+        if (OB.Inputs.size() < 2)
+          continue;
+
+        Value *AAPtr = OB.Inputs[0].get();
+
+        // Get alignment value
+        if (auto *AlignVal = dyn_cast<ConstantInt>(OB.Inputs[1].get())) {
+          uint64_t AlignValue = AlignVal->getZExtValue();
+          if (!isPowerOf2_64(AlignValue))
+            continue;
+          Align AssumedAlign(AlignValue);
+
+          // Handle constant offsets
+          APInt OffsetFromBase(DL.getIndexTypeSizeInBits(AAPtr->getType()), 0);
+          Value *ConstBasePtr = AAPtr->stripAndAccumulateConstantOffsets(
+              DL, OffsetFromBase, true);
+          uint64_t ConstOffsetVal = OffsetFromBase.abs().getLimitedValue();
+
+          // Handle variable offsets (constant offset is handled inside the
+          // function)
+          auto [VarBasePtr, VarOffsetAlign] =
+              computeVariableOffsetAlignment(ConstBasePtr, &I, ConstOffsetVal);
+
+          // Compute base alignments
+          Align ConstBaseAlign = computeFinalAlign(
+              ConstBasePtr, AssumedAlign, true, ConstOffsetVal, Align(1));
+          Align VarBaseAlign = computeFinalAlign(VarBasePtr, AssumedAlign,
+                                                 false, 0, VarOffsetAlign);
+
+          // Store alignment for both base pointers if they're different
+          if (ConstBasePtr != VarBasePtr) {
+            // Store alignment for constant base
+            auto [ConstIt, ConstInserted] =
+                BestBasePointerAligns.try_emplace(ConstBasePtr, ConstBaseAlign);
+            if (!ConstInserted && ConstBaseAlign > ConstIt->second) {
+              ConstIt->second = ConstBaseAlign;
+            }
+
+            // Store alignment for variable base
+            auto [VarIt, VarInserted] =
+                BestBasePointerAligns.try_emplace(VarBasePtr, VarBaseAlign);
+            if (!VarInserted && VarBaseAlign > VarIt->second) {
+              VarIt->second = VarBaseAlign;
+            }
+          } else {
+            // Same base pointer, choose the better alignment
+            Align BestBaseAlign = std::max(ConstBaseAlign, VarBaseAlign);
+            auto [It, Inserted] =
+                BestBasePointerAligns.try_emplace(ConstBasePtr, BestBaseAlign);
+            if (!Inserted && BestBaseAlign > It->second) {
+              It->second = BestBaseAlign;
+            }
+          }
+        }
+      }
+    }
+
+    // Process loads/stores and use the alignment map
     for (Instruction &I : BB) {
       Changed |= tryToImproveAlign(
           DL, &I, [&](Value *PtrOp, Align OldAlign, Align PrefAlign) {
             return std::max(InferFromKnownBits(I, PtrOp),
-                            InferFromBasePointer(PtrOp, OldAlign));
+                            InferFromBasePointer(PtrOp, OldAlign, &I));
           });
     }
   }
diff --git a/llvm/test/Transforms/InferAlignment/variable-offset-alignment.ll b/llvm/test/Transforms/InferAlignment/variable-offset-alignment.ll
index b5e59e73bdf88..9796e2f86a204 100644
--- a/llvm/test/Transforms/InferAlignment/variable-offset-alignment.ll
+++ b/llvm/test/Transforms/InferAlignment/variable-offset-alignment.ll
@@ -8,7 +8,7 @@ define void @test_assume_then_load(ptr %base, i64 %idx) {
 ; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr inbounds float, ptr [[BASE]], i64 [[IDX_SHIFTED]]
 ; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[GEP1]], i64 16) ]
 ; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr inbounds float, ptr [[BASE]], i64 [[IDX_SHIFTED]]
-; CHECK-NEXT:    [[LOAD:%.*]] = load float, ptr [[GEP2]], align 4
+; CHECK-NEXT:    [[LOAD:%.*]] = load float, ptr [[GEP2]], align 16
 ; CHECK-NEXT:    ret void
 ;
   %idx.shifted = shl i64 %idx, 2
@@ -27,7 +27,7 @@ define void @test_assume_addrspace_cast(ptr %base, i64 %idx) {
 ; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[GEP1]], i64 16) ]
 ; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr [[BASE]] to ptr addrspace(1)
 ; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr inbounds float, ptr addrspace(1) [[CAST]], i64 [[IDX_SHIFTED]]
-; CHECK-NEXT:    [[LOAD:%.*]] = load float, ptr addrspace(1) [[GEP2]], align 4
+; CHECK-NEXT:    [[LOAD:%.*]] = load float, ptr addrspace(1) [[GEP2]], align 16
 ; CHECK-NEXT:    ret void
 ;
   %idx.shifted = shl i64 %idx, 2
@@ -50,7 +50,7 @@ define void @test_multiple_assumes(ptr %base, i64 %idx1, i64 %idx2) {
 ; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[GEP2]], i64 8) ]
 ; CHECK-NEXT:    [[IDX3_SHIFTED:%.*]] = shl i64 [[IDX1]], 2
 ; CHECK-NEXT:    [[GEP3:%.*]] = getelementptr inbounds float, ptr [[BASE]], i64 [[IDX3_SHIFTED]]
-; CHECK-NEXT:    [[LOAD:%.*]] = load float, ptr [[GEP3]], align 4
+; CHECK-NEXT:    [[LOAD:%.*]] = load float, ptr [[GEP3]], align 16
 ; CHECK-NEXT:    ret void
 ;
   %idx1.shifted = shl i64 %idx1, 2
@@ -71,7 +71,7 @@ define void @test_base_align_then_variable_gep(ptr %base, i64 %idx) {
 ; CHECK-NEXT:    [[LOAD1:%.*]] = load float, ptr [[BASE]], align 16
 ; CHECK-NEXT:    [[IDX_SHIFTED:%.*]] = shl i64 [[IDX]], 1
 ; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds float, ptr [[BASE]], i64 [[IDX_SHIFTED]]
-; CHECK-NEXT:    [[LOAD2:%.*]] = load float, ptr [[GEP]], align 4
+; CHECK-NEXT:    [[LOAD2:%.*]] = load float, ptr [[GEP]], align 8
 ; CHECK-NEXT:    ret void
 ;
   %load1 = load float, ptr %base, align 16

>From 8344debf38792a99bcc750d60dbef6a6fb227d40 Mon Sep 17 00:00:00 2001
From: ysarita <ysarita at nvidia.com>
Date: Wed, 18 Mar 2026 02:17:04 +0000
Subject: [PATCH 3/3] reduced compile time overhead

---
 llvm/lib/Transforms/Scalar/InferAlignment.cpp | 99 ++++++++++++-------
 1 file changed, 66 insertions(+), 33 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/InferAlignment.cpp b/llvm/lib/Transforms/Scalar/InferAlignment.cpp
index 02c097c5292b4..3432ae9ef2a36 100644
--- a/llvm/lib/Transforms/Scalar/InferAlignment.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAlignment.cpp
@@ -24,6 +24,7 @@
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/Local.h"
+#include <optional>
 
 using namespace llvm;
 using namespace llvm::PatternMatch;
@@ -120,6 +121,50 @@ bool inferAlignment(Function &F, AssumptionCache &AC, DominatorTree &DT) {
     return Align(1ull << std::min(Known.getBitWidth() - 1, TrailZ));
   };
 
+  // Helper function to quickly compute alignment from common patterns
+  std::function<std::optional<Align>(Value *)> getIndexAlignmentFromPattern =
+      [&](Value *Idx) -> std::optional<Align> {
+    // Pattern: shl X, N -> alignment of 2^N
+    const APInt *ShiftAmt;
+    if (match(Idx, m_Shl(m_Value(), m_APInt(ShiftAmt)))) {
+      uint64_t Shift = ShiftAmt->getZExtValue();
+      if (Shift > 0 && Shift <= 63)
+        return Align(1ull << Shift);
+    }
+
+    // Pattern: mul X, C where C is a power of 2 -> alignment of C
+    const APInt *MulC;
+    if (match(Idx, m_Mul(m_Value(), m_APInt(MulC))) ||
+        match(Idx, m_Mul(m_APInt(MulC), m_Value()))) {
+      uint64_t Val = MulC->getZExtValue();
+      if (isPowerOf2_64(Val))
+        return Align(Val);
+    }
+
+    // Pattern: add X, C -> GCD of X's alignment and C
+    Value *AddOp;
+    const APInt *AddC;
+    if (match(Idx, m_Add(m_Value(AddOp), m_APInt(AddC))) ||
+        match(Idx, m_Add(m_APInt(AddC), m_Value(AddOp)))) {
+      if (auto XAlign = getIndexAlignmentFromPattern(AddOp))
+        return commonAlignment(*XAlign, AddC->getZExtValue());
+    }
+
+    // Pattern: sub X, C -> GCD of X's alignment and C
+    if (match(Idx, m_Sub(m_Value(AddOp), m_APInt(AddC)))) {
+      if (auto XAlign = getIndexAlignmentFromPattern(AddOp))
+        return commonAlignment(*XAlign, AddC->getZExtValue());
+    }
+
+    // Pattern: sext/zext - extensions preserve alignment
+    Value *CastSrc;
+    if (match(Idx, m_SExt(m_Value(CastSrc))) ||
+        match(Idx, m_ZExt(m_Value(CastSrc))))
+      return getIndexAlignmentFromPattern(CastSrc);
+
+    return std::nullopt;
+  };
+
   // Helper function to compute variable offset alignment and base pointer
   // If ConstOffset > 0, the effective offset alignment is limited by the
   // constant offset
@@ -128,47 +173,35 @@ bool inferAlignment(Function &F, AssumptionCache &AC, DominatorTree &DT) {
           uint64_t ConstOffset = 0) -> std::pair<Value *, Align> {
     Align VarOffsetAlign = Align(1);
     Value *VarBasePtr = Ptr;
-    bool FirstGEP = true;
-    while (true) {
-      if (auto *GEP = dyn_cast<GEPOperator>(VarBasePtr)) {
-        // We can only handle GEPs with a single index
-        if (GEP->getNumIndices() != 1)
-          break;
 
+    if (auto *GEP = dyn_cast<GEPOperator>(VarBasePtr)) {
+      // We can only handle GEPs with a single index
+      if (GEP->getNumIndices() == 1) {
         Value *Idx = GEP->idx_begin()->get();
-        KnownBits Known = computeKnownBits(Idx, DL, &AC, I, &DT);
-        unsigned TrailZ = std::min(Known.countMinTrailingZeros(),
-                                   +Value::MaxAlignmentExponent);
-        Align IndexAlign(1ull << std::min(Known.getBitWidth() - 1, TrailZ));
+        Align IndexAlign(1);
+
+        if (auto PatternAlign = getIndexAlignmentFromPattern(Idx)) {
+          IndexAlign = *PatternAlign;
+        }
+        // If pattern matching fails, IndexAlign remains 1 (no alignment from
+        // variable offset)
         Type *EltTy = GEP->getSourceElementType();
         TypeSize EltSizeType = DL.getTypeAllocSize(EltTy);
 
-        // If we encounter a scalable type, we can't compute alignment for the
-        // chain
-        if (EltSizeType.isScalable())
-          break;
-
-        uint64_t EltSize = EltSizeType.getFixedValue();
-
-        // Compute offset alignment: multiply index alignment by element size,
-        // then take the greatest power of 2 that divides the product
-        uint64_t Product = IndexAlign.value() * EltSize;
-        uint64_t ProductAlignValue = Product > 0 ? (Product & (~Product + 1))
-                                                 : 1; // Extract lowest set bit
-        Align GEPOffsetAlign = Align(ProductAlignValue);
-        if (FirstGEP) {
-          VarOffsetAlign = GEPOffsetAlign;
-          FirstGEP = false;
-        } else {
-          // When combining offsets that are added together, use GCD
-          // (commonAlignment)
-          VarOffsetAlign =
-              commonAlignment(VarOffsetAlign, GEPOffsetAlign.value());
+        // If we encounter a scalable type, we can't compute alignment
+        if (!EltSizeType.isScalable()) {
+          uint64_t EltSize = EltSizeType.getFixedValue();
+
+          // Compute offset alignment: multiply index alignment by element size,
+          // then take the greatest power of 2 that divides the product
+          uint64_t Product = IndexAlign.value() * EltSize;
+          uint64_t ProductAlignValue = Product > 0
+                                           ? (Product & (~Product + 1))
+                                           : 1; // Extract lowest set bit
+          VarOffsetAlign = Align(ProductAlignValue);
         }
 
         VarBasePtr = GEP->getPointerOperand();
-      } else {
-        break;
       }
     }
     VarBasePtr = VarBasePtr->stripPointerCasts();



More information about the llvm-commits mailing list