[llvm] [InstCombine] Preserve inbounds when canonicalizing gep+add (PR #90160)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 25 20:10:21 PDT 2024


https://github.com/nikic created https://github.com/llvm/llvm-project/pull/90160

When canonicalizing gep+add into gep+gep we can preserve inbounds if the add is also nsw and both add operands are non-negative (or both negative, but I don't think that's practically relevant).

Proof: https://alive2.llvm.org/ce/z/tJLBta

>From b6d4c9068435b4270d39fc8921bec21f6f1f5b30 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 25 Apr 2024 15:43:13 +0900
Subject: [PATCH] [InstCombine] Preserve inbounds when canonicalizing gep+add

When canonicalizing gep+add into gep+gep we can preserve inbounds
if the add is also nsw and both add operands are non-negative
(or both negative, but I don't think that's practically relevant).

Proof: https://alive2.llvm.org/ce/z/tJLBta
---
 .../InstCombine/InstructionCombining.cpp      | 34 ++++++++++++++-----
 llvm/test/Transforms/InstCombine/array.ll     |  8 ++---
 2 files changed, 30 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 58b2d8e9dec1c3..0858116cd91104 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2948,6 +2948,14 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
     return nullptr;
 
   if (GEP.getNumIndices() == 1) {
+    // We can only preserve inbounds if the original gep is inbounds, the add
+    // is nsw, and the add operands are non-negative.
+    auto CanPreserveInBounds = [&](bool AddIsNSW, Value *Idx1, Value *Idx2) {
+      SimplifyQuery Q = SQ.getWithInstruction(&GEP);
+      return GEP.isInBounds() && AddIsNSW && isKnownNonNegative(Idx1, Q) &&
+             isKnownNonNegative(Idx2, Q);
+    };
+
     // Try to replace ADD + GEP with GEP + GEP.
     Value *Idx1, *Idx2;
     if (match(GEP.getOperand(1),
@@ -2957,10 +2965,15 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
       // as:
       //   %newptr = getelementptr i32, ptr %ptr, i64 %idx1
       //   %newgep = getelementptr i32, ptr %newptr, i64 %idx2
-      auto *NewPtr = Builder.CreateGEP(GEP.getResultElementType(),
-                                       GEP.getPointerOperand(), Idx1);
-      return GetElementPtrInst::Create(GEP.getResultElementType(), NewPtr,
-                                       Idx2);
+      bool IsInBounds = CanPreserveInBounds(
+          cast<OverflowingBinaryOperator>(GEP.getOperand(1))->hasNoSignedWrap(),
+          Idx1, Idx2);
+      auto *NewPtr =
+          Builder.CreateGEP(GEP.getResultElementType(), GEP.getPointerOperand(),
+                            Idx1, "", IsInBounds);
+      return replaceInstUsesWith(
+          GEP, Builder.CreateGEP(GEP.getResultElementType(), NewPtr, Idx2, "",
+                                 IsInBounds));
     }
     ConstantInt *C;
     if (match(GEP.getOperand(1), m_OneUse(m_SExtLike(m_OneUse(m_NSWAdd(
@@ -2971,12 +2984,17 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
       // as:
       // %newptr = getelementptr i32, ptr %ptr, i32 %idx1
       // %newgep = getelementptr i32, ptr %newptr, i32 idx2
+      bool IsInBounds = CanPreserveInBounds(
+          /*IsNSW=*/true, Idx1, C);
       auto *NewPtr = Builder.CreateGEP(
           GEP.getResultElementType(), GEP.getPointerOperand(),
-          Builder.CreateSExt(Idx1, GEP.getOperand(1)->getType()));
-      return GetElementPtrInst::Create(
-          GEP.getResultElementType(), NewPtr,
-          Builder.CreateSExt(C, GEP.getOperand(1)->getType()));
+          Builder.CreateSExt(Idx1, GEP.getOperand(1)->getType()), "",
+          IsInBounds);
+      return replaceInstUsesWith(
+          GEP,
+          Builder.CreateGEP(GEP.getResultElementType(), NewPtr,
+                            Builder.CreateSExt(C, GEP.getOperand(1)->getType()),
+                            "", IsInBounds));
     }
   }
 
diff --git a/llvm/test/Transforms/InstCombine/array.ll b/llvm/test/Transforms/InstCombine/array.ll
index f439d4da6080c4..4f4ae17bebc503 100644
--- a/llvm/test/Transforms/InstCombine/array.ll
+++ b/llvm/test/Transforms/InstCombine/array.ll
@@ -116,8 +116,8 @@ define ptr @gep_inbounds_add_nsw_nonneg(ptr %ptr, i64 %a, i64 %b) {
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[A_NNEG]])
 ; CHECK-NEXT:    [[B_NNEG:%.*]] = icmp sgt i64 [[B]], -1
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[B_NNEG]])
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR]], i64 [[A]]
-; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i32, ptr [[TMP1]], i64 [[B]]
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr inbounds i32, ptr [[PTR]], i64 [[A]]
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i64 [[B]]
 ; CHECK-NEXT:    ret ptr [[GEP]]
 ;
   %a.nneg = icmp sgt i64 %a, -1
@@ -207,8 +207,8 @@ define ptr @gep_inbounds_sext_add_nonneg(ptr %ptr, i32 %a) {
 ; CHECK-NEXT:    [[A_NNEG:%.*]] = icmp sgt i32 [[A]], -1
 ; CHECK-NEXT:    call void @llvm.assume(i1 [[A_NNEG]])
 ; CHECK-NEXT:    [[TMP1:%.*]] = zext nneg i32 [[A]] to i64
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i32, ptr [[PTR]], i64 [[TMP1]]
-; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr [[TMP2]], i64 40
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr inbounds i32, ptr [[PTR]], i64 [[TMP1]]
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds i8, ptr [[TMP2]], i64 40
 ; CHECK-NEXT:    ret ptr [[GEP]]
 ;
   %a.nneg = icmp sgt i32 %a, -1



More information about the llvm-commits mailing list