[llvm] 96c8d61 - [SVE] Extend findMoreOptimalIndexType so BUILD_VECTORs do not force 64bit indices.

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 18 10:03:43 PDT 2022


Author: Paul Walker
Date: 2022-08-18T18:00:53+01:00
New Revision: 96c8d615d6660b0136b898303e78d1be0dd82108

URL: https://github.com/llvm/llvm-project/commit/96c8d615d6660b0136b898303e78d1be0dd82108
DIFF: https://github.com/llvm/llvm-project/commit/96c8d615d6660b0136b898303e78d1be0dd82108.diff

LOG: [SVE] Extend findMoreOptimalIndexType so BUILD_VECTORs do not force 64bit indices.

Extends findMoreOptimalIndexType to allow ISD::BUILD_VECTOR based
indices to be truncated when such truncation is lossless. This can
enable the use of 32bit gather/scatter indices thus making it less
likely to have to split a gather/scatter in two.

Depends on D125194

Differential Revision: https://reviews.llvm.org/D130533

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SelectionDAGNodes.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/sve-fixed-length-addressing-modes.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index fad63590c230c..7ebbc4895d7c9 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -118,6 +118,10 @@ bool isBuildVectorOfConstantSDNodes(const SDNode *N);
 /// ConstantFPSDNode or undef.
 bool isBuildVectorOfConstantFPSDNodes(const SDNode *N);
 
+/// Returns true if the specified node is a vector where all elements can
+/// be truncated to the specified element size without a loss in meaning.
+bool isVectorShrinkable(const SDNode *N, unsigned NewEltSize, bool Signed);
+
 /// Return true if the node has at least one operand and all operands of the
 /// specified node are ISD::UNDEF.
 bool allOperandsUndef(const SDNode *N);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index a89a3613a993d..d6688be839d69 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -291,6 +291,31 @@ bool ISD::isBuildVectorOfConstantFPSDNodes(const SDNode *N) {
   return true;
 }
 
+bool ISD::isVectorShrinkable(const SDNode *N, unsigned NewEltSize,
+                             bool Signed) {
+  if (N->getOpcode() != ISD::BUILD_VECTOR)
+    return false;
+
+  unsigned EltSize = N->getValueType(0).getScalarSizeInBits();
+  if (EltSize <= NewEltSize)
+    return false;
+
+  for (const SDValue &Op : N->op_values()) {
+    if (Op.isUndef())
+      continue;
+    if (!isa<ConstantSDNode>(Op))
+      return false;
+
+    APInt C = cast<ConstantSDNode>(Op)->getAPIntValue().trunc(EltSize);
+    if (Signed && C.trunc(NewEltSize).sext(EltSize) != C)
+      return false;
+    if (!Signed && C.trunc(NewEltSize).zext(EltSize) != C)
+      return false;
+  }
+
+  return true;
+}
+
 bool ISD::allOperandsUndef(const SDNode *N) {
   // Return false if the node has no operands.
   // This is "logically inconsistent" with the definition of "all" but

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index dda52248b6d7c..a1f423246db7a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17827,12 +17827,19 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
   if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64)
     return Changed;
 
+  // Can indices be trivially shrunk?
+  if (ISD::isVectorShrinkable(Index.getNode(), 32, N->isIndexSigned())) {
+    EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32);
+    Index = DAG.getNode(ISD::TRUNCATE, SDLoc(N), NewIndexVT, Index);
+    return true;
+  }
+
   // Match:
   //   Index = step(const)
   int64_t Stride = 0;
-  if (Index.getOpcode() == ISD::STEP_VECTOR)
+  if (Index.getOpcode() == ISD::STEP_VECTOR) {
     Stride = cast<ConstantSDNode>(Index.getOperand(0))->getSExtValue();
-
+  }
   // Match:
   //   Index = step(const) << shift(const)
   else if (Index.getOpcode() == ISD::SHL &&
@@ -17866,8 +17873,7 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
   EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32);
   // Stride does not scale explicitly by 'Scale', because it happens in
   // the gather/scatter addressing mode.
-  Index = DAG.getNode(ISD::STEP_VECTOR, SDLoc(N), NewIndexVT,
-                      DAG.getTargetConstant(Stride, SDLoc(N), MVT::i32));
+  Index = DAG.getStepVector(SDLoc(N), NewIndexVT, APInt(32, Stride));
   return true;
 }
 

diff  --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-addressing-modes.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-addressing-modes.ll
index d3ac40b706b34..2cf0b47a42114 100644
--- a/llvm/test/CodeGen/AArch64/sve-fixed-length-addressing-modes.ll
+++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-addressing-modes.ll
@@ -1,35 +1,17 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -aarch64-sve-vector-bits-min=256 < %s | FileCheck %s -check-prefixes=CHECK,VBITS_GE_256
-; RUN: llc -aarch64-sve-vector-bits-min=512 < %s | FileCheck %s -check-prefixes=CHECK,VBITS_GE_512
+; RUN: llc -aarch64-sve-vector-bits-min=256 < %s | FileCheck %s
+; RUN: llc -aarch64-sve-vector-bits-min=512 < %s | FileCheck %s
 
 target triple = "aarch64-unknown-linux-gnu"
 
 define void @masked_gather_base_plus_stride_v8f32(ptr %dst, ptr %src) #0 {
-; VBITS_GE_256-LABEL: masked_gather_base_plus_stride_v8f32:
-; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    index z0.d, #0, #7
-; VBITS_GE_256-NEXT:    ptrue p0.d, vl4
-; VBITS_GE_256-NEXT:    mov z1.d, z0.d
-; VBITS_GE_256-NEXT:    ld1w { z0.d }, p0/z, [x1, z0.d, lsl #2]
-; VBITS_GE_256-NEXT:    add z1.d, z1.d, #28 // =0x1c
-; VBITS_GE_256-NEXT:    ld1w { z1.d }, p0/z, [x1, z1.d, lsl #2]
-; VBITS_GE_256-NEXT:    ptrue p0.s, vl4
-; VBITS_GE_256-NEXT:    uzp1 z0.s, z0.s, z0.s
-; VBITS_GE_256-NEXT:    uzp1 z1.s, z1.s, z1.s
-; VBITS_GE_256-NEXT:    splice z0.s, p0, z0.s, z1.s
-; VBITS_GE_256-NEXT:    ptrue p0.s, vl8
-; VBITS_GE_256-NEXT:    st1w { z0.s }, p0, [x0]
-; VBITS_GE_256-NEXT:    ret
-;
-; VBITS_GE_512-LABEL: masked_gather_base_plus_stride_v8f32:
-; VBITS_GE_512:       // %bb.0:
-; VBITS_GE_512-NEXT:    index z0.d, #0, #7
-; VBITS_GE_512-NEXT:    ptrue p0.d, vl8
-; VBITS_GE_512-NEXT:    ld1w { z0.d }, p0/z, [x1, z0.d, lsl #2]
-; VBITS_GE_512-NEXT:    ptrue p0.s, vl8
-; VBITS_GE_512-NEXT:    uzp1 z0.s, z0.s, z0.s
-; VBITS_GE_512-NEXT:    st1w { z0.s }, p0, [x0]
-; VBITS_GE_512-NEXT:    ret
+; CHECK-LABEL: masked_gather_base_plus_stride_v8f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    index z0.s, #0, #7
+; CHECK-NEXT:    ptrue p0.s, vl8
+; CHECK-NEXT:    ld1w { z0.s }, p0/z, [x1, z0.s, sxtw #2]
+; CHECK-NEXT:    st1w { z0.s }, p0, [x0]
+; CHECK-NEXT:    ret
   %ptrs = getelementptr float, ptr %src, <8 x i64> <i64 0, i64 7, i64 14, i64 21, i64 28, i64 35, i64 42, i64 49>
   %data = tail call <8 x float> @llvm.masked.gather.v8f32.v8p0(<8 x ptr> %ptrs, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x float> undef)
   store <8 x float> %data, ptr %dst, align 4
@@ -52,30 +34,13 @@ define void @masked_gather_base_plus_stride_v4f64(ptr %dst, ptr %src) #0 {
 }
 
 define void @masked_scatter_base_plus_stride_v8f32(ptr %dst, ptr %src) #0 {
-; VBITS_GE_256-LABEL: masked_scatter_base_plus_stride_v8f32:
-; VBITS_GE_256:       // %bb.0:
-; VBITS_GE_256-NEXT:    ptrue p0.s, vl8
-; VBITS_GE_256-NEXT:    mov z1.d, #-28 // =0xffffffffffffffe4
-; VBITS_GE_256-NEXT:    ld1w { z0.s }, p0/z, [x1]
-; VBITS_GE_256-NEXT:    index z2.d, #0, #-7
-; VBITS_GE_256-NEXT:    add z1.d, z2.d, z1.d
-; VBITS_GE_256-NEXT:    ptrue p0.d, vl4
-; VBITS_GE_256-NEXT:    uunpklo z3.d, z0.s
-; VBITS_GE_256-NEXT:    ext z0.b, z0.b, z0.b, #16
-; VBITS_GE_256-NEXT:    uunpklo z0.d, z0.s
-; VBITS_GE_256-NEXT:    st1w { z3.d }, p0, [x0, z2.d, lsl #2]
-; VBITS_GE_256-NEXT:    st1w { z0.d }, p0, [x0, z1.d, lsl #2]
-; VBITS_GE_256-NEXT:    ret
-;
-; VBITS_GE_512-LABEL: masked_scatter_base_plus_stride_v8f32:
-; VBITS_GE_512:       // %bb.0:
-; VBITS_GE_512-NEXT:    ptrue p0.s, vl8
-; VBITS_GE_512-NEXT:    index z1.d, #0, #-7
-; VBITS_GE_512-NEXT:    ld1w { z0.s }, p0/z, [x1]
-; VBITS_GE_512-NEXT:    ptrue p0.d, vl8
-; VBITS_GE_512-NEXT:    uunpklo z0.d, z0.s
-; VBITS_GE_512-NEXT:    st1w { z0.d }, p0, [x0, z1.d, lsl #2]
-; VBITS_GE_512-NEXT:    ret
+; CHECK-LABEL: masked_scatter_base_plus_stride_v8f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s, vl8
+; CHECK-NEXT:    index z1.s, #0, #-7
+; CHECK-NEXT:    ld1w { z0.s }, p0/z, [x1]
+; CHECK-NEXT:    st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
+; CHECK-NEXT:    ret
   %data = load <8 x float>, ptr %src, align 4
   %ptrs = getelementptr float, ptr %dst, <8 x i64> <i64 0, i64 -7, i64 -14, i64 -21, i64 -28, i64 -35, i64 -42, i64 -49>
   tail call void @llvm.masked.scatter.v8f32.v8p0(<8 x float> %data, <8 x ptr> %ptrs, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>)


        


More information about the llvm-commits mailing list