[llvm] [SROA] Unfold gep of index select (PR #80983)

via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 7 04:06:31 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Nikita Popov (nikic)

<details>
<summary>Changes</summary>

SROA currently supports converting a gep of select into select of gep if the select is in the pointer operand. This patch expands support to selects in an index operand.

This is intended to address the regression reported in https://github.com/llvm/llvm-project/pull/68882#issuecomment-1924909922.

---
Full diff: https://github.com/llvm/llvm-project/pull/80983.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/SROA.cpp (+46-14) 
- (modified) llvm/test/Transforms/SROA/select-gep.ll (+19-13) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp
index bdbaf4f55c96d0..53d3be85ff8594 100644
--- a/llvm/lib/Transforms/Scalar/SROA.cpp
+++ b/llvm/lib/Transforms/Scalar/SROA.cpp
@@ -3937,30 +3937,63 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
     return false;
   }
 
-  // Fold gep (select cond, ptr1, ptr2) => select cond, gep(ptr1), gep(ptr2)
+  // Fold gep (select cond, ptr1, ptr2), idx
+  //   => select cond, gep(ptr1, idx), gep(ptr2, idx)
+  // and  gep ptr, (select cond, idx1, idx2)
+  //   => select cond, gep(ptr, idx1), gep(ptr, idx2)
   bool foldGEPSelect(GetElementPtrInst &GEPI) {
-    if (!GEPI.hasAllConstantIndices())
-      return false;
+    // Check whether the GEP has exactly one select operand and all indices
+    // will become constant after the transform.
+    auto IsValidOp = [](Value *Op) {
+      return Op->getType()->isPointerTy() || isa<ConstantInt>(Op);
+    };
+
+    SelectInst *Sel = nullptr;
+    for (Value *Op : GEPI.operands()) {
+      if (auto *SI = dyn_cast<SelectInst>(Op)) {
+        if (Sel)
+          return false;
+        Sel = SI;
+        continue;
+      }
 
-    SelectInst *Sel = cast<SelectInst>(GEPI.getPointerOperand());
+      if (!IsValidOp(Op))
+        return false;
+    }
+
+    if (!Sel || !IsValidOp(Sel->getTrueValue()) ||
+        !IsValidOp(Sel->getFalseValue()))
+      return false;
 
     LLVM_DEBUG(dbgs() << "  Rewriting gep(select) -> select(gep):"
                       << "\n    original: " << *Sel
                       << "\n              " << GEPI);
 
+    auto GetNewOps = [&](Value *SelOp) {
+      SmallVector<Value *> NewOps;
+      for (Value *Op : GEPI.operands())
+        if (Op == Sel)
+          NewOps.push_back(SelOp);
+        else
+          NewOps.push_back(Op);
+      return NewOps;
+    };
+
+    Value *True = Sel->getTrueValue();
+    Value *False = Sel->getFalseValue();
+    SmallVector<Value *> TrueOps = GetNewOps(True);
+    SmallVector<Value *> FalseOps = GetNewOps(False);
+
     IRB.SetInsertPoint(&GEPI);
-    SmallVector<Value *, 4> Index(GEPI.indices());
     bool IsInBounds = GEPI.isInBounds();
 
     Type *Ty = GEPI.getSourceElementType();
-    Value *True = Sel->getTrueValue();
-    Value *NTrue = IRB.CreateGEP(Ty, True, Index, True->getName() + ".sroa.gep",
-                                 IsInBounds);
-
-    Value *False = Sel->getFalseValue();
+    Value *NTrue = IRB.CreateGEP(Ty, TrueOps[0], ArrayRef(TrueOps).drop_front(),
+                                 True->getName() + ".sroa.gep", IsInBounds);
 
-    Value *NFalse = IRB.CreateGEP(Ty, False, Index,
-                                  False->getName() + ".sroa.gep", IsInBounds);
+    Value *NFalse =
+        IRB.CreateGEP(Ty, FalseOps[0], ArrayRef(FalseOps).drop_front(),
+                      False->getName() + ".sroa.gep", IsInBounds);
 
     Value *NSel = IRB.CreateSelect(Sel->getCondition(), NTrue, NFalse,
                                    Sel->getName() + ".sroa.sel");
@@ -4034,8 +4067,7 @@ class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> {
   }
 
   bool visitGetElementPtrInst(GetElementPtrInst &GEPI) {
-    if (isa<SelectInst>(GEPI.getPointerOperand()) &&
-        foldGEPSelect(GEPI))
+    if (foldGEPSelect(GEPI))
       return true;
 
     if (isa<PHINode>(GEPI.getPointerOperand()) &&
diff --git a/llvm/test/Transforms/SROA/select-gep.ll b/llvm/test/Transforms/SROA/select-gep.ll
index 56924a0a771b0c..6a204621c90616 100644
--- a/llvm/test/Transforms/SROA/select-gep.ll
+++ b/llvm/test/Transforms/SROA/select-gep.ll
@@ -158,11 +158,20 @@ bb:
 
 define i32 @test_select_idx_memcpy(i1 %c, ptr %p) {
 ; CHECK-LABEL: @test_select_idx_memcpy(
-; CHECK-NEXT:    [[ALLOCA:%.*]] = alloca [20 x i64], align 8
-; CHECK-NEXT:    call void @llvm.memcpy.p0.p0.i64(ptr [[ALLOCA]], ptr [[P:%.*]], i64 160, i1 false)
+; CHECK-NEXT:    [[ALLOCA_SROA_0:%.*]] = alloca [4 x i8], align 8
+; CHECK-NEXT:    [[ALLOCA_SROA_2:%.*]] = alloca [20 x i8], align 4
+; CHECK-NEXT:    [[ALLOCA_SROA_22:%.*]] = alloca [4 x i8], align 8
+; CHECK-NEXT:    [[ALLOCA_SROA_3:%.*]] = alloca [132 x i8], align 4
+; CHECK-NEXT:    call void @llvm.memcpy.p0.p0.i64(ptr align 8 [[ALLOCA_SROA_0]], ptr align 1 [[P:%.*]], i64 4, i1 false)
+; CHECK-NEXT:    [[ALLOCA_SROA_2_0_P_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 4
+; CHECK-NEXT:    call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[ALLOCA_SROA_2]], ptr align 1 [[ALLOCA_SROA_2_0_P_SROA_IDX]], i64 20, i1 false)
+; CHECK-NEXT:    [[ALLOCA_SROA_22_0_P_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 24
+; CHECK-NEXT:    call void @llvm.memcpy.p0.p0.i64(ptr align 8 [[ALLOCA_SROA_22]], ptr align 1 [[ALLOCA_SROA_22_0_P_SROA_IDX]], i64 4, i1 false)
+; CHECK-NEXT:    [[ALLOCA_SROA_3_0_P_SROA_IDX:%.*]] = getelementptr inbounds i8, ptr [[P]], i64 28
+; CHECK-NEXT:    call void @llvm.memcpy.p0.p0.i64(ptr align 4 [[ALLOCA_SROA_3]], ptr align 1 [[ALLOCA_SROA_3_0_P_SROA_IDX]], i64 132, i1 false)
 ; CHECK-NEXT:    [[IDX:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
-; CHECK-NEXT:    [[GEP:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 [[IDX]]
-; CHECK-NEXT:    [[RES:%.*]] = load i32, ptr [[GEP]], align 4
+; CHECK-NEXT:    [[IDX_SROA_SEL:%.*]] = select i1 [[C]], ptr [[ALLOCA_SROA_22]], ptr [[ALLOCA_SROA_0]]
+; CHECK-NEXT:    [[RES:%.*]] = load i32, ptr [[IDX_SROA_SEL]], align 4
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
   %alloca = alloca [20 x i64], align 8
@@ -175,14 +184,9 @@ define i32 @test_select_idx_memcpy(i1 %c, ptr %p) {
 
 define i32 @test_select_idx_mem2reg(i1 %c) {
 ; CHECK-LABEL: @test_select_idx_mem2reg(
-; CHECK-NEXT:    [[ALLOCA:%.*]] = alloca [20 x i64], align 8
-; CHECK-NEXT:    store i32 1, ptr [[ALLOCA]], align 4
-; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 24
-; CHECK-NEXT:    store i32 2, ptr [[GEP1]], align 4
 ; CHECK-NEXT:    [[IDX:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
-; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 [[IDX]]
-; CHECK-NEXT:    [[RES:%.*]] = load i32, ptr [[GEP2]], align 4
-; CHECK-NEXT:    ret i32 [[RES]]
+; CHECK-NEXT:    [[RES_SROA_SPECULATED:%.*]] = select i1 [[C]], i32 2, i32 1
+; CHECK-NEXT:    ret i32 [[RES_SROA_SPECULATED]]
 ;
   %alloca = alloca [20 x i64], align 8
   store i32 1, ptr %alloca
@@ -202,8 +206,10 @@ define i32 @test_select_idx_escaped(i1 %c, ptr %p) {
 ; CHECK-NEXT:    [[GEP1:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 24
 ; CHECK-NEXT:    store i32 2, ptr [[GEP1]], align 4
 ; CHECK-NEXT:    [[IDX:%.*]] = select i1 [[C:%.*]], i64 24, i64 0
-; CHECK-NEXT:    [[GEP2:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 [[IDX]]
-; CHECK-NEXT:    [[RES:%.*]] = load i32, ptr [[GEP2]], align 4
+; CHECK-NEXT:    [[DOTSROA_GEP:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 24
+; CHECK-NEXT:    [[DOTSROA_GEP1:%.*]] = getelementptr inbounds i8, ptr [[ALLOCA]], i64 0
+; CHECK-NEXT:    [[IDX_SROA_SEL:%.*]] = select i1 [[C]], ptr [[DOTSROA_GEP]], ptr [[DOTSROA_GEP1]]
+; CHECK-NEXT:    [[RES:%.*]] = load i32, ptr [[IDX_SROA_SEL]], align 4
 ; CHECK-NEXT:    ret i32 [[RES]]
 ;
   %alloca = alloca [20 x i64], align 8

``````````

</details>


https://github.com/llvm/llvm-project/pull/80983


More information about the llvm-commits mailing list