[llvm] [CGP] Scalarize non-constant indices for geps feeding gather/scatter (PR #145952)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 26 12:07:46 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Philip Reames (preames)

<details>
<summary>Changes</summary>

Implementing an existing TODO, and simplifying code in the process.

---
Full diff: https://github.com/llvm/llvm-project/pull/145952.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/CodeGenPrepare.cpp (+9-23) 
- (modified) llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll (+6-5) 


``````````diff
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 43574a54c37dd..4786b0dc0cb6e 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -6239,7 +6239,7 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
 /// pointer there's nothing we can do.
 ///
 /// If we have a GEP with more than 2 indices where the middle indices are all
-/// zeroes, we can replace it with 2 GEPs where the second has 2 operands.
+/// scalarizable, we can replace it with 2 GEPs where the second has 2 operands.
 ///
 /// If the final index isn't a vector or is a splat, we can emit a scalar GEP
 /// followed by a GEP with an all zeroes vector index. This will enable
@@ -6262,30 +6262,16 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
     SmallVector<Value *, 2> Ops(GEP->operands());
 
     bool RewriteGEP = false;
-
-    if (Ops[0]->getType()->isVectorTy()) {
-      Ops[0] = getSplatValue(Ops[0]);
-      if (!Ops[0])
-        return false;
-      RewriteGEP = true;
-    }
-
     unsigned FinalIndex = Ops.size() - 1;
 
-    // Ensure all but the last index is 0.
-    // FIXME: This isn't strictly required. All that's required is that they are
-    // all scalars or splats.
-    for (unsigned i = 1; i < FinalIndex; ++i) {
-      auto *C = dyn_cast<Constant>(Ops[i]);
-      if (!C)
-        return false;
-      if (isa<VectorType>(C->getType()))
-        C = C->getSplatValue();
-      auto *CI = dyn_cast_or_null<ConstantInt>(C);
-      if (!CI || !CI->isZero())
-        return false;
-      // Scalarize the index if needed.
-      Ops[i] = CI;
+    // Ensure all but the last index are scalar
+    for (unsigned i = 0; i < FinalIndex; ++i) {
+      if (isa<VectorType>(Ops[i]->getType())) {
+        Ops[i] = getSplatValue(Ops[i]);
+        if (!Ops[i])
+          return false;
+        RewriteGEP = true;
+      }
     }
 
     // Try to scalarize the final index.
diff --git a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
index e27d5d772a7a4..b7b23c60f7fc1 100644
--- a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
@@ -112,7 +112,8 @@ define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
 
 define <4 x i32> @scalar_prefix(ptr %base, i64 %index, <4 x i64> %vecidx) {
 ; CHECK-LABEL: @scalar_prefix(
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr [256 x i32], ptr [[BASE:%.*]], i64 [[INDEX:%.*]], <4 x i64> [[VECIDX:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr [256 x i32], ptr [[BASE:%.*]], i64 [[INDEX:%.*]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i32, ptr [[TMP1]], <4 x i64> [[VECIDX:%.*]]
 ; CHECK-NEXT:    [[RES:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP2]], i32 4, <4 x i1> splat (i1 true), <4 x i32> undef)
 ; CHECK-NEXT:    ret <4 x i32> [[RES]]
 ;
@@ -123,9 +124,8 @@ define <4 x i32> @scalar_prefix(ptr %base, i64 %index, <4 x i64> %vecidx) {
 
 define <4 x i32> @scalar_prefix_with_splat(ptr %base, i64 %index, <4 x i64> %vecidx) {
 ; CHECK-LABEL: @scalar_prefix_with_splat(
-; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i64> poison, i64 [[INDEX:%.*]], i32 0
-; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i64> [[BROADCAST_SPLATINSERT]], <4 x i64> poison, <4 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr [256 x i32], ptr [[BASE:%.*]], <4 x i64> [[BROADCAST_SPLAT]], <4 x i64> [[VECIDX:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr [256 x i32], ptr [[BASE:%.*]], i64 [[INDEX:%.*]], i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i32, ptr [[TMP1]], <4 x i64> [[VECIDX:%.*]]
 ; CHECK-NEXT:    [[RES:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP2]], i32 4, <4 x i1> splat (i1 true), <4 x i32> undef)
 ; CHECK-NEXT:    ret <4 x i32> [[RES]]
 ;
@@ -139,7 +139,8 @@ define <4 x i32> @scalar_prefix_with_splat(ptr %base, i64 %index, <4 x i64> %vec
 
 define <4 x i32> @scalar_prefix_with_constant_splat(ptr %base, <4 x i64> %vecidx) {
 ; CHECK-LABEL: @scalar_prefix_with_constant_splat(
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr [256 x i32], ptr [[BASE:%.*]], <4 x i64> splat (i64 20), <4 x i64> [[VECIDX:%.*]]
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr [256 x i32], ptr [[BASE:%.*]], i64 20, i64 0
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i32, ptr [[TMP1]], <4 x i64> [[VECIDX:%.*]]
 ; CHECK-NEXT:    [[RES:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP2]], i32 4, <4 x i1> splat (i1 true), <4 x i32> undef)
 ; CHECK-NEXT:    ret <4 x i32> [[RES]]
 ;

``````````

</details>


https://github.com/llvm/llvm-project/pull/145952


More information about the llvm-commits mailing list