[llvm] [SROA] Unfold gep of index select (PR #80983)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 8 01:26:37 PST 2024


https://github.com/nikic updated https://github.com/llvm/llvm-project/pull/80983

>From b59a4fa5f26e2cdeeb8fd3fab8a0aff626839a2c Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Wed, 7 Feb 2024 12:59:09 +0100
Subject: [PATCH 1/4] [SROA] Unfold gep of index select

SROA currently supports converting a gep of select into select of
gep if the select is in the pointer operand. This patch expands
support to selects in an index operand.

This is intended to address the regression reported in
https://github.com/llvm/llvm-project/pull/68882#issuecomment-1924909922.
---
 llvm/lib/Transforms/Scalar/SROA.cpp     | 60 +++++++++++++++++++------
 llvm/test/Transforms/SROA/select-gep.ll | 32 +++++++------
 2 files changed, 65 insertions(+), 27 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp
index bdbaf4f55c96d0..53d3be85ff8594 100644
--- a/llvm/lib/Transforms/Scalar/SROA.cpp
+++ b/llvm/lib/Transforms/Scalar/SROA.cpp
@@ -3937,30 +3937,63 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
     return false;
   }
 
-  // Fold gep (select cond, ptr1, ptr2) => select cond, gep(ptr1), gep(ptr2)
+  // Fold gep (select cond, ptr1, ptr2), idx
+  //   => select cond, gep(ptr1, idx), gep(ptr2, idx)
+  // and  gep ptr, (select cond, idx1, idx2)
+  //   => select cond, gep(ptr, idx1), gep(ptr, idx2)
   bool foldGEPSelect(GetElementPtrInst &GEPI) {
-    if (!GEPI.hasAllConstantIndices())
-      return false;
+    // Check whether the GEP has exactly one select operand and all indices
+    // will become constant after the transform.
+    auto IsValidOp = [](Value *Op) {
+      return Op->getType()->isPointerTy() || isa<ConstantInt>(Op);
+    };
+
+    SelectInst *Sel = nullptr;
+    for (Value *Op : GEPI.operands()) {
+      if (auto *SI = dyn_cast<SelectInst>(Op)) {
+        if (Sel)
+          return false;
+        Sel = SI;
+        continue;
+      }
 
-    SelectInst *Sel = cast<SelectInst>(GEPI.getPointerOperand());
+      if (!IsValidOp(Op))
+        return false;
+    }
+
+    if (!Sel || !IsValidOp(Sel->getTrueValue()) ||
+        !IsValidOp(Sel->getFalseValue()))
+      return false;
 
     LLVM_DEBUG(dbgs() << "  Rewriting gep(select) -> select(gep):"
                       << "\n    original: " << *Sel
                       << "\n              " << GEPI);
 
+    auto GetNewOps = [&](Value *SelOp) {
+      SmallVector<Value *> NewOps;
+      for (Value *Op : GEPI.operands())
+        if (Op == Sel)
+          NewOps.push_back(SelOp);
+        else
+          NewOps.push_back(Op);
+      return NewOps;
+    };
+
+    Value *True = Sel->getTrueValue();
+    Value *False = Sel->getFalseValue();
+    SmallVector<Value *> TrueOps = GetNewOps(True);
+    SmallVector<Value *> FalseOps = GetNewOps(False);
+
     IRB.SetInsertPoint(&GEPI);
-    SmallVector<Value *, 4> Index(GEPI.indices());
     bool IsInBounds = GEPI.isInBounds();
 
     Type *Ty = GEPI.getSourceElementType();
-    Value *True = Sel->getTrueValue();
-    Value *NTrue = IRB.CreateGEP(Ty, True, Index, True->getName() + ".sroa.gep",
-                                 IsInBounds);
-
-    Value *False = Sel->getFalseValue();
+    Value *NTrue = IRB.CreateGEP(Ty, TrueOps[0], ArrayRef(TrueOps).drop_front(),
+                                 True->getName() + ".sroa.gep", IsInBounds);
 
-    Value *NFalse = IRB.CreateGEP(Ty, False, Index,
-                                  False->getName() + ".sroa.gep", IsInBounds);
+    Value *NFalse =
+        IRB.CreateGEP(Ty, FalseOps[0], ArrayRef(FalseOps).drop_front(),
+                      False->getName() + ".sroa.gep", IsInBounds);
 
     Value *NSel = IRB.CreateSelect(Sel->getCondition(), NTrue, NFalse,
                                    Sel->getName() + ".sroa.sel");
@@ -4034,8 +4067,7 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
   }
 
   bool visitGetElementPtrInst(GetElementPtrInst &GEPI) {
-    if (isa<SelectInst>(GEPI.getPointerOperand()) &&
-        foldGEPSelect(GEPI))
+    if (foldGEPSelect(GEPI))
       return true;
 
     if (isa<PHINode>(GEPI.getPointerOperand()) &&
diff --git a/llvm/test/Transforms/SROA/select-gep.ll b/llvm/test/Transforms/SROA/select-gep.ll
index 56924a0a771b0c..6a204621c90616 100644
--- a/llvm/test/Transforms/SROA/select-gep.ll
+++ b/llvm/test/Transforms/SROA/select-gep.ll
@@ -158,11 +158,20 @@ bb:
 
 define i32 @test_select_idx_memcpy(i1 %c, ptr %p) {
 ; CHECK-LABEL: @test_select_idx_memcpy(
-; CHECK-NEXT:    [[ALLOCA:%.*]] = alloca [20 x i64], align 8
-; CHECK-NEXT:    call void @llvm.memcpy.p0.p0.i64(ptr [[ALLOCA]], ptr [[P:%.*]], i64 160, i1 false)
+; CHECK-NEXT:    [[ALLOCA_SROA_0:%.*]] = alloca [4 x i8], align 8
+; CHECK-NEXT:    [[ALLOCA_SROA_2:%.*]] = alloca [20 x i8], align 4
+; CHECK-NEXT:    [[ALLOCA_SROA_22:%.*]] = alloca [4 x i8], align 8
+; CHECK-NEXT:    [[ALLOCA_SROA_3:%.*]] = alloca [132 x i8], align 4
+; CHECK-NEXT:    call void @llvm.memcpy.p0.p0.i64(ptr align 8 [[ALLOCA_SROA_0]], ptr align 1 [[P:%.*]], i64 4, i1 false)
+; CHECK-NEXT:    [[ALLOCA_SROA_2_0_P_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 4
+; CHECK-NEXT:    call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[ALLOCA_SROA_2]], ptr align 1 [[ALLOCA_SROA_2_0_P_SROA_IDX]], i64 20, i1 false)
+; CHECK-NEXT:    [[ALLOCA_SROA_22_0_P_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 24
+; CHECK-NEXT:    call void @llvm.memcpy.p0.p0.i64(ptr align 8 [[ALLOCA_SROA_22]], ptr align 1 [[ALLOCA_SROA_22_0_P_SROA_IDX]], i64 4, i1 false)
+; CHECK-NEXT:    [[ALLOCA_SROA_3_0_P_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 28
+; CHECK-NEXT:    call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[ALLOCA_SROA_3]], ptr align 1 [[ALLOCA_SROA_3_0_P_SROA_IDX]], i64 132, i1 false)
 ; CHECK-NEXT:    [[IDX:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
-; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 [[IDX]]
-; CHECK-NEXT:    [[RES:%.*]] = load i32, ptr [[GEP]], align 4
+; CHECK-NEXT:    [[IDX_SROA_SEL:%.*]] = select i1 [[C]], ptr [[ALLOCA_SROA_22]], ptr [[ALLOCA_SROA_0]]
+; CHECK-NEXT:    [[RES:%.*]] = load i32, ptr [[IDX_SROA_SEL]], align 4
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
   %alloca = alloca [20 x i64], align 8
@@ -175,14 +184,9 @@ define i32 @test_select_idx_memcpy(i1 %c, ptr %p) {
 
 define i32 @test_select_idx_mem2reg(i1 %c) {
 ; CHECK-LABEL: @test_select_idx_mem2reg(
-; CHECK-NEXT:    [[ALLOCA:%.*]] = alloca [20 x i64], align 8
-; CHECK-NEXT:    store i32 1, ptr [[ALLOCA]], align 4
-; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 24
-; CHECK-NEXT:    store i32 2, ptr [[GEP1]], align 4
 ; CHECK-NEXT:    [[IDX:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
-; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 [[IDX]]
-; CHECK-NEXT:    [[RES:%.*]] = load i32, ptr [[GEP2]], align 4
-; CHECK-NEXT:    ret i32 [[RES]]
+; CHECK-NEXT:    [[RES_SROA_SPECULATED:%.*]] = select i1 [[C]], i32 2, i32 1
+; CHECK-NEXT:    ret i32 [[RES_SROA_SPECULATED]]
 ;
   %alloca = alloca [20 x i64], align 8
   store i32 1, ptr %alloca
@@ -202,8 +206,10 @@ define i32 @test_select_idx_escaped(i1 %c, ptr %p) {
 ; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 24
 ; CHECK-NEXT:    store i32 2, ptr [[GEP1]], align 4
 ; CHECK-NEXT:    [[IDX:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
-; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 [[IDX]]
-; CHECK-NEXT:    [[RES:%.*]] = load i32, ptr [[GEP2]], align 4
+; CHECK-NEXT:    [[DOTSROA_GEP:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 24
+; CHECK-NEXT:    [[DOTSROA_GEP1:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 0
+; CHECK-NEXT:    [[IDX_SROA_SEL:%.*]] = select i1 [[C]], ptr [[DOTSROA_GEP]], ptr [[DOTSROA_GEP1]]
+; CHECK-NEXT:    [[RES:%.*]] = load i32, ptr [[IDX_SROA_SEL]], align 4
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
   %alloca = alloca [20 x i64], align 8

>From 019aad8b14e49ec3b8a33923ae969008ce6c0d18 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 8 Feb 2024 10:18:32 +0100
Subject: [PATCH 2/4] Add extra test comments

---
 llvm/test/Transforms/SROA/select-gep.ll | 11 ++++++++++-
 1 file changed, 10 insertions(+), 1 deletion(-)

diff --git a/llvm/test/Transforms/SROA/select-gep.ll b/llvm/test/Transforms/SROA/select-gep.ll
index 6a204621c90616..0100b910dc5a89 100644
--- a/llvm/test/Transforms/SROA/select-gep.ll
+++ b/llvm/test/Transforms/SROA/select-gep.ll
@@ -155,7 +155,8 @@ bb:
   ret i32 %load
 }
 
-
+; Test gep of index select unfolding on an alloca that is splittable, but not
+; promotable. The allocas here will be optimized away by subsequent passes.
 define i32 @test_select_idx_memcpy(i1 %c, ptr %p) {
 ; CHECK-LABEL: @test_select_idx_memcpy(
 ; CHECK-NEXT:    [[ALLOCA_SROA_0:%.*]] = alloca [4 x i8], align 8
@@ -182,6 +183,8 @@ define i32 @test_select_idx_memcpy(i1 %c, ptr %p) {
   ret i32 %res
 }
 
+; Test gep of index select unfolding on an alloca that is splittable and
+; promotable.
 define i32 @test_select_idx_mem2reg(i1 %c) {
 ; CHECK-LABEL: @test_select_idx_mem2reg(
 ; CHECK-NEXT:    [[IDX:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
@@ -198,6 +201,9 @@ define i32 @test_select_idx_mem2reg(i1 %c) {
   ret i32 %res
 }
 
+; Test gep of index select unfolding on an alloca that escaped, and as such
+; is not splittable or promotable.
+; FIXME: Ideally, no transform would take place in this case.
 define i32 @test_select_idx_escaped(i1 %c, ptr %p) {
 ; CHECK-LABEL: @test_select_idx_escaped(
 ; CHECK-NEXT:    [[ALLOCA:%.*]] = alloca [20 x i64], align 8
@@ -223,6 +229,9 @@ define i32 @test_select_idx_escaped(i1 %c, ptr %p) {
   ret i32 %res
 }
 
+; The following cases involve non-constant indices and should not be
+; transformed.
+
 define i32 @test_select_idx_not_constant1(i1 %c, ptr %p, i64 %arg) {
 ; CHECK-LABEL: @test_select_idx_not_constant1(
 ; CHECK-NEXT:    [[ALLOCA:%.*]] = alloca [20 x i64], align 8

>From c15562c2f7001d5e2f447cb1617c5fb737c0cc0c Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 8 Feb 2024 10:21:17 +0100
Subject: [PATCH 3/4] Handle pointer operand separately

---
 llvm/lib/Transforms/Scalar/SROA.cpp | 17 ++++++++---------
 1 file changed, 8 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp
index 53d3be85ff8594..8d19ae4f39c799 100644
--- a/llvm/lib/Transforms/Scalar/SROA.cpp
+++ b/llvm/lib/Transforms/Scalar/SROA.cpp
@@ -3944,25 +3944,24 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
   bool foldGEPSelect(GetElementPtrInst &GEPI) {
     // Check whether the GEP has exactly one select operand and all indices
     // will become constant after the transform.
-    auto IsValidOp = [](Value *Op) {
-      return Op->getType()->isPointerTy() || isa<ConstantInt>(Op);
-    };
-
-    SelectInst *Sel = nullptr;
-    for (Value *Op : GEPI.operands()) {
+    SelectInst *Sel = dyn_cast<SelectInst>(GEPI.getPointerOperand());
+    for (Value *Op : GEPI.indices()) {
       if (auto *SI = dyn_cast<SelectInst>(Op)) {
         if (Sel)
           return false;
+
         Sel = SI;
+        if (!isa<ConstantInt>(Sel->getTrueValue()) ||
+            !isa<ConstantInt>(Sel->getFalseValue()))
+          return false;
         continue;
       }
 
-      if (!IsValidOp(Op))
+      if (!isa<ConstantInt>(Op))
         return false;
     }
 
-    if (!Sel || !IsValidOp(Sel->getTrueValue()) ||
-        !IsValidOp(Sel->getFalseValue()))
+    if (!Sel)
       return false;
 
     LLVM_DEBUG(dbgs() << "  Rewriting gep(select) -> select(gep):"

>From d23ec21bc86e040d0eb7a192923606eab89a180f Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Thu, 8 Feb 2024 10:26:06 +0100
Subject: [PATCH 4/4] Add test for nested selects

---
 llvm/test/Transforms/SROA/select-gep.ll | 29 +++++++++++++++++++++++++
 1 file changed, 29 insertions(+)

diff --git a/llvm/test/Transforms/SROA/select-gep.ll b/llvm/test/Transforms/SROA/select-gep.ll
index 0100b910dc5a89..1342a2ca4ea2be 100644
--- a/llvm/test/Transforms/SROA/select-gep.ll
+++ b/llvm/test/Transforms/SROA/select-gep.ll
@@ -229,6 +229,35 @@ define i32 @test_select_idx_escaped(i1 %c, ptr %p) {
   ret i32 %res
 }
 
+; FIXME: Should we allow recursive select unfolding if all the leaves are
+; constants?
+define i32 @test_select_idx_nested(i1 %c, i1 %c2) {
+; CHECK-LABEL: @test_select_idx_nested(
+; CHECK-NEXT:    [[ALLOCA:%.*]] = alloca [20 x i64], align 8
+; CHECK-NEXT:    store i32 1, ptr [[ALLOCA]], align 4
+; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 8
+; CHECK-NEXT:    store i32 2, ptr [[GEP1]], align 4
+; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 24
+; CHECK-NEXT:    store i32 3, ptr [[GEP2]], align 4
+; CHECK-NEXT:    [[IDX1:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
+; CHECK-NEXT:    [[IDX2:%.*]] = select i1 [[C2:%.*]], i64 [[IDX1]], i64 8
+; CHECK-NEXT:    [[GEP3:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 [[IDX2]]
+; CHECK-NEXT:    [[RES:%.*]] = load i32, ptr [[GEP3]], align 4
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %alloca = alloca [20 x i64], align 8
+  store i32 1, ptr %alloca
+  %gep1 = getelementptr inbounds i8, ptr %alloca, i64 8
+  store i32 2, ptr %gep1
+  %gep2 = getelementptr inbounds i8, ptr %alloca, i64 24
+  store i32 3, ptr %gep2
+  %idx1 = select i1 %c, i64 24, i64 0
+  %idx2 = select i1 %c2, i64 %idx1, i64 8
+  %gep3 = getelementptr inbounds i8, ptr %alloca, i64 %idx2
+  %res = load i32, ptr %gep3, align 4
+  ret i32 %res
+}
+
 ; The following cases involve non-constant indices and should not be
 ; transformed.
 



More information about the llvm-commits mailing list