[llvm] 201e368 - [AArch64][SVE] Handle more cases in findMoreOptimalIndexType.

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 28 04:16:42 PST 2022


Author: Sander de Smalen
Date: 2022-02-28T12:13:52Z
New Revision: 201e3686ab4533213e6d72d192d4493002ffa679

URL: https://github.com/llvm/llvm-project/commit/201e3686ab4533213e6d72d192d4493002ffa679
DIFF: https://github.com/llvm/llvm-project/commit/201e3686ab4533213e6d72d192d4493002ffa679.diff

LOG: [AArch64][SVE] Handle more cases in findMoreOptimalIndexType.

This patch addresses @paulwalker-arm's comment on D117900 to
only update/write the by-ref operands iff the function returns
true. It also handles a few more cases where a series of added
offsets can be folded into the base pointer, rather than just looking
at a single offset.

Reviewed By: paulwalker-arm

Differential Revision: https://reviews.llvm.org/D119728

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9754df1a6a641..daadcc048b163 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -16476,55 +16476,90 @@ static SDValue performSTORECombine(SDNode *N,
   return SDValue();
 }
 
-// Analyse the specified address returning true if a more optimal addressing
-// mode is available. When returning true all parameters are updated to reflect
-// their recommended values.
-static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
-                                     SDValue &BasePtr, SDValue &Index,
-                                     ISD::MemIndexType &IndexType,
-                                     SelectionDAG &DAG) {
-  // Only consider element types that are pointer sized as smaller types can
-  // be easily promoted.
+/// \return true if part of the index was folded into the Base.
+static bool foldIndexIntoBase(SDValue &BasePtr, SDValue &Index, SDValue Scale,
+                              SDLoc DL, SelectionDAG &DAG) {
+  // This function assumes a vector of i64 indices.
   EVT IndexVT = Index.getValueType();
-  if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64)
+  if (!IndexVT.isVector() || IndexVT.getVectorElementType() != MVT::i64)
     return false;
 
-  int64_t Stride = 0;
-  SDLoc DL(N);
-  // Index = step(const) + splat(offset)
-  if (Index.getOpcode() == ISD::ADD &&
-      Index.getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
-    SDValue StepVector = Index.getOperand(0);
+  // Simplify:
+  //   BasePtr = Ptr
+  //   Index = X + splat(Offset)
+  // ->
+  //   BasePtr = Ptr + Offset * scale.
+  //   Index = X
+  if (Index.getOpcode() == ISD::ADD) {
     if (auto Offset = DAG.getSplatValue(Index.getOperand(1))) {
-      Stride = cast<ConstantSDNode>(StepVector.getOperand(0))->getSExtValue();
-      Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale());
+      Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, Scale);
       BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
+      Index = Index.getOperand(0);
+      return true;
     }
   }
 
-  // Index = shl((step(const) + splat(offset))), splat(shift))
+  // Simplify:
+  //   BasePtr = Ptr
+  //   Index = (X + splat(Offset)) << splat(Shift)
+  // ->
+  //   BasePtr = Ptr + (Offset << Shift) * scale)
+  //   Index = X << splat(shift)
   if (Index.getOpcode() == ISD::SHL &&
-      Index.getOperand(0).getOpcode() == ISD::ADD &&
-      Index.getOperand(0).getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
+      Index.getOperand(0).getOpcode() == ISD::ADD) {
     SDValue Add = Index.getOperand(0);
     SDValue ShiftOp = Index.getOperand(1);
-    SDValue StepOp = Add.getOperand(0);
     SDValue OffsetOp = Add.getOperand(1);
-    if (auto *Shift =
-            dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(ShiftOp)))
+    if (auto Shift = DAG.getSplatValue(ShiftOp))
       if (auto Offset = DAG.getSplatValue(OffsetOp)) {
-        int64_t Step =
-            cast<ConstantSDNode>(StepOp.getOperand(0))->getSExtValue();
-        // Stride does not scale explicitly by 'Scale', because it happens in
-        // the gather/scatter addressing mode.
-        Stride = Step << Shift->getSExtValue();
-        // BasePtr = BasePtr + ((Offset * Scale) << Shift)
-        Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale());
-        Offset = DAG.getNode(ISD::SHL, DL, MVT::i64, Offset, SDValue(Shift, 0));
+        Offset = DAG.getNode(ISD::SHL, DL, MVT::i64, Offset, Shift);
+        Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, Scale);
         BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
+        Index = DAG.getNode(ISD::SHL, DL, Index.getValueType(),
+                            Add.getOperand(0), ShiftOp);
+        return true;
       }
   }
 
+  return false;
+}
+
+// Analyse the specified address returning true if a more optimal addressing
+// mode is available. When returning true all parameters are updated to reflect
+// their recommended values.
+static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
+                                     SDValue &BasePtr, SDValue &Index,
+                                     SelectionDAG &DAG) {
+  // Only consider element types that are pointer sized as smaller types can
+  // be easily promoted.
+  EVT IndexVT = Index.getValueType();
+  if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64)
+    return false;
+
+  // Try to iteratively fold parts of the index into the base pointer to
+  // simplify the index as much as possible.
+  SDValue NewBasePtr = BasePtr, NewIndex = Index;
+  while (foldIndexIntoBase(NewBasePtr, NewIndex, N->getScale(), SDLoc(N), DAG))
+    ;
+
+  // Match:
+  //   Index = step(const)
+  int64_t Stride = 0;
+  if (NewIndex.getOpcode() == ISD::STEP_VECTOR)
+    Stride = cast<ConstantSDNode>(NewIndex.getOperand(0))->getSExtValue();
+
+  // Match:
+  //   Index = step(const) << shift(const)
+  else if (NewIndex.getOpcode() == ISD::SHL &&
+           NewIndex.getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
+    SDValue RHS = NewIndex.getOperand(1);
+    if (auto *Shift =
+            dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(RHS))) {
+      int64_t Step = (int64_t)NewIndex.getOperand(0).getConstantOperandVal(1);
+      Stride = Step << Shift->getZExtValue();
+    }
+  }
+
   // Return early because no supported pattern is found.
   if (Stride == 0)
     return false;
@@ -16545,8 +16580,11 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
     return false;
 
   EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32);
-  Index = DAG.getNode(ISD::STEP_VECTOR, DL, NewIndexVT,
-                      DAG.getTargetConstant(Stride, DL, MVT::i32));
+  // Stride does not scale explicitly by 'Scale', because it happens in
+  // the gather/scatter addressing mode.
+  Index = DAG.getNode(ISD::STEP_VECTOR, SDLoc(N), NewIndexVT,
+                      DAG.getTargetConstant(Stride, SDLoc(N), MVT::i32));
+  BasePtr = NewBasePtr;
   return true;
 }
 
@@ -16566,7 +16604,7 @@ static SDValue performMaskedGatherScatterCombine(
   SDValue BasePtr = MGS->getBasePtr();
   ISD::MemIndexType IndexType = MGS->getIndexType();
 
-  if (!findMoreOptimalIndexType(MGS, BasePtr, Index, IndexType, DAG))
+  if (!findMoreOptimalIndexType(MGS, BasePtr, Index, DAG))
     return SDValue();
 
   // Here we catch such cases early and change MGATHER's IndexType to allow

diff  --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
index dd5a1264f8700..e894291845d41 100644
--- a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
+++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll
@@ -283,7 +283,54 @@ define void @scatter_f16_index_offset_8([8 x half]* %base, i64 %offset, <vscale
   ret void
 }
 
+; stepvector is hidden further behind GEP and two adds.
+define void @scatter_f16_index_add_add([8 x half]* %base, i64 %offset, i64 %offset2, <vscale x 4 x i1> %pg, <vscale x 4 x half> %data) #0 {
+; CHECK-LABEL: scatter_f16_index_add_add:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #16
+; CHECK-NEXT:    add x9, x0, x2, lsl #4
+; CHECK-NEXT:    add x9, x9, x1, lsl #4
+; CHECK-NEXT:    index z1.s, #0, w8
+; CHECK-NEXT:    st1h { z0.s }, p0, [x9, z1.s, sxtw]
+; CHECK-NEXT:    ret
+  %splat.offset.ins = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+  %splat.offset = shufflevector <vscale x 4 x i64> %splat.offset.ins, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %splat.offset2.ins = insertelement <vscale x 4 x i64> undef, i64 %offset2, i32 0
+  %splat.offset2 = shufflevector <vscale x 4 x i64> %splat.offset2.ins, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+  %add1 = add <vscale x 4 x i64> %splat.offset, %step
+  %add2 = add <vscale x 4 x i64> %add1, %splat.offset2
+  %gep = getelementptr [8 x half], [8 x half]* %base, <vscale x 4 x i64> %add2
+  %gep.bc = bitcast <vscale x 4 x [8 x half]*> %gep to <vscale x 4 x half*>
+  call void @llvm.masked.scatter.nxv4f16(<vscale x 4 x half> %data, <vscale x 4 x half*> %gep.bc, i32 2, <vscale x 4 x i1> %pg)
+  ret void
+}
 
+; stepvector is hidden further behind GEP two adds and a shift.
+define void @scatter_f16_index_add_add_mul([8 x half]* %base, i64 %offset, i64 %offset2, <vscale x 4 x i1> %pg, <vscale x 4 x half> %data) #0 {
+; CHECK-LABEL: scatter_f16_index_add_add_mul:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, #128
+; CHECK-NEXT:    add x9, x0, x2, lsl #7
+; CHECK-NEXT:    add x9, x9, x1, lsl #7
+; CHECK-NEXT:    index z1.s, #0, w8
+; CHECK-NEXT:    st1h { z0.s }, p0, [x9, z1.s, sxtw]
+; CHECK-NEXT:    ret
+  %splat.offset.ins = insertelement <vscale x 4 x i64> undef, i64 %offset, i32 0
+  %splat.offset = shufflevector <vscale x 4 x i64> %splat.offset.ins, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %splat.offset2.ins = insertelement <vscale x 4 x i64> undef, i64 %offset2, i32 0
+  %splat.offset2 = shufflevector <vscale x 4 x i64> %splat.offset2.ins, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %step = call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+  %add1 = add <vscale x 4 x i64> %splat.offset, %step
+  %add2 = add <vscale x 4 x i64> %add1, %splat.offset2
+  %splat.const8.ins = insertelement <vscale x 4 x i64> undef, i64 8, i32 0
+  %splat.const8 = shufflevector <vscale x 4 x i64> %splat.const8.ins, <vscale x 4 x i64> undef, <vscale x 4 x i32> zeroinitializer
+  %mul = mul <vscale x 4 x i64> %add2, %splat.const8
+  %gep = getelementptr [8 x half], [8 x half]* %base, <vscale x 4 x i64> %mul
+  %gep.bc = bitcast <vscale x 4 x [8 x half]*> %gep to <vscale x 4 x half*>
+  call void @llvm.masked.scatter.nxv4f16(<vscale x 4 x half> %data, <vscale x 4 x half*> %gep.bc, i32 2, <vscale x 4 x i1> %pg)
+  ret void
+}
 attributes #0 = { "target-features"="+sve" vscale_range(1, 16) }
 
 declare <vscale x 4 x float> @llvm.masked.gather.nxv4f32(<vscale x 4 x float*>, i32, <vscale x 4 x i1>, <vscale x 4 x float>)


        


More information about the llvm-commits mailing list