[llvm] 0c44115 - [SVE] Add support for non-element-type sized scaling when lowering MGATHER/MSCATTER.

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 14 04:05:35 PDT 2022


Author: Paul Walker
Date: 2022-04-14T11:54:46+01:00
New Revision: 0c44115e5120167fc573e36dd878f4f95f5d63e6

URL: https://github.com/llvm/llvm-project/commit/0c44115e5120167fc573e36dd878f4f95f5d63e6
DIFF: https://github.com/llvm/llvm-project/commit/0c44115e5120167fc573e36dd878f4f95f5d63e6.diff

LOG: [SVE] Add support for non-element-type sized scaling when lowering MGATHER/MSCATTER.

The lowering code did not use the scale operand of MGATHER/MSCATTER
nodes, but instead assumed scaled indices were always scaled based
on the element type of the memory type. This patch adds the missing
support by rewritting the nodes as unscaled variants.

Differential Revision: https://reviews.llvm.org/D123670

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/ValueTypes.h
    llvm/include/llvm/Support/MachineValueType.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/sve-masked-gather.ll
    llvm/test/CodeGen/AArch64/sve-masked-scatter.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/ValueTypes.h b/llvm/include/llvm/CodeGen/ValueTypes.h
index 3daa9eae44699..48d265476ca8f 100644
--- a/llvm/include/llvm/CodeGen/ValueTypes.h
+++ b/llvm/include/llvm/CodeGen/ValueTypes.h
@@ -364,6 +364,12 @@ namespace llvm {
       return {(BaseSize.getKnownMinSize() + 7) / 8, BaseSize.isScalable()};
     }
 
+    // Return the number of bytes overwritten by a store of this value type or
+    // this value type's element type in the case of a vector.
+    uint64_t getScalarStoreSize() const {
+      return getScalarType().getStoreSize().getFixedSize();
+    }
+
     /// Return the number of bits overwritten by a store of the specified value
     /// type.
     ///

diff  --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h
index 643c2d8ce9817..2f4dcfd34af90 100644
--- a/llvm/include/llvm/Support/MachineValueType.h
+++ b/llvm/include/llvm/Support/MachineValueType.h
@@ -1078,6 +1078,12 @@ namespace llvm {
       return {(BaseSize.getKnownMinSize() + 7) / 8, BaseSize.isScalable()};
     }
 
+    // Return the number of bytes overwritten by a store of this value type or
+    // this value type's element type in the case of a vector.
+    uint64_t getScalarStoreSize() const {
+      return getScalarType().getStoreSize().getFixedSize();
+    }
+
     /// Return the number of bits overwritten by a store of the specified value
     /// type.
     ///

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index ae5ffd863f8c2..ff84699f29fef 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4400,9 +4400,14 @@ static bool getUniformBase(const Value *Ptr, SDValue &Base, SDValue &Index,
   Base = SDB->getValue(BasePtr);
   Index = SDB->getValue(IndexVal);
   IndexType = ISD::SIGNED_SCALED;
-  Scale = DAG.getTargetConstant(
-              DL.getTypeAllocSize(GEP->getResultElementType()),
-              SDB->getCurSDLoc(), TLI.getPointerTy(DL));
+
+  // MGATHER/MSCATTER only support scaling by a power-of-two.
+  uint64_t ScaleVal = DL.getTypeAllocSize(GEP->getResultElementType());
+  if (!isPowerOf2_64(ScaleVal))
+    return false;
+
+  Scale =
+      DAG.getTargetConstant(ScaleVal, SDB->getCurSDLoc(), TLI.getPointerTy(DL));
   return true;
 }
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b74583b6d9fb8..cc6aa4ee5a2b2 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4650,33 +4650,50 @@ void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index, EVT MemVT,
 
 SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
                                             SelectionDAG &DAG) const {
-  SDLoc DL(Op);
   MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Op);
-  assert(MGT && "Can only custom lower gather load nodes");
-
-  bool IsFixedLength = MGT->getMemoryVT().isFixedLengthVector();
 
-  SDValue Index = MGT->getIndex();
+  SDLoc DL(Op);
   SDValue Chain = MGT->getChain();
   SDValue PassThru = MGT->getPassThru();
   SDValue Mask = MGT->getMask();
   SDValue BasePtr = MGT->getBasePtr();
-  ISD::LoadExtType ExtTy = MGT->getExtensionType();
-
+  SDValue Index = MGT->getIndex();
+  SDValue Scale = MGT->getScale();
+  EVT VT = Op.getValueType();
+  EVT MemVT = MGT->getMemoryVT();
+  ISD::LoadExtType ExtType = MGT->getExtensionType();
   ISD::MemIndexType IndexType = MGT->getIndexType();
+
   bool IsScaled =
       IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED;
   bool IsSigned =
       IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED;
+
+  // SVE supports an index scaled by sizeof(MemVT.elt) only, everything else
+  // must be calculated before hand.
+  uint64_t ScaleVal = cast<ConstantSDNode>(Scale)->getZExtValue();
+  if (IsScaled && ScaleVal != MemVT.getScalarStoreSize()) {
+    assert(isPowerOf2_64(ScaleVal) && "Expecting power-of-two types");
+    EVT IndexVT = Index.getValueType();
+    Index = DAG.getNode(ISD::SHL, DL, IndexVT, Index,
+                        DAG.getConstant(Log2_32(ScaleVal), DL, IndexVT));
+    Scale = DAG.getTargetConstant(1, DL, Scale.getValueType());
+
+    SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
+    IndexType = IsSigned ? ISD::SIGNED_UNSCALED : ISD::UNSIGNED_UNSCALED;
+    return DAG.getMaskedGather(MGT->getVTList(), MemVT, DL, Ops,
+                               MGT->getMemOperand(), IndexType, ExtType);
+  }
+
   bool IdxNeedsExtend =
       getGatherScatterIndexIsExtended(Index) ||
       Index.getSimpleValueType().getVectorElementType() == MVT::i32;
 
-  EVT VT = PassThru.getSimpleValueType();
   EVT IndexVT = Index.getSimpleValueType();
-  EVT MemVT = MGT->getMemoryVT();
   SDValue InputVT = DAG.getValueType(MemVT);
 
+  bool IsFixedLength = MGT->getMemoryVT().isFixedLengthVector();
+
   if (IsFixedLength) {
     assert(Subtarget->useSVEForFixedLengthVectors() &&
            "Cannot lower when not using SVE for fixed vectors");
@@ -4714,7 +4731,7 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
   selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode,
                               /*isGather=*/true, DAG);
 
-  if (ExtTy == ISD::SEXTLOAD)
+  if (ExtType == ISD::SEXTLOAD)
     Opcode = getSignExtendedGatherOpcode(Opcode);
 
   if (IsFixedLength) {
@@ -4751,33 +4768,51 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
 
 SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
                                              SelectionDAG &DAG) const {
-  SDLoc DL(Op);
   MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(Op);
-  assert(MSC && "Can only custom lower scatter store nodes");
 
-  bool IsFixedLength = MSC->getMemoryVT().isFixedLengthVector();
-
-  SDValue Index = MSC->getIndex();
+  SDLoc DL(Op);
   SDValue Chain = MSC->getChain();
   SDValue StoreVal = MSC->getValue();
   SDValue Mask = MSC->getMask();
   SDValue BasePtr = MSC->getBasePtr();
-
+  SDValue Index = MSC->getIndex();
+  SDValue Scale = MSC->getScale();
+  EVT VT = StoreVal.getValueType();
+  EVT MemVT = MSC->getMemoryVT();
   ISD::MemIndexType IndexType = MSC->getIndexType();
+
   bool IsScaled =
       IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED;
   bool IsSigned =
       IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED;
+
+  // SVE supports an index scaled by sizeof(MemVT.elt) only, everything else
+  // must be calculated before hand.
+  uint64_t ScaleVal = cast<ConstantSDNode>(Scale)->getZExtValue();
+  if (IsScaled && ScaleVal != MemVT.getScalarStoreSize()) {
+    assert(isPowerOf2_64(ScaleVal) && "Expecting power-of-two types");
+    EVT IndexVT = Index.getValueType();
+    Index = DAG.getNode(ISD::SHL, DL, IndexVT, Index,
+                        DAG.getConstant(Log2_32(ScaleVal), DL, IndexVT));
+    Scale = DAG.getTargetConstant(1, DL, Scale.getValueType());
+
+    SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
+    IndexType = IsSigned ? ISD::SIGNED_UNSCALED : ISD::UNSIGNED_UNSCALED;
+    return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
+                                MSC->getMemOperand(), IndexType,
+                                MSC->isTruncatingStore());
+  }
+
   bool NeedsExtend =
       getGatherScatterIndexIsExtended(Index) ||
       Index.getSimpleValueType().getVectorElementType() == MVT::i32;
 
-  EVT VT = StoreVal.getSimpleValueType();
   EVT IndexVT = Index.getSimpleValueType();
   SDVTList VTs = DAG.getVTList(MVT::Other);
-  EVT MemVT = MSC->getMemoryVT();
   SDValue InputVT = DAG.getValueType(MemVT);
 
+  bool IsFixedLength = MSC->getMemoryVT().isFixedLengthVector();
+
   if (IsFixedLength) {
     assert(Subtarget->useSVEForFixedLengthVectors() &&
            "Cannot lower when not using SVE for fixed vectors");

diff  --git a/llvm/test/CodeGen/AArch64/sve-masked-gather.ll b/llvm/test/CodeGen/AArch64/sve-masked-gather.ll
index 784053be075fc..1df9cf935a7eb 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-gather.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-gather.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve -opaque-pointers < %s | FileCheck %s
 
 define <vscale x 2 x i64> @masked_gather_nxv2i8(<vscale x 2 x i8*> %ptrs, <vscale x 2 x i1> %mask) {
 ; CHECK-LABEL: masked_gather_nxv2i8:
@@ -127,6 +127,30 @@ define <vscale x 2 x i64> @masked_gather_passthru_0(<vscale x 2 x i32*> %ptrs, <
   ret <vscale x 2 x i64> %vals.sext
 }
 
+%i64_x3 = type { i64, i64, i64}
+define <vscale x 2 x i64> @masked_gather_non_power_of_two_based_scaling(ptr %base, <vscale x 2 x i64> %offsets, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_gather_non_power_of_two_based_scaling:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mul z0.d, z0.d, #24
+; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x0, z0.d]
+; CHECK-NEXT:    ret
+  %ptrs = getelementptr inbounds %i64_x3, ptr %base, <vscale x 2 x i64> %offsets
+  %vals = call <vscale x 2 x i64> @llvm.masked.gather.nxv2i64(<vscale x 2 x ptr> %ptrs, i32 8, <vscale x 2 x i1> %mask, <vscale x 2 x i64> undef)
+  ret <vscale x 2 x i64> %vals
+}
+
+%i64_x4 = type { i64, i64, i64, i64}
+define <vscale x 2 x i64> @masked_gather_non_element_type_based_scaling(ptr %base, <vscale x 2 x i64> %offsets, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_gather_non_element_type_based_scaling:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z0.d, z0.d, #5
+; CHECK-NEXT:    ld1d { z0.d }, p0/z, [x0, z0.d]
+; CHECK-NEXT:    ret
+  %ptrs = getelementptr inbounds %i64_x4, ptr %base, <vscale x 2 x i64> %offsets
+  %vals = call <vscale x 2 x i64> @llvm.masked.gather.nxv2i64(<vscale x 2 x ptr> %ptrs, i32 8, <vscale x 2 x i1> %mask, <vscale x 2 x i64> undef)
+  ret <vscale x 2 x i64> %vals
+}
+
 declare <vscale x 2 x i8> @llvm.masked.gather.nxv2i8(<vscale x 2 x i8*>, i32, <vscale x 2 x i1>, <vscale x 2 x i8>)
 declare <vscale x 2 x i16> @llvm.masked.gather.nxv2i16(<vscale x 2 x i16*>, i32, <vscale x 2 x i1>, <vscale x 2 x i16>)
 declare <vscale x 2 x i32> @llvm.masked.gather.nxv2i32(<vscale x 2 x i32*>, i32, <vscale x 2 x i1>, <vscale x 2 x i32>)

diff  --git a/llvm/test/CodeGen/AArch64/sve-masked-scatter.ll b/llvm/test/CodeGen/AArch64/sve-masked-scatter.ll
index da5e47823bee7..1941515486a09 100644
--- a/llvm/test/CodeGen/AArch64/sve-masked-scatter.ll
+++ b/llvm/test/CodeGen/AArch64/sve-masked-scatter.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve  < %s | FileCheck %s
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -opaque-pointers < %s | FileCheck %s
 
 define void @masked_scatter_nxv2i8(<vscale x 2 x i8> %data, <vscale x 2 x i8*> %ptrs, <vscale x 2 x i1> %masks) nounwind {
 ; CHECK-LABEL: masked_scatter_nxv2i8:
@@ -79,17 +79,41 @@ define void @masked_scatter_splat_constant_pointer (<vscale x 4 x i1> %pg) {
 ; CHECK-NEXT:    mov z0.d, #0 // =0x0
 ; CHECK-NEXT:    punpklo p1.h, p0.b
 ; CHECK-NEXT:    punpkhi p0.h, p0.b
-; CHECK-NEXT:    st1w { z0.d }, p1, [x8, z0.d, lsl #2]
-; CHECK-NEXT:    st1w { z0.d }, p0, [x8, z0.d, lsl #2]
+; CHECK-NEXT:    st1w { z0.d }, p1, [z0.d]
+; CHECK-NEXT:    st1w { z0.d }, p0, [z0.d]
 ; CHECK-NEXT:    ret
 vector.body:
   call void @llvm.masked.scatter.nxv4i32.nxv4p0i32(<vscale x 4 x i32> undef,
-    <vscale x 4 x i32*> shufflevector (<vscale x 4 x i32*> insertelement (<vscale x 4 x i32*> poison, i32* undef, i32 0), <vscale x 4 x i32*> poison, <vscale x 4 x i32> zeroinitializer),
+    <vscale x 4 x i32*> shufflevector (<vscale x 4 x i32*> insertelement (<vscale x 4 x i32*> poison, i32* null, i32 0), <vscale x 4 x i32*> poison, <vscale x 4 x i32> zeroinitializer),
     i32 4,
     <vscale x 4 x i1> %pg)
   ret void
 }
 
+%i64_x3 = type { i64, i64, i64 }
+define void @masked_scatter_non_power_of_two_based_scaling(<vscale x 2 x double> %data, ptr %base, <vscale x 2 x i64> %offsets, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_scatter_non_power_of_two_based_scaling:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mul z1.d, z1.d, #24
+; CHECK-NEXT:    st1d { z0.d }, p0, [x0, z1.d]
+; CHECK-NEXT:    ret
+  %ptrs = getelementptr inbounds %i64_x3, ptr %base, <vscale x 2 x i64> %offsets
+  call void @llvm.masked.scatter.nxv2f64(<vscale x 2 x double> %data, <vscale x 2 x ptr> %ptrs, i32 8, <vscale x 2 x i1> %mask)
+  ret void
+}
+
+%i64_x4 = type { i64, i64, i64, i64}
+define void @masked_scatter_non_element_type_based_scaling(<vscale x 2 x double> %data, ptr %base, <vscale x 2 x i64> %offsets, <vscale x 2 x i1> %mask) {
+; CHECK-LABEL: masked_scatter_non_element_type_based_scaling:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    lsl z1.d, z1.d, #5
+; CHECK-NEXT:    st1d { z0.d }, p0, [x0, z1.d]
+; CHECK-NEXT:    ret
+  %ptrs = getelementptr inbounds %i64_x4, ptr %base, <vscale x 2 x i64> %offsets
+  call void @llvm.masked.scatter.nxv2f64(<vscale x 2 x double> %data, <vscale x 2 x ptr> %ptrs, i32 8, <vscale x 2 x i1> %mask)
+  ret void
+}
+
 declare void @llvm.masked.scatter.nxv2f16(<vscale x 2 x half>, <vscale x 2 x half*>, i32, <vscale x 2 x i1>)
 declare void @llvm.masked.scatter.nxv2bf16(<vscale x 2 x bfloat>, <vscale x 2 x bfloat*>, i32, <vscale x 2 x i1>)
 declare void @llvm.masked.scatter.nxv2f32(<vscale x 2 x float>, <vscale x 2 x float*>, i32, <vscale x 2 x i1>)


        


More information about the llvm-commits mailing list