[llvm] 216f546 - [SVE] Refactor lowering for fixed length MGATHER/MSCATTER.

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Sat May 21 02:16:13 PDT 2022


Author: Paul Walker
Date: 2022-05-21T10:14:45+01:00
New Revision: 216f546c846ca69005de193f4d1eea78e0efb2c2

URL: https://github.com/llvm/llvm-project/commit/216f546c846ca69005de193f4d1eea78e0efb2c2
DIFF: https://github.com/llvm/llvm-project/commit/216f546c846ca69005de193f4d1eea78e0efb2c2.diff

LOG: [SVE] Refactor lowering for fixed length MGATHER/MSCATTER.

Lower fixed length MGATHER/MSCATTER operations to scalable vector
equivalents, which are then lowered to SVE specific nodes. This
two stage process is in preparation for making scalable vector
MGATHER/MSCATTER operations legal.

Differential Revision: https://reviews.llvm.org/D125192

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e3fa268732b30..f8b71ee642877 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4696,6 +4696,56 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
                                MGT->getMemOperand(), IndexType, ExtType);
   }
 
+  // Lower fixed length gather to a scalable equivalent.
+  if (VT.isFixedLengthVector()) {
+    assert(Subtarget->useSVEForFixedLengthVectors() &&
+           "Cannot lower when not using SVE for fixed vectors!");
+
+    // NOTE: Handle floating-point as if integer then bitcast the result.
+    EVT DataVT = VT.changeVectorElementTypeToInteger();
+    MemVT = MemVT.changeVectorElementTypeToInteger();
+
+    // Find the smallest integer fixed length vector we can use for the gather.
+    EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
+    if (DataVT.getVectorElementType() == MVT::i64 ||
+        Index.getValueType().getVectorElementType() == MVT::i64 ||
+        Mask.getValueType().getVectorElementType() == MVT::i64)
+      PromotedVT = VT.changeVectorElementType(MVT::i64);
+
+    // Promote vector operands except for passthrough, which we know is either
+    // undef or zero, and thus best constructed directly.
+    unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+    Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
+    Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);
+
+    // A promoted result type forces the need for an extending load.
+    if (PromotedVT != DataVT && ExtType == ISD::NON_EXTLOAD)
+      ExtType = ISD::EXTLOAD;
+
+    EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);
+
+    // Convert fixed length vector operands to scalable.
+    MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
+    Index = convertToScalableVector(DAG, ContainerVT, Index);
+    Mask = convertFixedMaskToScalableVector(Mask, DAG);
+    PassThru = PassThru->isUndef() ? DAG.getUNDEF(ContainerVT)
+                                   : DAG.getConstant(0, DL, ContainerVT);
+
+    // Emit equivalent scalable vector gather.
+    SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
+    SDValue Load =
+        DAG.getMaskedGather(DAG.getVTList(ContainerVT, MVT::Other), MemVT, DL,
+                            Ops, MGT->getMemOperand(), IndexType, ExtType);
+
+    // Extract fixed length data then convert to the required result type.
+    SDValue Result = convertFromScalableVector(DAG, PromotedVT, Load);
+    Result = DAG.getNode(ISD::TRUNCATE, DL, DataVT, Result);
+    if (VT.isFloatingPoint())
+      Result = DAG.getNode(ISD::BITCAST, DL, VT, Result);
+
+    return DAG.getMergeValues({Result, Load.getValue(1)}, DL);
+  }
+
   bool IdxNeedsExtend =
       getGatherScatterIndexIsExtended(Index) ||
       Index.getSimpleValueType().getVectorElementType() == MVT::i32;
@@ -4703,26 +4753,8 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
   EVT IndexVT = Index.getSimpleValueType();
   SDValue InputVT = DAG.getValueType(MemVT);
 
-  bool IsFixedLength = MGT->getMemoryVT().isFixedLengthVector();
-
-  if (IsFixedLength) {
-    assert(Subtarget->useSVEForFixedLengthVectors() &&
-           "Cannot lower when not using SVE for fixed vectors");
-    if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) {
-      IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
-      MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
-    } else {
-      MemVT = getContainerForFixedLengthVector(DAG, MemVT);
-      IndexVT = MemVT.changeTypeToInteger();
-    }
-    InputVT = DAG.getValueType(MemVT.changeTypeToInteger());
-    Mask = DAG.getNode(
-        ISD::SIGN_EXTEND, DL,
-        VT.changeVectorElementType(IndexVT.getVectorElementType()), Mask);
-  }
-
   // Handle FP data by using an integer gather and casting the result.
-  if (VT.isFloatingPoint() && !IsFixedLength)
+  if (VT.isFloatingPoint())
     InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
 
   SDVTList VTs = DAG.getVTList(IndexVT, MVT::Other);
@@ -4737,25 +4769,11 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
   if (ExtType == ISD::SEXTLOAD)
     Opcode = getSignExtendedGatherOpcode(Opcode);
 
-  if (IsFixedLength) {
-    if (Index.getSimpleValueType().isFixedLengthVector())
-      Index = convertToScalableVector(DAG, IndexVT, Index);
-    if (BasePtr.getSimpleValueType().isFixedLengthVector())
-      BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr);
-    Mask = convertFixedMaskToScalableVector(Mask, DAG);
-  }
-
   SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT};
   SDValue Result = DAG.getNode(Opcode, DL, VTs, Ops);
   Chain = Result.getValue(1);
 
-  if (IsFixedLength) {
-    Result = convertFromScalableVector(
-        DAG, VT.changeVectorElementType(IndexVT.getVectorElementType()),
-        Result);
-    Result = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Result);
-    Result = DAG.getNode(ISD::BITCAST, DL, VT, Result);
-  } else if (VT.isFloatingPoint())
+  if (VT.isFloatingPoint())
     Result = getSVESafeBitCast(VT, Result, DAG);
 
   return DAG.getMergeValues({Result, Chain}, DL);
@@ -4775,6 +4793,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
   EVT VT = StoreVal.getValueType();
   EVT MemVT = MSC->getMemoryVT();
   ISD::MemIndexType IndexType = MSC->getIndexType();
+  bool Truncating = MSC->isTruncatingStore();
 
   bool IsScaled = MSC->isIndexScaled();
   bool IsSigned = MSC->isIndexSigned();
@@ -4791,42 +4810,60 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
 
     SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
     return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
-                                MSC->getMemOperand(), IndexType,
-                                MSC->isTruncatingStore());
+                                MSC->getMemOperand(), IndexType, Truncating);
+  }
+
+  // Lower fixed length scatter to a scalable equivalent.
+  if (VT.isFixedLengthVector()) {
+    assert(Subtarget->useSVEForFixedLengthVectors() &&
+           "Cannot lower when not using SVE for fixed vectors!");
+
+    // Once bitcast we treat floating-point scatters as if integer.
+    if (VT.isFloatingPoint()) {
+      VT = VT.changeVectorElementTypeToInteger();
+      MemVT = MemVT.changeVectorElementTypeToInteger();
+      StoreVal = DAG.getNode(ISD::BITCAST, DL, VT, StoreVal);
+    }
+
+    // Find the smallest integer fixed length vector we can use for the scatter.
+    EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
+    if (VT.getVectorElementType() == MVT::i64 ||
+        Index.getValueType().getVectorElementType() == MVT::i64 ||
+        Mask.getValueType().getVectorElementType() == MVT::i64)
+      PromotedVT = VT.changeVectorElementType(MVT::i64);
+
+    // Promote vector operands.
+    unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+    Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
+    Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);
+    StoreVal = DAG.getNode(ISD::ANY_EXTEND, DL, PromotedVT, StoreVal);
+
+    // A promoted value type forces the need for a truncating store.
+    if (PromotedVT != VT)
+      Truncating = true;
+
+    EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);
+
+    // Convert fixed length vector operands to scalable.
+    MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
+    Index = convertToScalableVector(DAG, ContainerVT, Index);
+    Mask = convertFixedMaskToScalableVector(Mask, DAG);
+    StoreVal = convertToScalableVector(DAG, ContainerVT, StoreVal);
+
+    // Emit equivalent scalable vector scatter.
+    SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
+    return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
+                                MSC->getMemOperand(), IndexType, Truncating);
   }
 
   bool NeedsExtend =
       getGatherScatterIndexIsExtended(Index) ||
       Index.getSimpleValueType().getVectorElementType() == MVT::i32;
 
-  EVT IndexVT = Index.getSimpleValueType();
   SDVTList VTs = DAG.getVTList(MVT::Other);
   SDValue InputVT = DAG.getValueType(MemVT);
 
-  bool IsFixedLength = MSC->getMemoryVT().isFixedLengthVector();
-
-  if (IsFixedLength) {
-    assert(Subtarget->useSVEForFixedLengthVectors() &&
-           "Cannot lower when not using SVE for fixed vectors");
-    if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) {
-      IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
-      MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
-    } else {
-      MemVT = getContainerForFixedLengthVector(DAG, MemVT);
-      IndexVT = MemVT.changeTypeToInteger();
-    }
-    InputVT = DAG.getValueType(MemVT.changeTypeToInteger());
-
-    StoreVal =
-        DAG.getNode(ISD::BITCAST, DL, VT.changeTypeToInteger(), StoreVal);
-    StoreVal = DAG.getNode(
-        ISD::ANY_EXTEND, DL,
-        VT.changeVectorElementType(IndexVT.getVectorElementType()), StoreVal);
-    StoreVal = convertToScalableVector(DAG, IndexVT, StoreVal);
-    Mask = DAG.getNode(
-        ISD::SIGN_EXTEND, DL,
-        VT.changeVectorElementType(IndexVT.getVectorElementType()), Mask);
-  } else if (VT.isFloatingPoint()) {
+  if (VT.isFloatingPoint()) {
     // Handle FP data by casting the data so an integer scatter can be used.
     EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
     StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
@@ -4840,14 +4877,6 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
   selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
                               /*isGather=*/false, DAG);
 
-  if (IsFixedLength) {
-    if (Index.getSimpleValueType().isFixedLengthVector())
-      Index = convertToScalableVector(DAG, IndexVT, Index);
-    if (BasePtr.getSimpleValueType().isFixedLengthVector())
-      BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr);
-    Mask = convertFixedMaskToScalableVector(Mask, DAG);
-  }
-
   SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT};
   return DAG.getNode(Opcode, DL, VTs, Ops);
 }


        


More information about the llvm-commits mailing list