[llvm] [CGP] Scalarize non-constant indices for geps feeding gather/scatter (PR #145952)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 26 12:07:46 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Philip Reames (preames)
<details>
<summary>Changes</summary>
Implementing an existing TODO, and simplifying code in the process.
---
Full diff: https://github.com/llvm/llvm-project/pull/145952.diff
2 Files Affected:
- (modified) llvm/lib/CodeGen/CodeGenPrepare.cpp (+9-23)
- (modified) llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll (+6-5)
``````````diff
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 43574a54c37dd..4786b0dc0cb6e 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -6239,7 +6239,7 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
/// pointer there's nothing we can do.
///
/// If we have a GEP with more than 2 indices where the middle indices are all
-/// zeroes, we can replace it with 2 GEPs where the second has 2 operands.
+/// scalarizable, we can replace it with 2 GEPs where the second has 2 operands.
///
/// If the final index isn't a vector or is a splat, we can emit a scalar GEP
/// followed by a GEP with an all zeroes vector index. This will enable
@@ -6262,30 +6262,16 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
SmallVector<Value *, 2> Ops(GEP->operands());
bool RewriteGEP = false;
-
- if (Ops[0]->getType()->isVectorTy()) {
- Ops[0] = getSplatValue(Ops[0]);
- if (!Ops[0])
- return false;
- RewriteGEP = true;
- }
-
unsigned FinalIndex = Ops.size() - 1;
- // Ensure all but the last index is 0.
- // FIXME: This isn't strictly required. All that's required is that they are
- // all scalars or splats.
- for (unsigned i = 1; i < FinalIndex; ++i) {
- auto *C = dyn_cast<Constant>(Ops[i]);
- if (!C)
- return false;
- if (isa<VectorType>(C->getType()))
- C = C->getSplatValue();
- auto *CI = dyn_cast_or_null<ConstantInt>(C);
- if (!CI || !CI->isZero())
- return false;
- // Scalarize the index if needed.
- Ops[i] = CI;
+ // Ensure all but the last index are scalar
+ for (unsigned i = 0; i < FinalIndex; ++i) {
+ if (isa<VectorType>(Ops[i]->getType())) {
+ Ops[i] = getSplatValue(Ops[i]);
+ if (!Ops[i])
+ return false;
+ RewriteGEP = true;
+ }
}
// Try to scalarize the final index.
diff --git a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
index e27d5d772a7a4..b7b23c60f7fc1 100644
--- a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
@@ -112,7 +112,8 @@ define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
define <4 x i32> @scalar_prefix(ptr %base, i64 %index, <4 x i64> %vecidx) {
; CHECK-LABEL: @scalar_prefix(
-; CHECK-NEXT: [[TMP2:%.*]] = getelementptr [256 x i32], ptr [[BASE:%.*]], i64 [[INDEX:%.*]], <4 x i64> [[VECIDX:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr [256 x i32], ptr [[BASE:%.*]], i64 [[INDEX:%.*]], i64 0
+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[TMP1]], <4 x i64> [[VECIDX:%.*]]
; CHECK-NEXT: [[RES:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP2]], i32 4, <4 x i1> splat (i1 true), <4 x i32> undef)
; CHECK-NEXT: ret <4 x i32> [[RES]]
;
@@ -123,9 +124,8 @@ define <4 x i32> @scalar_prefix(ptr %base, i64 %index, <4 x i64> %vecidx) {
define <4 x i32> @scalar_prefix_with_splat(ptr %base, i64 %index, <4 x i64> %vecidx) {
; CHECK-LABEL: @scalar_prefix_with_splat(
-; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i64> poison, i64 [[INDEX:%.*]], i32 0
-; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i64> [[BROADCAST_SPLATINSERT]], <4 x i64> poison, <4 x i32> zeroinitializer
-; CHECK-NEXT: [[TMP2:%.*]] = getelementptr [256 x i32], ptr [[BASE:%.*]], <4 x i64> [[BROADCAST_SPLAT]], <4 x i64> [[VECIDX:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr [256 x i32], ptr [[BASE:%.*]], i64 [[INDEX:%.*]], i64 0
+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[TMP1]], <4 x i64> [[VECIDX:%.*]]
; CHECK-NEXT: [[RES:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP2]], i32 4, <4 x i1> splat (i1 true), <4 x i32> undef)
; CHECK-NEXT: ret <4 x i32> [[RES]]
;
@@ -139,7 +139,8 @@ define <4 x i32> @scalar_prefix_with_splat(ptr %base, i64 %index, <4 x i64> %vec
define <4 x i32> @scalar_prefix_with_constant_splat(ptr %base, <4 x i64> %vecidx) {
; CHECK-LABEL: @scalar_prefix_with_constant_splat(
-; CHECK-NEXT: [[TMP2:%.*]] = getelementptr [256 x i32], ptr [[BASE:%.*]], <4 x i64> splat (i64 20), <4 x i64> [[VECIDX:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr [256 x i32], ptr [[BASE:%.*]], i64 20, i64 0
+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[TMP1]], <4 x i64> [[VECIDX:%.*]]
; CHECK-NEXT: [[RES:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP2]], i32 4, <4 x i1> splat (i1 true), <4 x i32> undef)
; CHECK-NEXT: ret <4 x i32> [[RES]]
;
``````````
</details>
https://github.com/llvm/llvm-project/pull/145952
More information about the llvm-commits
mailing list