[llvm] b9c4733 - [DAG] Move one-use add of splat to base of scatter/gather

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 22 18:46:12 PDT 2022


Author: Philip Reames
Date: 2022-09-22T18:45:12-07:00
New Revision: b9c473307954e62f5f756aa7af315d0ffe707634

URL: https://github.com/llvm/llvm-project/commit/b9c473307954e62f5f756aa7af315d0ffe707634
DIFF: https://github.com/llvm/llvm-project/commit/b9c473307954e62f5f756aa7af315d0ffe707634.diff

LOG: [DAG] Move one-use add of splat to base of scatter/gather

This extends the uniform base transform used with scatter/gather to support one-use vector adds-of-splats with a non-zero base. This has the effect of essentially reassociating an add from vector to scalar domain.

The motivation is to improve the lowering of scatter/gather operations fed by complex geps.

Differential Revision: https://reviews.llvm.org/D134472

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
    llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2aad144e9a7c6..dbaf9545de7d0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10668,23 +10668,33 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
 }
 
 bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
-                       SelectionDAG &DAG) {
-  if (!isNullConstant(BasePtr) || Index.getOpcode() != ISD::ADD)
+                       SelectionDAG &DAG, const SDLoc &DL) {
+  if (Index.getOpcode() != ISD::ADD)
     return false;
 
   // Only perform the transformation when existing operands can be reused.
   if (IndexIsScaled)
     return false;
 
+  if (!isNullConstant(BasePtr) && !Index.hasOneUse())
+    return false;
+
+  EVT VT = BasePtr.getValueType();
   if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(0));
-      SplatVal && SplatVal.getValueType() == BasePtr.getValueType()) {
-    BasePtr = SplatVal;
+      SplatVal && SplatVal.getValueType() == VT) {
+    if (isNullConstant(BasePtr))
+      BasePtr = SplatVal;
+    else
+      BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
     Index = Index.getOperand(1);
     return true;
   }
   if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(1));
-      SplatVal && SplatVal.getValueType() == BasePtr.getValueType()) {
-    BasePtr = SplatVal;
+      SplatVal && SplatVal.getValueType() == VT) {
+    if (isNullConstant(BasePtr))
+      BasePtr = SplatVal;
+    else
+      BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal);
     Index = Index.getOperand(0);
     return true;
   }
@@ -10739,7 +10749,7 @@ SDValue DAGCombiner::visitVPSCATTER(SDNode *N) {
   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
     return Chain;
 
-  if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) {
+  if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
     SDValue Ops[] = {Chain, StoreVal, BasePtr, Index, Scale, Mask, VL};
     return DAG.getScatterVP(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
                             DL, Ops, MSC->getMemOperand(), IndexType);
@@ -10769,7 +10779,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
     return Chain;
 
-  if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) {
+  if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) {
     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
     return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
                                 DL, Ops, MSC->getMemOperand(), IndexType,
@@ -10861,7 +10871,7 @@ SDValue DAGCombiner::visitVPGATHER(SDNode *N) {
   ISD::MemIndexType IndexType = MGT->getIndexType();
   SDLoc DL(N);
 
-  if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) {
+  if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
     SDValue Ops[] = {Chain, BasePtr, Index, Scale, Mask, VL};
     return DAG.getGatherVP(
         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
@@ -10893,7 +10903,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
     return CombineTo(N, PassThru, MGT->getChain());
 
-  if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) {
+  if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) {
     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
     return DAG.getMaskedGather(
         DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,

diff  --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
index 9257a6a54ba82..bdede039a1202 100644
--- a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
+++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
@@ -105,18 +105,16 @@ define void @scatter_i8_index_offset_maximum_plus_one(i8* %base, i64 %offset, <v
 ; CHECK-NEXT:    rdvl x8, #1
 ; CHECK-NEXT:    mov w9, #67108864
 ; CHECK-NEXT:    lsr x8, x8, #4
-; CHECK-NEXT:    add x10, x0, x1
+; CHECK-NEXT:    add x11, x0, x1
+; CHECK-NEXT:    mov w10, #33554432
 ; CHECK-NEXT:    punpklo p1.h, p0.b
-; CHECK-NEXT:    uunpklo z3.d, z0.s
-; CHECK-NEXT:    mul x8, x8, x9
-; CHECK-NEXT:    mov w9, #33554432
+; CHECK-NEXT:    madd x8, x8, x9, x11
+; CHECK-NEXT:    uunpklo z2.d, z0.s
 ; CHECK-NEXT:    punpkhi p0.h, p0.b
 ; CHECK-NEXT:    uunpkhi z0.d, z0.s
-; CHECK-NEXT:    index z1.d, #0, x9
-; CHECK-NEXT:    mov z2.d, x8
-; CHECK-NEXT:    st1b { z3.d }, p1, [x10, z1.d]
-; CHECK-NEXT:    add z2.d, z1.d, z2.d
-; CHECK-NEXT:    st1b { z0.d }, p0, [x10, z2.d]
+; CHECK-NEXT:    index z1.d, #0, x10
+; CHECK-NEXT:    st1b { z2.d }, p1, [x11, z1.d]
+; CHECK-NEXT:    st1b { z0.d }, p0, [x8, z1.d]
 ; CHECK-NEXT:    ret
   %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
   %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
@@ -138,18 +136,16 @@ define void @scatter_i8_index_offset_minimum_minus_one(i8* %base, i64 %offset, <
 ; CHECK-NEXT:    mov x9, #-2
 ; CHECK-NEXT:    lsr x8, x8, #4
 ; CHECK-NEXT:    movk x9, #64511, lsl #16
-; CHECK-NEXT:    add x10, x0, x1
+; CHECK-NEXT:    add x11, x0, x1
+; CHECK-NEXT:    mov x10, #-33554433
+; CHECK-NEXT:    madd x8, x8, x9, x11
 ; CHECK-NEXT:    punpklo p1.h, p0.b
-; CHECK-NEXT:    mul x8, x8, x9
-; CHECK-NEXT:    mov x9, #-33554433
-; CHECK-NEXT:    uunpklo z3.d, z0.s
+; CHECK-NEXT:    uunpklo z2.d, z0.s
 ; CHECK-NEXT:    punpkhi p0.h, p0.b
+; CHECK-NEXT:    index z1.d, #0, x10
 ; CHECK-NEXT:    uunpkhi z0.d, z0.s
-; CHECK-NEXT:    index z1.d, #0, x9
-; CHECK-NEXT:    mov z2.d, x8
-; CHECK-NEXT:    st1b { z3.d }, p1, [x10, z1.d]
-; CHECK-NEXT:    add z2.d, z1.d, z2.d
-; CHECK-NEXT:    st1b { z0.d }, p0, [x10, z2.d]
+; CHECK-NEXT:    st1b { z2.d }, p1, [x11, z1.d]
+; CHECK-NEXT:    st1b { z0.d }, p0, [x8, z1.d]
 ; CHECK-NEXT:    ret
   %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
   %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
@@ -170,18 +166,16 @@ define void @scatter_i8_index_stride_too_big(i8* %base, i64 %offset, <vscale x 4
 ; CHECK-NEXT:    rdvl x8, #1
 ; CHECK-NEXT:    mov x9, #-9223372036854775808
 ; CHECK-NEXT:    lsr x8, x8, #4
-; CHECK-NEXT:    add x10, x0, x1
+; CHECK-NEXT:    add x11, x0, x1
+; CHECK-NEXT:    mov x10, #4611686018427387904
 ; CHECK-NEXT:    punpklo p1.h, p0.b
-; CHECK-NEXT:    uunpklo z3.d, z0.s
-; CHECK-NEXT:    mul x8, x8, x9
-; CHECK-NEXT:    mov x9, #4611686018427387904
+; CHECK-NEXT:    madd x8, x8, x9, x11
+; CHECK-NEXT:    uunpklo z2.d, z0.s
 ; CHECK-NEXT:    punpkhi p0.h, p0.b
 ; CHECK-NEXT:    uunpkhi z0.d, z0.s
-; CHECK-NEXT:    index z1.d, #0, x9
-; CHECK-NEXT:    mov z2.d, x8
-; CHECK-NEXT:    st1b { z3.d }, p1, [x10, z1.d]
-; CHECK-NEXT:    add z2.d, z1.d, z2.d
-; CHECK-NEXT:    st1b { z0.d }, p0, [x10, z2.d]
+; CHECK-NEXT:    index z1.d, #0, x10
+; CHECK-NEXT:    st1b { z2.d }, p1, [x11, z1.d]
+; CHECK-NEXT:    st1b { z0.d }, p0, [x8, z1.d]
 ; CHECK-NEXT:    ret
   %t0 = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
   %t1 = shufflevector <vscale x 4 x i64> %t0, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer

diff  --git a/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll b/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll
index cc7b2815897fa..7991eb98571c1 100644
--- a/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll
@@ -10,25 +10,23 @@ define void @complex_gep(ptr %p, <vscale x 2 x i64> %vec.ind, <vscale x 2 x i1>
 ; RV32-LABEL: complex_gep:
 ; RV32:       # %bb.0:
 ; RV32-NEXT:    vsetvli a1, zero, e32, m1, ta, mu
-; RV32-NEXT:    vmv.v.x v10, a0
-; RV32-NEXT:    vnsrl.wi v11, v8, 0
-; RV32-NEXT:    li a0, 48
-; RV32-NEXT:    vmadd.vx v11, a0, v10
-; RV32-NEXT:    vmv.v.i v8, 0
-; RV32-NEXT:    li a0, 28
-; RV32-NEXT:    vsoxei32.v v8, (a0), v11, v0.t
+; RV32-NEXT:    vnsrl.wi v10, v8, 0
+; RV32-NEXT:    li a1, 48
+; RV32-NEXT:    vmul.vx v8, v10, a1
+; RV32-NEXT:    addi a0, a0, 28
+; RV32-NEXT:    vmv.v.i v9, 0
+; RV32-NEXT:    vsoxei32.v v9, (a0), v8, v0.t
 ; RV32-NEXT:    ret
 ;
 ; RV64-LABEL: complex_gep:
 ; RV64:       # %bb.0:
-; RV64-NEXT:    vsetvli a1, zero, e64, m2, ta, mu
-; RV64-NEXT:    vmv.v.x v10, a0
-; RV64-NEXT:    li a0, 56
-; RV64-NEXT:    vmacc.vx v10, a0, v8
+; RV64-NEXT:    li a1, 56
+; RV64-NEXT:    vsetvli a2, zero, e64, m2, ta, mu
+; RV64-NEXT:    vmul.vx v8, v8, a1
+; RV64-NEXT:    addi a0, a0, 32
 ; RV64-NEXT:    vsetvli zero, zero, e32, m1, ta, mu
-; RV64-NEXT:    vmv.v.i v8, 0
-; RV64-NEXT:    li a0, 32
-; RV64-NEXT:    vsoxei64.v v8, (a0), v10, v0.t
+; RV64-NEXT:    vmv.v.i v10, 0
+; RV64-NEXT:    vsoxei64.v v10, (a0), v8, v0.t
 ; RV64-NEXT:    ret
   %gep = getelementptr inbounds %struct, ptr %p, <vscale x 2 x i64> %vec.ind, i32 5
   call void @llvm.masked.scatter.nxv2i32.nxv2p0(<vscale x 2 x i32> zeroinitializer, <vscale x 2 x ptr> %gep, i32 8, <vscale x 2 x i1> %m)


        


More information about the llvm-commits mailing list