[llvm] b9c4733 - [DAG] Move one-use add of splat to base of scatter/gather
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 22 18:46:12 PDT 2022
Author: Philip Reames
Date: 2022-09-22T18:45:12-07:00
New Revision: b9c473307954e62f5f756aa7af315d0ffe707634
URL: https://github.com/llvm/llvm-project/commit/b9c473307954e62f5f756aa7af315d0ffe707634
DIFF: https://github.com/llvm/llvm-project/commit/b9c473307954e62f5f756aa7af315d0ffe707634.diff
LOG: [DAG] Move one-use add of splat to base of scatter/gather
This extends the uniform base transform used with scatter/gather to support one-use vector adds-of-splats with a non-zero base. This has the effect of essentially reassociating an add from vector to scalar domain.
The motivation is to improve the lowering of scatter/gather operations fed by complex geps.
Differential Revision: https://reviews.llvm.org/D134472
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2aad144e9a7c6..dbaf9545de7d0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10668,23 +10668,33 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
}
bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
- SelectionDAG &DAG) {
- if (!isNullConstant(BasePtr) || Index.getOpcode() != ISD::ADD)
+ SelectionDAG &DAG, const SDLoc &DL) {
+ if (Index.getOpcode() != ISD::ADD)
return false;
// Only perform the transformation when existing operands can be reused.
if (IndexIsScaled)
return false;
+ if (!isNullConstant(BasePtr) && !Index.hasOneUse())
+ return false;
+
+ EVT VT = BasePtr.getValueType();
if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(0));
- SplatVal && SplatVal.getValueType() == BasePtr.getValueType()) {
- BasePtr = SplatVal;
+ SplatVal && SplatVal.getValueType() == VT) {
+ if (isNullConstant(BasePtr))
+ BasePtr = SplatVal;
+ else
+ BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
Index = Index.getOperand(1);
return true;
}
if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(1));
- SplatVal && SplatVal.getValueType() == BasePtr.getValueType()) {
- BasePtr = SplatVal;
+ SplatVal && SplatVal.getValueType() == VT) {
+ if (isNullConstant(BasePtr))
+ BasePtr = SplatVal;
+ else
+ BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
Index = Index.getOperand(0);
return true;
}
@@ -10739,7 +10749,7 @@ SDValue DAGCombiner::visitVPSCATTER(SDNode *N) {
if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
return Chain;
- if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) {
+ if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
DL, Ops, MSC->getMemOperand(), IndexType);
@@ -10769,7 +10779,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
return Chain;
- if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) {
+ if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
DL, Ops, MSC->getMemOperand(), IndexType,
@@ -10861,7 +10871,7 @@ SDValue DAGCombiner::visitVPGATHER(SDNode *N) {
ISD::MemIndexType IndexType = MGT->getIndexType();
SDLoc DL(N);
- if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) {
+ if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
return DAG.getGatherVP(
DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
@@ -10893,7 +10903,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
return CombineTo(N, PassThru, MGT->getChain());
- if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) {
+ if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
return DAG.getMaskedGather(
DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
diff --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
index 9257a6a54ba82..bdede039a1202 100644
--- a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
+++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
@@ -105,18 +105,16 @@ define void @scatter_i8_index_offset_maximum_plus_one(i8* %base, i64 %offset, <v
; CHECK-NEXT: rdvl x8, #1
; CHECK-NEXT: mov w9, #67108864
; CHECK-NEXT: lsr x8, x8, #4
-; CHECK-NEXT: add x10, x0, x1
+; CHECK-NEXT: add x11, x0, x1
+; CHECK-NEXT: mov w10, #33554432
; CHECK-NEXT: punpklo p1.h, p0.b
-; CHECK-NEXT: uunpklo z3.d, z0.s
-; CHECK-NEXT: mul x8, x8, x9
-; CHECK-NEXT: mov w9, #33554432
+; CHECK-NEXT: madd x8, x8, x9, x11
+; CHECK-NEXT: uunpklo z2.d, z0.s
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: uunpkhi z0.d, z0.s
-; CHECK-NEXT: index z1.d, #0, x9
-; CHECK-NEXT: mov z2.d, x8
-; CHECK-NEXT: st1b { z3.d }, p1, [x10, z1.d]
-; CHECK-NEXT: add z2.d, z1.d, z2.d
-; CHECK-NEXT: st1b { z0.d }, p0, [x10, z2.d]
+; CHECK-NEXT: index z1.d, #0, x10
+; CHECK-NEXT: st1b { z2.d }, p1, [x11, z1.d]
+; CHECK-NEXT: st1b { z0.d }, p0, [x8, z1.d]
; CHECK-NEXT: ret
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
@@ -138,18 +136,16 @@ define void @scatter_i8_index_offset_minimum_minus_one(i8* %base, i64 %offset, <
; CHECK-NEXT: mov x9, #-2
; CHECK-NEXT: lsr x8, x8, #4
; CHECK-NEXT: movk x9, #64511, lsl #16
-; CHECK-NEXT: add x10, x0, x1
+; CHECK-NEXT: add x11, x0, x1
+; CHECK-NEXT: mov x10, #-33554433
+; CHECK-NEXT: madd x8, x8, x9, x11
; CHECK-NEXT: punpklo p1.h, p0.b
-; CHECK-NEXT: mul x8, x8, x9
-; CHECK-NEXT: mov x9, #-33554433
-; CHECK-NEXT: uunpklo z3.d, z0.s
+; CHECK-NEXT: uunpklo z2.d, z0.s
; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: index z1.d, #0, x10
; CHECK-NEXT: uunpkhi z0.d, z0.s
-; CHECK-NEXT: index z1.d, #0, x9
-; CHECK-NEXT: mov z2.d, x8
-; CHECK-NEXT: st1b { z3.d }, p1, [x10, z1.d]
-; CHECK-NEXT: add z2.d, z1.d, z2.d
-; CHECK-NEXT: st1b { z0.d }, p0, [x10, z2.d]
+; CHECK-NEXT: st1b { z2.d }, p1, [x11, z1.d]
+; CHECK-NEXT: st1b { z0.d }, p0, [x8, z1.d]
; CHECK-NEXT: ret
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
@@ -170,18 +166,16 @@ define void @scatter_i8_index_stride_too_big(i8* %base, i64 %offset, <vscale x 4
; CHECK-NEXT: rdvl x8, #1
; CHECK-NEXT: mov x9, #-9223372036854775808
; CHECK-NEXT: lsr x8, x8, #4
-; CHECK-NEXT: add x10, x0, x1
+; CHECK-NEXT: add x11, x0, x1
+; CHECK-NEXT: mov x10, #4611686018427387904
; CHECK-NEXT: punpklo p1.h, p0.b
-; CHECK-NEXT: uunpklo z3.d, z0.s
-; CHECK-NEXT: mul x8, x8, x9
-; CHECK-NEXT: mov x9, #4611686018427387904
+; CHECK-NEXT: madd x8, x8, x9, x11
+; CHECK-NEXT: uunpklo z2.d, z0.s
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: uunpkhi z0.d, z0.s
-; CHECK-NEXT: index z1.d, #0, x9
-; CHECK-NEXT: mov z2.d, x8
-; CHECK-NEXT: st1b { z3.d }, p1, [x10, z1.d]
-; CHECK-NEXT: add z2.d, z1.d, z2.d
-; CHECK-NEXT: st1b { z0.d }, p0, [x10, z2.d]
+; CHECK-NEXT: index z1.d, #0, x10
+; CHECK-NEXT: st1b { z2.d }, p1, [x11, z1.d]
+; CHECK-NEXT: st1b { z0.d }, p0, [x8, z1.d]
; CHECK-NEXT: ret
%t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
%t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
diff --git a/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll b/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll
index cc7b2815897fa..7991eb98571c1 100644
--- a/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll
@@ -10,25 +10,23 @@ define void @complex_gep(ptr %p, <vscale x 2 x i64> %vec.ind, <vscale x 2 x i1>
; RV32-LABEL: complex_gep:
; RV32: # %bb.0:
; RV32-NEXT: vsetvli a1, zero, e32, m1, ta, mu
-; RV32-NEXT: vmv.v.x v10, a0
-; RV32-NEXT: vnsrl.wi v11, v8, 0
-; RV32-NEXT: li a0, 48
-; RV32-NEXT: vmadd.vx v11, a0, v10
-; RV32-NEXT: vmv.v.i v8, 0
-; RV32-NEXT: li a0, 28
-; RV32-NEXT: vsoxei32.v v8, (a0), v11, v0.t
+; RV32-NEXT: vnsrl.wi v10, v8, 0
+; RV32-NEXT: li a1, 48
+; RV32-NEXT: vmul.vx v8, v10, a1
+; RV32-NEXT: addi a0, a0, 28
+; RV32-NEXT: vmv.v.i v9, 0
+; RV32-NEXT: vsoxei32.v v9, (a0), v8, v0.t
; RV32-NEXT: ret
;
; RV64-LABEL: complex_gep:
; RV64: # %bb.0:
-; RV64-NEXT: vsetvli a1, zero, e64, m2, ta, mu
-; RV64-NEXT: vmv.v.x v10, a0
-; RV64-NEXT: li a0, 56
-; RV64-NEXT: vmacc.vx v10, a0, v8
+; RV64-NEXT: li a1, 56
+; RV64-NEXT: vsetvli a2, zero, e64, m2, ta, mu
+; RV64-NEXT: vmul.vx v8, v8, a1
+; RV64-NEXT: addi a0, a0, 32
; RV64-NEXT: vsetvli zero, zero, e32, m1, ta, mu
-; RV64-NEXT: vmv.v.i v8, 0
-; RV64-NEXT: li a0, 32
-; RV64-NEXT: vsoxei64.v v8, (a0), v10, v0.t
+; RV64-NEXT: vmv.v.i v10, 0
+; RV64-NEXT: vsoxei64.v v10, (a0), v8, v0.t
; RV64-NEXT: ret
%gep = getelementptr inbounds %struct, ptr %p, <vscale x 2 x i64> %vec.ind, i32 5
call void @llvm.masked.scatter.nxv2i32.nxv2p0(<vscale x 2 x i32> zeroinitializer, <vscale x 2 x ptr> %gep, i32 8, <vscale x 2 x i1> %m)
More information about the llvm-commits
mailing list