[llvm] 6af1677 - [SVE][CodeGen] Fix scalable vector issues in DAGTypeLegalizer::GenWidenVectorStores

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 13 03:22:31 PDT 2020


Author: David Sherwood
Date: 2020-08-13T11:07:17+01:00
New Revision: 6af1677161fbcdedf8ba08e8ffd065c9451ae733

URL: https://github.com/llvm/llvm-project/commit/6af1677161fbcdedf8ba08e8ffd065c9451ae733
DIFF: https://github.com/llvm/llvm-project/commit/6af1677161fbcdedf8ba08e8ffd065c9451ae733.diff

LOG: [SVE][CodeGen] Fix scalable vector issues in DAGTypeLegalizer::GenWidenVectorStores

In DAGTypeLegalizer::GenWidenVectorStores the algorithm assumes it only
ever deals with fixed width types, hence the offsets for each individual
store never take 'vscale' into account. I've changed the main loop in
that function to use TypeSize instead of unsigned for tracking the
remaining store amount and offset increment. In addition, I've changed
the loop to use the new IncrementPointer helper function for updating
the addresses in each iteration, since this handles scalable vector
types.

Whilst fixing this function I also fixed a minor issue in
IncrementPointer whereby we were not adding the no-unsigned-wrap flag
for the add instruction in the same way as the fixed width case does.

Also, I've added a report_fatal_error in GenWidenVectorTruncStores,
since this code currently uses a sequence of element-by-element scalar
stores.

I've added new tests in

  CodeGen/AArch64/sve-intrinsics-stores.ll
  CodeGen/AArch64/sve-st1-addressing-mode-reg-imm.ll

for the changes in GenWidenVectorStores.

Differential Revision: https://reviews.llvm.org/D84937

Added: 
    

Modified: 
    llvm/include/llvm/Support/TypeSize.h
    llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
    llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll
    llvm/test/CodeGen/AArch64/sve-st1-addressing-mode-reg-imm.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/TypeSize.h b/llvm/include/llvm/Support/TypeSize.h
index 76564c401e8e..b8a3fa3b20c9 100644
--- a/llvm/include/llvm/Support/TypeSize.h
+++ b/llvm/include/llvm/Support/TypeSize.h
@@ -131,6 +131,20 @@ class TypeSize {
     return { MinSize / RHS, IsScalable };
   }
 
+  TypeSize &operator-=(TypeSize RHS) {
+    assert(IsScalable == RHS.IsScalable &&
+           "Subtraction using mixed scalable and fixed types");
+    MinSize -= RHS.MinSize;
+    return *this;
+  }
+
+  TypeSize &operator+=(TypeSize RHS) {
+    assert(IsScalable == RHS.IsScalable &&
+           "Addition using mixed scalable and fixed types");
+    MinSize += RHS.MinSize;
+    return *this;
+  }
+
   // Return the minimum size with the assumption that the size is exact.
   // Use in places where a scalable size doesn't make sense (e.g. non-vector
   // types, or vectors in backends which don't support scalable vectors).

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 183da49c2d40..1daa907bbf01 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -780,8 +780,8 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
 
   // Helper function for incrementing the pointer when splitting
   // memory operations
-  void IncrementPointer(MemSDNode *N, EVT MemVT,
-                        MachinePointerInfo &MPI, SDValue &Ptr);
+  void IncrementPointer(MemSDNode *N, EVT MemVT, MachinePointerInfo &MPI,
+                        SDValue &Ptr, uint64_t *ScaledOffset = nullptr);
 
   // Vector Result Splitting: <128 x ty> -> 2 x <64 x ty>.
   void SplitVectorResult(SDNode *N, unsigned ResNo);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 4485fb044f34..6fc31125575e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -984,16 +984,20 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
 }
 
 void DAGTypeLegalizer::IncrementPointer(MemSDNode *N, EVT MemVT,
-                                        MachinePointerInfo &MPI,
-                                        SDValue &Ptr) {
+                                        MachinePointerInfo &MPI, SDValue &Ptr,
+                                        uint64_t *ScaledOffset) {
   SDLoc DL(N);
   unsigned IncrementSize = MemVT.getSizeInBits().getKnownMinSize() / 8;
 
   if (MemVT.isScalableVector()) {
+    SDNodeFlags Flags;
     SDValue BytesIncrement = DAG.getVScale(
         DL, Ptr.getValueType(),
         APInt(Ptr.getValueSizeInBits().getFixedSize(), IncrementSize));
     MPI = MachinePointerInfo(N->getPointerInfo().getAddrSpace());
+    Flags.setNoUnsignedWrap(true);
+    if (ScaledOffset)
+      *ScaledOffset += IncrementSize;
     Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, BytesIncrement);
   } else {
     MPI = N->getPointerInfo().getWithOffset(IncrementSize);
@@ -4844,7 +4848,7 @@ static EVT FindMemType(SelectionDAG& DAG, const TargetLowering &TLI,
 
   // If we have one element to load/store, return it.
   EVT RetVT = WidenEltVT;
-  if (Width == WidenEltWidth)
+  if (!Scalable && Width == WidenEltWidth)
     return RetVT;
 
   // See if there is larger legal integer than the element type to load/store.
@@ -5139,55 +5143,62 @@ void DAGTypeLegalizer::GenWidenVectorStores(SmallVectorImpl<SDValue> &StChain,
   SDLoc dl(ST);
 
   EVT StVT = ST->getMemoryVT();
-  unsigned StWidth = StVT.getSizeInBits();
+  TypeSize StWidth = StVT.getSizeInBits();
   EVT ValVT = ValOp.getValueType();
-  unsigned ValWidth = ValVT.getSizeInBits();
+  TypeSize ValWidth = ValVT.getSizeInBits();
   EVT ValEltVT = ValVT.getVectorElementType();
-  unsigned ValEltWidth = ValEltVT.getSizeInBits();
+  unsigned ValEltWidth = ValEltVT.getSizeInBits().getFixedSize();
   assert(StVT.getVectorElementType() == ValEltVT);
+  assert(StVT.isScalableVector() == ValVT.isScalableVector() &&
+         "Mismatch between store and value types");
 
   int Idx = 0;          // current index to store
-  unsigned Offset = 0;  // offset from base to store
-  while (StWidth != 0) {
+
+  MachinePointerInfo MPI = ST->getPointerInfo();
+  uint64_t ScaledOffset = 0;
+  while (StWidth.isNonZero()) {
     // Find the largest vector type we can store with.
-    EVT NewVT = FindMemType(DAG, TLI, StWidth, ValVT);
-    unsigned NewVTWidth = NewVT.getSizeInBits();
-    unsigned Increment = NewVTWidth / 8;
+    EVT NewVT = FindMemType(DAG, TLI, StWidth.getKnownMinSize(), ValVT);
+    TypeSize NewVTWidth = NewVT.getSizeInBits();
+
     if (NewVT.isVector()) {
-      unsigned NumVTElts = NewVT.getVectorNumElements();
+      unsigned NumVTElts = NewVT.getVectorMinNumElements();
       do {
+        Align NewAlign = ScaledOffset == 0
+                             ? ST->getOriginalAlign()
+                             : commonAlignment(ST->getAlign(), ScaledOffset);
         SDValue EOp = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NewVT, ValOp,
                                   DAG.getVectorIdxConstant(Idx, dl));
-        StChain.push_back(DAG.getStore(
-            Chain, dl, EOp, BasePtr, ST->getPointerInfo().getWithOffset(Offset),
-            ST->getOriginalAlign(), MMOFlags, AAInfo));
+        SDValue PartStore = DAG.getStore(Chain, dl, EOp, BasePtr, MPI, NewAlign,
+                                         MMOFlags, AAInfo);
+        StChain.push_back(PartStore);
+
         StWidth -= NewVTWidth;
-        Offset += Increment;
         Idx += NumVTElts;
 
-        BasePtr =
-            DAG.getObjectPtrOffset(dl, BasePtr, TypeSize::Fixed(Increment));
-      } while (StWidth != 0 && StWidth >= NewVTWidth);
+        IncrementPointer(cast<StoreSDNode>(PartStore), NewVT, MPI, BasePtr,
+                         &ScaledOffset);
+      } while (StWidth.isNonZero() && StWidth >= NewVTWidth);
     } else {
       // Cast the vector to the scalar type we can store.
-      unsigned NumElts = ValWidth / NewVTWidth;
+      unsigned NumElts = ValWidth.getFixedSize() / NewVTWidth.getFixedSize();
       EVT NewVecVT = EVT::getVectorVT(*DAG.getContext(), NewVT, NumElts);
       SDValue VecOp = DAG.getNode(ISD::BITCAST, dl, NewVecVT, ValOp);
       // Readjust index position based on new vector type.
-      Idx = Idx * ValEltWidth / NewVTWidth;
+      Idx = Idx * ValEltWidth / NewVTWidth.getFixedSize();
       do {
         SDValue EOp = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, NewVT, VecOp,
                                   DAG.getVectorIdxConstant(Idx++, dl));
-        StChain.push_back(DAG.getStore(
-            Chain, dl, EOp, BasePtr, ST->getPointerInfo().getWithOffset(Offset),
-            ST->getOriginalAlign(), MMOFlags, AAInfo));
+        SDValue PartStore =
+            DAG.getStore(Chain, dl, EOp, BasePtr, MPI, ST->getOriginalAlign(),
+                         MMOFlags, AAInfo);
+        StChain.push_back(PartStore);
+
         StWidth -= NewVTWidth;
-        Offset += Increment;
-        BasePtr =
-            DAG.getObjectPtrOffset(dl, BasePtr, TypeSize::Fixed(Increment));
-      } while (StWidth != 0 && StWidth >= NewVTWidth);
+        IncrementPointer(cast<StoreSDNode>(PartStore), NewVT, MPI, BasePtr);
+      } while (StWidth.isNonZero() && StWidth >= NewVTWidth);
       // Restore index back to be relative to the original widen element type.
-      Idx = Idx * NewVTWidth / ValEltWidth;
+      Idx = Idx * NewVTWidth.getFixedSize() / ValEltWidth;
     }
   }
 }
@@ -5210,8 +5221,13 @@ DAGTypeLegalizer::GenWidenVectorTruncStores(SmallVectorImpl<SDValue> &StChain,
   // It must be true that the wide vector type is bigger than where we need to
   // store.
   assert(StVT.isVector() && ValOp.getValueType().isVector());
+  assert(StVT.isScalableVector() == ValOp.getValueType().isScalableVector());
   assert(StVT.bitsLT(ValOp.getValueType()));
 
+  if (StVT.isScalableVector())
+    report_fatal_error("Generating widen scalable vector truncating stores not "
+                       "yet supported");
+
   // For truncating stores, we can not play the tricks of chopping legal vector
   // types and bitcast it to the right type. Instead, we unroll the store.
   EVT StEltVT  = StVT.getVectorElementType();

diff  --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll
index d26ab2980ccc..92877233b2c9 100644
--- a/llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll
+++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-stores.ll
@@ -437,6 +437,69 @@ define void @stnt1d_f64(<vscale x 2 x double> %data, <vscale x 2 x i1> %pred, do
 }
 
 
+; Stores (tuples)
+
+define void @store_i64_tuple3(<vscale x 6 x i64>* %out, <vscale x 2 x i64> %in1, <vscale x 2 x i64> %in2, <vscale x 2 x i64> %in3) {
+; CHECK-LABEL: store_i64_tuple3
+; CHECK:      st1d { z2.d }, p0, [x0, #2, mul vl]
+; CHECK-NEXT: st1d { z1.d }, p0, [x0, #1, mul vl]
+; CHECK-NEXT: st1d { z0.d }, p0, [x0]
+  %tuple = tail call <vscale x 6 x i64> @llvm.aarch64.sve.tuple.create3.nxv6i64.nxv2i64(<vscale x 2 x i64> %in1, <vscale x 2 x i64> %in2, <vscale x 2 x i64> %in3)
+  store <vscale x 6 x i64> %tuple, <vscale x 6 x i64>* %out
+  ret void
+}
+
+define void @store_i64_tuple4(<vscale x 8 x i64>* %out, <vscale x 2 x i64> %in1, <vscale x 2 x i64> %in2, <vscale x 2 x i64> %in3, <vscale x 2 x i64> %in4) {
+; CHECK-LABEL: store_i64_tuple4
+; CHECK:      st1d { z3.d }, p0, [x0, #3, mul vl]
+; CHECK-NEXT: st1d { z2.d }, p0, [x0, #2, mul vl]
+; CHECK-NEXT: st1d { z1.d }, p0, [x0, #1, mul vl]
+; CHECK-NEXT: st1d { z0.d }, p0, [x0]
+  %tuple = tail call <vscale x 8 x i64> @llvm.aarch64.sve.tuple.create4.nxv8i64.nxv2i64(<vscale x 2 x i64> %in1, <vscale x 2 x i64> %in2, <vscale x 2 x i64> %in3, <vscale x 2 x i64> %in4)
+  store <vscale x 8 x i64> %tuple, <vscale x 8 x i64>* %out
+  ret void
+}
+
+define void @store_i16_tuple2(<vscale x 16 x i16>* %out, <vscale x 8 x i16> %in1, <vscale x 8 x i16> %in2) {
+; CHECK-LABEL: store_i16_tuple2
+; CHECK:      st1h { z1.h }, p0, [x0, #1, mul vl]
+; CHECK-NEXT: st1h { z0.h }, p0, [x0]
+  %tuple = tail call <vscale x 16 x i16> @llvm.aarch64.sve.tuple.create2.nxv16i16.nxv8i16(<vscale x 8 x i16> %in1, <vscale x 8 x i16> %in2)
+  store <vscale x 16 x i16> %tuple, <vscale x 16 x i16>* %out
+  ret void
+}
+
+define void @store_i16_tuple3(<vscale x 24 x i16>* %out, <vscale x 8 x i16> %in1, <vscale x 8 x i16> %in2, <vscale x 8 x i16> %in3) {
+; CHECK-LABEL: store_i16_tuple3
+; CHECK:      st1h { z2.h }, p0, [x0, #2, mul vl]
+; CHECK-NEXT: st1h { z1.h }, p0, [x0, #1, mul vl]
+; CHECK-NEXT: st1h { z0.h }, p0, [x0]
+  %tuple = tail call <vscale x 24 x i16> @llvm.aarch64.sve.tuple.create3.nxv24i16.nxv8i16(<vscale x 8 x i16> %in1, <vscale x 8 x i16> %in2, <vscale x 8 x i16> %in3)
+  store <vscale x 24 x i16> %tuple, <vscale x 24 x i16>* %out
+  ret void
+}
+
+define void @store_f32_tuple3(<vscale x 12 x float>* %out, <vscale x 4 x float> %in1, <vscale x 4 x float> %in2, <vscale x 4 x float> %in3) {
+; CHECK-LABEL: store_f32_tuple3
+; CHECK:      st1w { z2.s }, p0, [x0, #2, mul vl]
+; CHECK-NEXT: st1w { z1.s }, p0, [x0, #1, mul vl]
+; CHECK-NEXT: st1w { z0.s }, p0, [x0]
+  %tuple = tail call <vscale x 12 x float> @llvm.aarch64.sve.tuple.create3.nxv12f32.nxv4f32(<vscale x 4 x float> %in1, <vscale x 4 x float> %in2, <vscale x 4 x float> %in3)
+  store <vscale x 12 x float> %tuple, <vscale x 12 x float>* %out
+  ret void
+}
+
+define void @store_f32_tuple4(<vscale x 16 x float>* %out, <vscale x 4 x float> %in1, <vscale x 4 x float> %in2, <vscale x 4 x float> %in3, <vscale x 4 x float> %in4) {
+; CHECK-LABEL: store_f32_tuple4
+; CHECK:      st1w { z3.s }, p0, [x0, #3, mul vl]
+; CHECK-NEXT: st1w { z2.s }, p0, [x0, #2, mul vl]
+; CHECK-NEXT: st1w { z1.s }, p0, [x0, #1, mul vl]
+; CHECK-NEXT: st1w { z0.s }, p0, [x0]
+  %tuple = tail call <vscale x 16 x float> @llvm.aarch64.sve.tuple.create4.nxv16f32.nxv4f32(<vscale x 4 x float> %in1, <vscale x 4 x float> %in2, <vscale x 4 x float> %in3, <vscale x 4 x float> %in4)
+  store <vscale x 16 x float> %tuple, <vscale x 16 x float>* %out
+  ret void
+}
+
 declare void @llvm.aarch64.sve.st2.nxv16i8(<vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i1>, i8*)
 declare void @llvm.aarch64.sve.st2.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i1>, i16*)
 declare void @llvm.aarch64.sve.st2.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32*)
@@ -473,5 +536,14 @@ declare void @llvm.aarch64.sve.stnt1.nxv8bf16(<vscale x 8 x bfloat>, <vscale x 8
 declare void @llvm.aarch64.sve.stnt1.nxv4f32(<vscale x 4 x float>, <vscale x 4 x i1>, float*)
 declare void @llvm.aarch64.sve.stnt1.nxv2f64(<vscale x 2 x double>, <vscale x 2 x i1>, double*)
 
+declare <vscale x 6 x i64> @llvm.aarch64.sve.tuple.create3.nxv6i64.nxv2i64(<vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i64>)
+declare <vscale x 8 x i64> @llvm.aarch64.sve.tuple.create4.nxv8i64.nxv2i64(<vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i64>)
+
+declare <vscale x 16 x i16> @llvm.aarch64.sve.tuple.create2.nxv16i16.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16>)
+declare <vscale x 24 x i16> @llvm.aarch64.sve.tuple.create3.nxv24i16.nxv8i16(<vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>)
+
+declare <vscale x 12 x float> @llvm.aarch64.sve.tuple.create3.nxv12f32.nxv4f32(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>)
+declare <vscale x 16 x float> @llvm.aarch64.sve.tuple.create4.nxv16f32.nxv4f32(<vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>, <vscale x 4 x float>)
+
 ; +bf16 is required for the bfloat version.
 attributes #0 = { "target-features"="+sve,+bf16" }

diff  --git a/llvm/test/CodeGen/AArch64/sve-st1-addressing-mode-reg-imm.ll b/llvm/test/CodeGen/AArch64/sve-st1-addressing-mode-reg-imm.ll
index e24db77b682e..cf1aef43c0ef 100644
--- a/llvm/test/CodeGen/AArch64/sve-st1-addressing-mode-reg-imm.ll
+++ b/llvm/test/CodeGen/AArch64/sve-st1-addressing-mode-reg-imm.ll
@@ -133,3 +133,37 @@ define void @store_nxv4f16(<vscale x 4 x half>* %out) {
   store <vscale x 4 x half> %splat, <vscale x 4 x half>* %out
   ret void
 }
+
+; Splat stores of unusual FP scalable vector types
+
+define void @store_nxv6f32(<vscale x 6 x float>* %out) {
+; CHECK-LABEL: store_nxv6f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    fmov z0.s, #1.00000000
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    st1w { z0.s }, p0, [x0]
+; CHECK-NEXT:    uunpklo z0.d, z0.s
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    st1w { z0.d }, p0, [x0, #2, mul vl]
+; CHECK-NEXT:    ret
+  %ins = insertelement <vscale x 6 x float> undef, float 1.0, i32 0
+  %splat = shufflevector <vscale x 6 x float> %ins, <vscale x 6 x float> undef, <vscale x 6 x i32> zeroinitializer
+  store <vscale x 6 x float> %splat, <vscale x 6 x float>* %out
+  ret void
+}
+
+define void @store_nxv12f16(<vscale x 12 x half>* %out) {
+; CHECK-LABEL: store_nxv12f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    fmov z0.h, #1.00000000
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    st1h { z0.h }, p0, [x0]
+; CHECK-NEXT:    uunpklo z0.s, z0.h
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    st1h { z0.s }, p0, [x0, #2, mul vl]
+; CHECK-NEXT:    ret
+  %ins = insertelement <vscale x 12 x half> undef, half 1.0, i32 0
+  %splat = shufflevector <vscale x 12 x half> %ins, <vscale x 12 x half> undef, <vscale x 12 x i32> zeroinitializer
+  store <vscale x 12 x half> %splat, <vscale x 12 x half>* %out
+  ret void
+}


        


More information about the llvm-commits mailing list