[llvm] [InstCombine] Canonicalize `gep T, (gep i8, base, C1), (Index +nsw C2)` into `gep T, (gep i8, base, C1 + C2 * sizeof(T)), Index` (PR #76177)

via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 21 11:49:39 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

<details>
<summary>Changes</summary>

This patch tries to canonicalize `gep T, (gep i8, base, C1), (Index +nsw C2)` into `gep T, (gep i8, base, C1 + C2 * sizeof(T)), Index`.

Alive2: https://alive2.llvm.org/ce/z/9icv7n
Fixes regressions found in https://github.com/llvm/llvm-project/pull/68882.

Note: I found that `gep T, base, (Index +nsw C)` will be expanded into `gep T, (gep T, ptr, Index), C`. Thus we can canonicalize the GEP of GEP by moving the constant-indexed GEP to the front/[back](https://reviews.llvm.org/D138950). But in these regressions, the `add` inst is likely to be used by multiple users. So I think it is worth handling this pattern.


---
Full diff: https://github.com/llvm/llvm-project/pull/76177.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+51) 
- (added) llvm/test/Transforms/InstCombine/gepofconstgepi8.ll (+282) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 775720ab43a5c0..bfa208cc4dda5e 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2038,6 +2038,54 @@ static Instruction *foldSelectGEP(GetElementPtrInst &GEP,
   return SelectInst::Create(Cond, NewTrueC, NewFalseC, "", nullptr, Sel);
 }
 
+// Canonicalization:
+// gep T, (gep i8, base, C1), (Index +nsw C2) into
+// gep T, (gep i8, base, C1 + C2 * sizeof(T)), Index
+static Instruction *canonicalizeGEPOfConstGEPI8(GetElementPtrInst &GEP,
+                                                GEPOperator *Src,
+                                                InstCombinerImpl &IC) {
+  if (GEP.getNumIndices() != 1)
+    return nullptr;
+  auto &DL = IC.getDataLayout();
+  if (!Src->getSourceElementType()->isIntegerTy(8) ||
+      !Src->hasAllConstantIndices())
+    return nullptr;
+  Value *VarIndex;
+  const APInt *C2;
+  Type *PtrTy = Src->getType()->getScalarType();
+  unsigned IndexSizeInBits = DL.getIndexTypeSizeInBits(PtrTy);
+  if (!(GEP.getOperand(1)->getType()->getScalarSizeInBits() >=
+            IndexSizeInBits &&
+        match(GEP.getOperand(1), m_Add(m_Value(VarIndex), m_APInt(C2)))) &&
+      !match(GEP.getOperand(1),
+             m_SExtOrSelf(
+                 m_CombineOr(m_NSWAdd(m_Value(VarIndex), m_APInt(C2)),
+                             m_DisjointOr(m_Value(VarIndex), m_APInt(C2))))))
+    return nullptr;
+  Type *BaseType = GEP.getSourceElementType();
+  APInt C1(IndexSizeInBits, 0);
+  // Add the offset for Src (which is fully constant).
+  if (!Src->accumulateConstantOffset(DL, C1))
+    return nullptr;
+  APInt TypeSize(IndexSizeInBits, DL.getTypeAllocSize(BaseType));
+  bool Overflow = false;
+  APInt C3 = TypeSize.smul_ov(C2->sext(TypeSize.getBitWidth()), Overflow);
+  if (Overflow)
+    return nullptr;
+  APInt NewOffset = C1.sadd_ov(C3, Overflow);
+  if (Overflow)
+    return nullptr;
+  if (NewOffset.isZero() ||
+      (Src->hasOneUse() && GEP.getOperand(1)->hasOneUse())) {
+    Value *GEPConst =
+        IC.Builder.CreateGEP(IC.Builder.getInt8Ty(), Src->getPointerOperand(),
+                             IC.Builder.getInt(NewOffset));
+    return GetElementPtrInst::Create(BaseType, GEPConst, VarIndex);
+  }
+
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP,
                                              GEPOperator *Src) {
   // Combine Indices - If the source pointer to this getelementptr instruction
@@ -2046,6 +2094,9 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP,
   if (!shouldMergeGEPs(*cast<GEPOperator>(&GEP), *Src))
     return nullptr;
 
+  if (auto *I = canonicalizeGEPOfConstGEPI8(GEP, Src, *this))
+    return I;
+
   // For constant GEPs, use a more general offset-based folding approach.
   Type *PtrTy = Src->getType()->getScalarType();
   if (GEP.hasAllConstantIndices() &&
diff --git a/llvm/test/Transforms/InstCombine/gepofconstgepi8.ll b/llvm/test/Transforms/InstCombine/gepofconstgepi8.ll
new file mode 100644
index 00000000000000..0d280f0055748c
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/gepofconstgepi8.ll
@@ -0,0 +1,282 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -S -passes=instcombine | FileCheck %s
+
+target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64"
+target triple = "x86_64-unknown-linux-gnu"
+
+declare void @use64(i64)
+declare void @useptr(ptr)
+
+define ptr @test_zero(ptr %base, i64 %a) {
+; CHECK-LABEL: define ptr @test_zero(
+; CHECK-SAME: ptr [[BASE:%.*]], i64 [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[BASE]], i64 [[A]]
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 -4
+  %index = add i64 %a, 1
+  %p2 = getelementptr i32, ptr %p1, i64 %index
+  ret ptr %p2
+}
+
+define ptr @test_nonzero(ptr %base, i64 %a) {
+; CHECK-LABEL: define ptr @test_nonzero(
+; CHECK-SAME: ptr [[BASE:%.*]], i64 [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr i8, ptr [[BASE]], i64 4
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[TMP0]], i64 [[A]]
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 -4
+  %index = add i64 %a, 2
+  %p2 = getelementptr i32, ptr %p1, i64 %index
+  ret ptr %p2
+}
+
+define ptr @test_or_disjoint(ptr %base, i64 %a) {
+; CHECK-LABEL: define ptr @test_or_disjoint(
+; CHECK-SAME: ptr [[BASE:%.*]], i64 [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[BASE]], i64 [[A]]
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 -4
+  %index = or disjoint i64 %a, 1
+  %p2 = getelementptr i32, ptr %p1, i64 %index
+  ret ptr %p2
+}
+
+define ptr @test_zero_multiuse_index(ptr %base, i64 %a) {
+; CHECK-LABEL: define ptr @test_zero_multiuse_index(
+; CHECK-SAME: ptr [[BASE:%.*]], i64 [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[INDEX:%.*]] = add i64 [[A]], 1
+; CHECK-NEXT:    call void @use64(i64 [[INDEX]])
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[BASE]], i64 [[A]]
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 -4
+  %index = add i64 %a, 1
+  call void @use64(i64 %index)
+  %p2 = getelementptr i32, ptr %p1, i64 %index
+  ret ptr %p2
+}
+
+define ptr @test_zero_multiuse_ptr(ptr %base, i64 %a) {
+; CHECK-LABEL: define ptr @test_zero_multiuse_ptr(
+; CHECK-SAME: ptr [[BASE:%.*]], i64 [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P1:%.*]] = getelementptr i8, ptr [[BASE]], i64 -4
+; CHECK-NEXT:    call void @useptr(ptr [[P1]])
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[BASE]], i64 [[A]]
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 -4
+  call void @useptr(ptr %p1)
+  %index = add i64 %a, 1
+  %p2 = getelementptr i32, ptr %p1, i64 %index
+  ret ptr %p2
+}
+
+define ptr @test_zero_sext_add_nsw(ptr %base, i32 %a) {
+; CHECK-LABEL: define ptr @test_zero_sext_add_nsw(
+; CHECK-SAME: ptr [[BASE:%.*]], i32 [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = sext i32 [[A]] to i64
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[BASE]], i64 [[TMP0]]
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 -4
+  %index = add nsw i32 %a, 1
+  %p2 = getelementptr i32, ptr %p1, i32 %index
+  ret ptr %p2
+}
+
+define ptr @test_zero_trunc_add(ptr %base, i128 %a) {
+; CHECK-LABEL: define ptr @test_zero_trunc_add(
+; CHECK-SAME: ptr [[BASE:%.*]], i128 [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = trunc i128 [[A]] to i64
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[BASE]], i64 [[TMP0]]
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 -4
+  %index = add i128 %a, 1
+  %p2 = getelementptr i32, ptr %p1, i128 %index
+  ret ptr %p2
+}
+
+; Negative tests
+
+define ptr @test_non_i8(ptr %base, i64 %a) {
+; CHECK-LABEL: define ptr @test_non_i8(
+; CHECK-SAME: ptr [[BASE:%.*]], i64 [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P1:%.*]] = getelementptr i16, ptr [[BASE]], i64 -4
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr i32, ptr [[P1]], i64 [[A]]
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[TMP0]], i64 1
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i16, ptr %base, i64 -4
+  %index = add i64 %a, 1
+  %p2 = getelementptr i32, ptr %p1, i64 %index
+  ret ptr %p2
+}
+
+define ptr @test_non_const(ptr %base, i64 %a, i64 %b) {
+; CHECK-LABEL: define ptr @test_non_const(
+; CHECK-SAME: ptr [[BASE:%.*]], i64 [[A:%.*]], i64 [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P1:%.*]] = getelementptr i8, ptr [[BASE]], i64 [[B]]
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr i32, ptr [[P1]], i64 [[A]]
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[TMP0]], i64 1
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 %b
+  %index = add i64 %a, 1
+  %p2 = getelementptr i32, ptr %p1, i64 %index
+  ret ptr %p2
+}
+
+define ptr @test_too_many_indices(ptr %base, i64 %a, i64 %b) {
+; CHECK-LABEL: define ptr @test_too_many_indices(
+; CHECK-SAME: ptr [[BASE:%.*]], i64 [[A:%.*]], i64 [[B:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P1:%.*]] = getelementptr i8, ptr [[BASE]], i64 [[B]]
+; CHECK-NEXT:    [[INDEX:%.*]] = add i64 [[A]], 1
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr [8 x i32], ptr [[P1]], i64 1, i64 [[INDEX]]
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 %b
+  %index = add i64 %a, 1
+  %p2 = getelementptr [8 x i32], ptr %p1, i64 1, i64 %index
+  ret ptr %p2
+}
+
+define ptr @test_wrong_op(ptr %base, i64 %a) {
+; CHECK-LABEL: define ptr @test_wrong_op(
+; CHECK-SAME: ptr [[BASE:%.*]], i64 [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P1:%.*]] = getelementptr i8, ptr [[BASE]], i64 -4
+; CHECK-NEXT:    [[INDEX:%.*]] = xor i64 [[A]], 1
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[P1]], i64 [[INDEX]]
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 -4
+  %index = xor i64 %a, 1
+  %p2 = getelementptr i32, ptr %p1, i64 %index
+  ret ptr %p2
+}
+
+define ptr @test_sext_add_without_nsw(ptr %base, i32 %a) {
+; CHECK-LABEL: define ptr @test_sext_add_without_nsw(
+; CHECK-SAME: ptr [[BASE:%.*]], i32 [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P1:%.*]] = getelementptr i8, ptr [[BASE]], i64 -4
+; CHECK-NEXT:    [[INDEX:%.*]] = add i32 [[A]], 1
+; CHECK-NEXT:    [[TMP0:%.*]] = sext i32 [[INDEX]] to i64
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[P1]], i64 [[TMP0]]
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 -4
+  %index = add i32 %a, 1
+  %p2 = getelementptr i32, ptr %p1, i32 %index
+  ret ptr %p2
+}
+
+define ptr @test_or_without_disjoint(ptr %base, i64 %a) {
+; CHECK-LABEL: define ptr @test_or_without_disjoint(
+; CHECK-SAME: ptr [[BASE:%.*]], i64 [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P1:%.*]] = getelementptr i8, ptr [[BASE]], i64 -4
+; CHECK-NEXT:    [[INDEX:%.*]] = or i64 [[A]], 1
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[P1]], i64 [[INDEX]]
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 -4
+  %index = or i64 %a, 1
+  %p2 = getelementptr i32, ptr %p1, i64 %index
+  ret ptr %p2
+}
+
+define ptr @test_smul_overflow(ptr %base, i64 %a) {
+; CHECK-LABEL: define ptr @test_smul_overflow(
+; CHECK-SAME: ptr [[BASE:%.*]], i64 [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P1:%.*]] = getelementptr i8, ptr [[BASE]], i64 -4
+; CHECK-NEXT:    [[INDEX:%.*]] = xor i64 [[A]], -9223372036854775808
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[P1]], i64 [[INDEX]]
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 -4
+  %index = add i64 %a, -9223372036854775808
+  %p2 = getelementptr i32, ptr %p1, i64 %index
+  ret ptr %p2
+}
+
+define ptr @test_sadd_overflow(ptr %base, i64 %a) {
+; CHECK-LABEL: define ptr @test_sadd_overflow(
+; CHECK-SAME: ptr [[BASE:%.*]], i64 [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P1:%.*]] = getelementptr i8, ptr [[BASE]], i64 9223372036854775804
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr i32, ptr [[P1]], i64 [[A]]
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[TMP0]], i64 1
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 9223372036854775804
+  %index = add i64 %a, 1
+  %p2 = getelementptr i32, ptr %p1, i64 %index
+  ret ptr %p2
+}
+
+define ptr @test_nonzero_multiuse_index(ptr %base, i64 %a) {
+; CHECK-LABEL: define ptr @test_nonzero_multiuse_index(
+; CHECK-SAME: ptr [[BASE:%.*]], i64 [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P1:%.*]] = getelementptr i8, ptr [[BASE]], i64 -4
+; CHECK-NEXT:    [[INDEX:%.*]] = add i64 [[A]], 2
+; CHECK-NEXT:    call void @use64(i64 [[INDEX]])
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[P1]], i64 [[INDEX]]
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 -4
+  %index = add i64 %a, 2
+  call void @use64(i64 %index)
+  %p2 = getelementptr i32, ptr %p1, i64 %index
+  ret ptr %p2
+}
+
+define ptr @test_nonzero_multiuse_ptr(ptr %base, i64 %a) {
+; CHECK-LABEL: define ptr @test_nonzero_multiuse_ptr(
+; CHECK-SAME: ptr [[BASE:%.*]], i64 [[A:%.*]]) {
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[P1:%.*]] = getelementptr i8, ptr [[BASE]], i64 -4
+; CHECK-NEXT:    call void @useptr(ptr [[P1]])
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr i32, ptr [[P1]], i64 [[A]]
+; CHECK-NEXT:    [[P2:%.*]] = getelementptr i32, ptr [[TMP0]], i64 2
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+entry:
+  %p1 = getelementptr i8, ptr %base, i64 -4
+  call void @useptr(ptr %p1)
+  %index = add i64 %a, 2
+  %p2 = getelementptr i32, ptr %p1, i64 %index
+  ret ptr %p2
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/76177


More information about the llvm-commits mailing list