[llvm] 7a0b897 - [DAGCombiner][SVE] Ensure MGATHER/MSCATTER addressing mode combines preserve index scaling

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 29 04:43:38 PDT 2022


Author: Paul Walker
Date: 2022-04-29T12:35:16+01:00
New Revision: 7a0b897e8664d11481230a69a88fca2b2ee5f904

URL: https://github.com/llvm/llvm-project/commit/7a0b897e8664d11481230a69a88fca2b2ee5f904
DIFF: https://github.com/llvm/llvm-project/commit/7a0b897e8664d11481230a69a88fca2b2ee5f904.diff

LOG: [DAGCombiner][SVE] Ensure MGATHER/MSCATTER addressing mode combines preserve index scaling

refineUniformBase and selectGatherScatterAddrMode both attempt the
transformation:

  base(0) + index(A+splat(B)) => base(B) + index(A)

However, this is only safe when index is not implicitly scaled.

Differential Revision: https://reviews.llvm.org/D123222

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 71618eb2bd7c..181ff00184b3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10426,14 +10426,19 @@ static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) {
       TopHalf->isZero() ? RHS->getOperand(1) : LHS->getOperand(1));
 }
 
-bool refineUniformBase(SDValue &BasePtr, SDValue &Index, SelectionDAG &DAG) {
+bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
+                       SelectionDAG &DAG) {
   if (!isNullConstant(BasePtr) || Index.getOpcode() != ISD::ADD)
     return false;
 
+  // Only perform the transformation when existing operands can be reused.
+  if (IndexIsScaled)
+    return false;
+
   // For now we check only the LHS of the add.
   SDValue LHS = Index.getOperand(0);
   SDValue SplatVal = DAG.getSplatValue(LHS);
-  if (!SplatVal)
+  if (!SplatVal || SplatVal.getValueType() != BasePtr.getValueType())
     return false;
 
   BasePtr = SplatVal;
@@ -10481,7 +10486,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
     return Chain;
 
-  if (refineUniformBase(BasePtr, Index, DAG)) {
+  if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) {
     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
     return DAG.getMaskedScatter(
         DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops,
@@ -10576,7 +10581,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
   if (ISD::isConstantSplatVectorAllZeros(Mask.getNode()))
     return CombineTo(N, PassThru, MGT->getChain());
 
-  if (refineUniformBase(BasePtr, Index, DAG)) {
+  if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) {
     SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
     return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
                                MGT->getMemoryVT(), DL, Ops,

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a5dadd112c9e..3f6a36fe49f4 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4656,10 +4656,10 @@ bool getGatherScatterIndexIsExtended(SDValue Index) {
 // VECTOR + IMMEDIATE:
 //    getelementptr nullptr, <vscale x N x T> (splat(#x)) + %indices)
 // -> getelementptr #x, <vscale x N x T> %indices
-void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index, EVT MemVT,
-                                 unsigned &Opcode, bool IsGather,
-                                 SelectionDAG &DAG) {
-  if (!isNullConstant(BasePtr))
+void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index,
+                                 bool IsScaled, EVT MemVT, unsigned &Opcode,
+                                 bool IsGather, SelectionDAG &DAG) {
+  if (!isNullConstant(BasePtr) || IsScaled)
     return;
 
   // FIXME: This will not match for fixed vector type codegen as the nodes in
@@ -4789,7 +4789,7 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
     Index = Index.getOperand(0);
 
   unsigned Opcode = getGatherVecOpcode(IsScaled, IsSigned, IdxNeedsExtend);
-  selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode,
+  selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
                               /*isGather=*/true, DAG);
 
   if (ExtType == ISD::SEXTLOAD)
@@ -4898,7 +4898,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
     Index = Index.getOperand(0);
 
   unsigned Opcode = getScatterVecOpcode(IsScaled, IsSigned, NeedsExtend);
-  selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode,
+  selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
                               /*isGather=*/false, DAG);
 
   if (IsFixedLength) {

diff  --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
index d06cc313ba53..4fdf4a106dbc 100644
--- a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
+++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
@@ -343,12 +343,13 @@ define <vscale x 2 x i64> @masked_gather_nxv2i64_const_with_vec_offsets(<vscale
   ret <vscale x 2 x i64> %data
 }
 
-; TODO: The generated code is wrong because we've lost the scaling applied to
-; %scalar_offset when it's used to calculate %ptrs.
 define <vscale x 2 x i64> @masked_gather_nxv2i64_null_with_vec_plus_scalar_offsets(<vscale x 2 x i64> %vector_offsets, i64 %scalar_offset, <vscale x 2 x i1> %pg) #0 {
 ; CHECK-LABEL: masked_gather_nxv2i64_null_with_vec_plus_scalar_offsets:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x0, z0.d, lsl #3]
+; CHECK-NEXT:    mov x8, xzr
+; CHECK-NEXT:    mov z1.d, x0
+; CHECK-NEXT:    add z0.d, z0.d, z1.d
+; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x8, z0.d, lsl #3]
 ; CHECK-NEXT:    ret
   %scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 %scalar_offset, i64 0
   %scalar_offset.splat = shufflevector <vscale x 2 x i64> %scalar_offset.ins, <vscale x 2 x i64> undef, <vscale x 2 x i32> zeroinitializer
@@ -358,12 +359,11 @@ define <vscale x 2 x i64> @masked_gather_nxv2i64_null_with_vec_plus_scalar_offse
   ret <vscale x 2 x i64> %data
 }
 
-; TODO: The generated code is wrong because we've lost the scaling applied to
-; constant scalar offset (i.e. i64 1)  when it's used to calculate %ptrs.
 define <vscale x 2 x i64> @masked_gather_nxv2i64_null_with__vec_plus_imm_offsets(<vscale x 2 x i64> %vector_offsets, <vscale x 2 x i1> %pg) #0 {
 ; CHECK-LABEL: masked_gather_nxv2i64_null_with__vec_plus_imm_offsets:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov w8, #1
+; CHECK-NEXT:    mov x8, xzr
+; CHECK-NEXT:    add z0.d, z0.d, #1 // =0x1
 ; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x8, z0.d, lsl #3]
 ; CHECK-NEXT:    ret
   %scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 1, i64 0
@@ -425,12 +425,13 @@ define void @masked_scatter_nxv2i64_const_with_vec_offsets(<vscale x 2 x i64> %v
   ret void
 }
 
-; TODO: The generated code is wrong because we've lost the scaling applied to
-; %scalar_offset when it's used to calculate %ptrs.
 define void @masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets(<vscale x 2 x i64> %vector_offsets, i64 %scalar_offset, <vscale x 2 x i1> %pg, <vscale x 2 x i64> %data) #0 {
 ; CHECK-LABEL: masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    st1d { z1.d }, p0, [x0, z0.d, lsl #3]
+; CHECK-NEXT:    mov x8, xzr
+; CHECK-NEXT:    mov z2.d, x0
+; CHECK-NEXT:    add z0.d, z0.d, z2.d
+; CHECK-NEXT:    st1d { z1.d }, p0, [x8, z0.d, lsl #3]
 ; CHECK-NEXT:    ret
   %scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 %scalar_offset, i64 0
   %scalar_offset.splat = shufflevector <vscale x 2 x i64> %scalar_offset.ins, <vscale x 2 x i64> undef, <vscale x 2 x i32> zeroinitializer
@@ -440,12 +441,11 @@ define void @masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets(<vscale x
   ret void
 }
 
-; TODO: The generated code is wrong because we've lost the scaling applied to
-; constant scalar offset (i.e. i64 1)  when it's used to calculate %ptrs.
 define void @masked_scatter_nxv2i64_null_with__vec_plus_imm_offsets(<vscale x 2 x i64> %vector_offsets, <vscale x 2 x i1> %pg, <vscale x 2 x i64> %data) #0 {
 ; CHECK-LABEL: masked_scatter_nxv2i64_null_with__vec_plus_imm_offsets:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov w8, #1
+; CHECK-NEXT:    mov x8, xzr
+; CHECK-NEXT:    add z0.d, z0.d, #1 // =0x1
 ; CHECK-NEXT:    st1d { z1.d }, p0, [x8, z0.d, lsl #3]
 ; CHECK-NEXT:    ret
   %scalar_offset.ins = insertelement <vscale x 2 x i64> undef, i64 1, i64 0


        


More information about the llvm-commits mailing list