[llvm] [SDAG] Prefer scalar for prefix of vector GEP expansion (PR #146719)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 2 08:11:03 PDT 2025
https://github.com/preames created https://github.com/llvm/llvm-project/pull/146719
When generating SDAG for a getelementptr with a vector result, we were previously generating splats for each scalar operand. This essentially has the effect of agressively vectorizing the sequence, and leaving it later combines to scalarize if profitable.
Instead, we can keep the accumulating address as a scalar for as long as the prefix of operands allows before lazily converting to vector on the first vector operand. This both better fits hardware which frequently has a scalar base on the scatter/gather instructions, and reduces the addressing cost even when not as otherwise we end up with a scalar to vector domain crossing for each scalar operand.
Note that constant splat offsets are treated as scalar for the above, and only variable offsets can force a conversion to vector.
>From 6cf9f3f68e34888593427472e844e792669d70bb Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Mon, 30 Jun 2025 18:36:02 -0700
Subject: [PATCH] [SDAG] Prefer scalar for prefix of vector GEP expansion
When generating SDAG for a getelementptr with a vector result, we were
previously generating splats for each scalar operand. This essentially
has the effect of agressively vectorizing the sequence, and leaving it
later combines to scalarize if profitable.
Instead, we can keep the accumulating address as a scalar for as long
as the prefix of operands allows before lazily converting to vector
on the first vector operand. This both better fits hardware which
frequently has a scalar base on the scatter/gather instructions,
and reduces the addressing cost even when not as otherwise we end
up with a scalar to vector domain crossing for each scalar operand.
Note that constant splat offsets are treated as scalar for the above,
and only variable offsets can force a conversion to vector.
---
.../SelectionDAG/SelectionDAGBuilder.cpp | 33 ++++----
llvm/test/CodeGen/AArch64/ptradd.ll | 82 ++++++-------------
llvm/test/CodeGen/RISCV/rvv/mgather-sdnode.ll | 23 ++----
3 files changed, 52 insertions(+), 86 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 04d6fd5f48cc3..06e3a0717a02c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4336,19 +4336,13 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
auto &TLI = DAG.getTargetLoweringInfo();
GEPNoWrapFlags NW = cast<GEPOperator>(I).getNoWrapFlags();
- // Normalize Vector GEP - all scalar operands should be converted to the
- // splat vector.
+ // For a vector GEP, keep the prefix scalar as long as possible, than
+ // convert any scalars encountered after the first vector operand to vectors.
bool IsVectorGEP = I.getType()->isVectorTy();
ElementCount VectorElementCount =
IsVectorGEP ? cast<VectorType>(I.getType())->getElementCount()
: ElementCount::getFixed(0);
- if (IsVectorGEP && !N.getValueType().isVector()) {
- LLVMContext &Context = *DAG.getContext();
- EVT VT = EVT::getVectorVT(Context, N.getValueType(), VectorElementCount);
- N = DAG.getSplat(VT, dl, N);
- }
-
for (gep_type_iterator GTI = gep_type_begin(&I), E = gep_type_end(&I);
GTI != E; ++GTI) {
const Value *Idx = GTI.getOperand();
@@ -4396,7 +4390,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
APInt Offs = ElementMul * CI->getValue().sextOrTrunc(IdxSize);
LLVMContext &Context = *DAG.getContext();
SDValue OffsVal;
- if (IsVectorGEP)
+ if (N.getValueType().isVector())
OffsVal = DAG.getConstant(
Offs, dl, EVT::getVectorVT(Context, IdxTy, VectorElementCount));
else
@@ -4418,10 +4412,16 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
// N = N + Idx * ElementMul;
SDValue IdxN = getValue(Idx);
- if (!IdxN.getValueType().isVector() && IsVectorGEP) {
- EVT VT = EVT::getVectorVT(*Context, IdxN.getValueType(),
- VectorElementCount);
- IdxN = DAG.getSplat(VT, dl, IdxN);
+ if (IdxN.getValueType().isVector() != N.getValueType().isVector()) {
+ if (N.getValueType().isVector()) {
+ EVT VT = EVT::getVectorVT(*Context, IdxN.getValueType(),
+ VectorElementCount);
+ IdxN = DAG.getSplat(VT, dl, IdxN);
+ } else {
+ EVT VT =
+ EVT::getVectorVT(*Context, N.getValueType(), VectorElementCount);
+ N = DAG.getSplat(VT, dl, N);
+ }
}
// If the index is smaller or larger than intptr_t, truncate or extend
@@ -4442,7 +4442,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
SDValue VScale = DAG.getNode(
ISD::VSCALE, dl, VScaleTy,
DAG.getConstant(ElementMul.getZExtValue(), dl, VScaleTy));
- if (IsVectorGEP)
+ if (N.getValueType().isVector())
VScale = DAG.getSplatVector(N.getValueType(), dl, VScale);
IdxN = DAG.getNode(ISD::MUL, dl, N.getValueType(), IdxN, VScale,
ScaleFlags);
@@ -4475,6 +4475,11 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
}
}
+ if (IsVectorGEP && !N.getValueType().isVector()) {
+ EVT VT = EVT::getVectorVT(*Context, N.getValueType(), VectorElementCount);
+ N = DAG.getSplat(VT, dl, N);
+ }
+
MVT PtrTy = TLI.getPointerTy(DAG.getDataLayout(), AS);
MVT PtrMemTy = TLI.getPointerMemTy(DAG.getDataLayout(), AS);
if (IsVectorGEP) {
diff --git a/llvm/test/CodeGen/AArch64/ptradd.ll b/llvm/test/CodeGen/AArch64/ptradd.ll
index 28a8f4303765b..fc364357436e2 100644
--- a/llvm/test/CodeGen/AArch64/ptradd.ll
+++ b/llvm/test/CodeGen/AArch64/ptradd.ll
@@ -285,19 +285,11 @@ entry:
}
define <1 x ptr> @vector_gep_v1i64_c10(ptr %b) {
-; CHECK-SD-LABEL: vector_gep_v1i64_c10:
-; CHECK-SD: // %bb.0: // %entry
-; CHECK-SD-NEXT: mov w8, #10 // =0xa
-; CHECK-SD-NEXT: fmov d0, x0
-; CHECK-SD-NEXT: fmov d1, x8
-; CHECK-SD-NEXT: add d0, d0, d1
-; CHECK-SD-NEXT: ret
-;
-; CHECK-GI-LABEL: vector_gep_v1i64_c10:
-; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: add x8, x0, #10
-; CHECK-GI-NEXT: fmov d0, x8
-; CHECK-GI-NEXT: ret
+; CHECK-LABEL: vector_gep_v1i64_c10:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: add x8, x0, #10
+; CHECK-NEXT: fmov d0, x8
+; CHECK-NEXT: ret
entry:
%g = getelementptr i8, ptr %b, <1 x i64> <i64 10>
ret <1 x ptr> %g
@@ -306,10 +298,8 @@ entry:
define <2 x ptr> @vector_gep_v2i64_c10(ptr %b) {
; CHECK-SD-LABEL: vector_gep_v2i64_c10:
; CHECK-SD: // %bb.0: // %entry
-; CHECK-SD-NEXT: mov w8, #10 // =0xa
-; CHECK-SD-NEXT: dup v0.2d, x0
-; CHECK-SD-NEXT: dup v1.2d, x8
-; CHECK-SD-NEXT: add v0.2d, v0.2d, v1.2d
+; CHECK-SD-NEXT: add x8, x0, #10
+; CHECK-SD-NEXT: dup v0.2d, x8
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: vector_gep_v2i64_c10:
@@ -327,15 +317,10 @@ entry:
define <3 x ptr> @vector_gep_v3i64_c10(ptr %b) {
; CHECK-SD-LABEL: vector_gep_v3i64_c10:
; CHECK-SD: // %bb.0: // %entry
-; CHECK-SD-NEXT: mov w8, #10 // =0xa
-; CHECK-SD-NEXT: dup v0.2d, x0
-; CHECK-SD-NEXT: fmov d3, x0
-; CHECK-SD-NEXT: dup v2.2d, x8
-; CHECK-SD-NEXT: add v0.2d, v0.2d, v2.2d
-; CHECK-SD-NEXT: add d2, d3, d2
-; CHECK-SD-NEXT: ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-SD-NEXT: // kill: def $d0 killed $d0 killed $q0
-; CHECK-SD-NEXT: // kill: def $d1 killed $d1 killed $q1
+; CHECK-SD-NEXT: add x8, x0, #10
+; CHECK-SD-NEXT: fmov d0, x8
+; CHECK-SD-NEXT: fmov d1, d0
+; CHECK-SD-NEXT: fmov d2, d0
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: vector_gep_v3i64_c10:
@@ -356,10 +341,8 @@ entry:
define <4 x ptr> @vector_gep_v4i64_c10(ptr %b) {
; CHECK-SD-LABEL: vector_gep_v4i64_c10:
; CHECK-SD: // %bb.0: // %entry
-; CHECK-SD-NEXT: mov w8, #10 // =0xa
-; CHECK-SD-NEXT: dup v0.2d, x0
-; CHECK-SD-NEXT: dup v1.2d, x8
-; CHECK-SD-NEXT: add v0.2d, v0.2d, v1.2d
+; CHECK-SD-NEXT: add x8, x0, #10
+; CHECK-SD-NEXT: dup v0.2d, x8
; CHECK-SD-NEXT: mov v1.16b, v0.16b
; CHECK-SD-NEXT: ret
;
@@ -377,19 +360,11 @@ entry:
}
define <1 x ptr> @vector_gep_v1i64_cm10(ptr %b) {
-; CHECK-SD-LABEL: vector_gep_v1i64_cm10:
-; CHECK-SD: // %bb.0: // %entry
-; CHECK-SD-NEXT: mov x8, #-10 // =0xfffffffffffffff6
-; CHECK-SD-NEXT: fmov d1, x0
-; CHECK-SD-NEXT: fmov d0, x8
-; CHECK-SD-NEXT: add d0, d1, d0
-; CHECK-SD-NEXT: ret
-;
-; CHECK-GI-LABEL: vector_gep_v1i64_cm10:
-; CHECK-GI: // %bb.0: // %entry
-; CHECK-GI-NEXT: sub x8, x0, #10
-; CHECK-GI-NEXT: fmov d0, x8
-; CHECK-GI-NEXT: ret
+; CHECK-LABEL: vector_gep_v1i64_cm10:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: sub x8, x0, #10
+; CHECK-NEXT: fmov d0, x8
+; CHECK-NEXT: ret
entry:
%g = getelementptr i8, ptr %b, <1 x i64> <i64 -10>
ret <1 x ptr> %g
@@ -398,10 +373,8 @@ entry:
define <2 x ptr> @vector_gep_v2i64_cm10(ptr %b) {
; CHECK-SD-LABEL: vector_gep_v2i64_cm10:
; CHECK-SD: // %bb.0: // %entry
-; CHECK-SD-NEXT: mov x8, #-10 // =0xfffffffffffffff6
-; CHECK-SD-NEXT: dup v1.2d, x0
+; CHECK-SD-NEXT: sub x8, x0, #10
; CHECK-SD-NEXT: dup v0.2d, x8
-; CHECK-SD-NEXT: add v0.2d, v1.2d, v0.2d
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: vector_gep_v2i64_cm10:
@@ -419,15 +392,10 @@ entry:
define <3 x ptr> @vector_gep_v3i64_cm10(ptr %b) {
; CHECK-SD-LABEL: vector_gep_v3i64_cm10:
; CHECK-SD: // %bb.0: // %entry
-; CHECK-SD-NEXT: mov x8, #-10 // =0xfffffffffffffff6
-; CHECK-SD-NEXT: dup v0.2d, x0
-; CHECK-SD-NEXT: fmov d3, x0
-; CHECK-SD-NEXT: dup v2.2d, x8
-; CHECK-SD-NEXT: add v0.2d, v0.2d, v2.2d
-; CHECK-SD-NEXT: add d2, d3, d2
-; CHECK-SD-NEXT: ext v1.16b, v0.16b, v0.16b, #8
-; CHECK-SD-NEXT: // kill: def $d0 killed $d0 killed $q0
-; CHECK-SD-NEXT: // kill: def $d1 killed $d1 killed $q1
+; CHECK-SD-NEXT: sub x8, x0, #10
+; CHECK-SD-NEXT: fmov d0, x8
+; CHECK-SD-NEXT: fmov d1, d0
+; CHECK-SD-NEXT: fmov d2, d0
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: vector_gep_v3i64_cm10:
@@ -448,10 +416,8 @@ entry:
define <4 x ptr> @vector_gep_v4i64_cm10(ptr %b) {
; CHECK-SD-LABEL: vector_gep_v4i64_cm10:
; CHECK-SD: // %bb.0: // %entry
-; CHECK-SD-NEXT: mov x8, #-10 // =0xfffffffffffffff6
-; CHECK-SD-NEXT: dup v1.2d, x0
+; CHECK-SD-NEXT: sub x8, x0, #10
; CHECK-SD-NEXT: dup v0.2d, x8
-; CHECK-SD-NEXT: add v0.2d, v1.2d, v0.2d
; CHECK-SD-NEXT: mov v1.16b, v0.16b
; CHECK-SD-NEXT: ret
;
diff --git a/llvm/test/CodeGen/RISCV/rvv/mgather-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/mgather-sdnode.ll
index 3057ee9293992..7c123db5665d4 100644
--- a/llvm/test/CodeGen/RISCV/rvv/mgather-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/mgather-sdnode.ll
@@ -2377,26 +2377,21 @@ define <vscale x 1 x i8> @mgather_baseidx_zext_nxv1i1_nxv1i8(ptr %base, <vscale
define <4 x i32> @scalar_prefix(ptr %base, i32 signext %index, <4 x i32> %vecidx) {
; RV32-LABEL: scalar_prefix:
; RV32: # %bb.0:
+; RV32-NEXT: slli a1, a1, 10
+; RV32-NEXT: add a0, a0, a1
; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; RV32-NEXT: vmv.v.x v9, a1
-; RV32-NEXT: vsll.vi v9, v9, 10
-; RV32-NEXT: vadd.vx v9, v9, a0
; RV32-NEXT: vsll.vi v8, v8, 2
-; RV32-NEXT: vadd.vv v8, v9, v8
-; RV32-NEXT: vluxei32.v v8, (zero), v8
+; RV32-NEXT: vluxei32.v v8, (a0), v8
; RV32-NEXT: ret
;
; RV64-LABEL: scalar_prefix:
; RV64: # %bb.0:
-; RV64-NEXT: li a2, 1024
-; RV64-NEXT: vsetivli zero, 4, e64, m2, ta, ma
-; RV64-NEXT: vmv.v.x v10, a0
-; RV64-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; RV64-NEXT: vmv.v.x v9, a2
-; RV64-NEXT: vwmaccsu.vx v10, a1, v9
-; RV64-NEXT: li a0, 4
-; RV64-NEXT: vwmaccus.vx v10, a0, v8
-; RV64-NEXT: vluxei64.v v8, (zero), v10
+; RV64-NEXT: li a2, 4
+; RV64-NEXT: slli a1, a1, 10
+; RV64-NEXT: vsetivli zero, 4, e32, m1, ta, ma
+; RV64-NEXT: vwmulsu.vx v10, v8, a2
+; RV64-NEXT: add a0, a0, a1
+; RV64-NEXT: vluxei64.v v8, (a0), v10
; RV64-NEXT: ret
%gep = getelementptr [256 x i32], ptr %base, i32 %index, <4 x i32> %vecidx
%res = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> %gep, i32 4, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i32> undef)
More information about the llvm-commits
mailing list