[llvm] 216f546 - [SVE] Refactor lowering for fixed length MGATHER/MSCATTER.
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Sat May 21 02:16:13 PDT 2022
Author: Paul Walker
Date: 2022-05-21T10:14:45+01:00
New Revision: 216f546c846ca69005de193f4d1eea78e0efb2c2
URL: https://github.com/llvm/llvm-project/commit/216f546c846ca69005de193f4d1eea78e0efb2c2
DIFF: https://github.com/llvm/llvm-project/commit/216f546c846ca69005de193f4d1eea78e0efb2c2.diff
LOG: [SVE] Refactor lowering for fixed length MGATHER/MSCATTER.
Lower fixed length MGATHER/MSCATTER operations to scalable vector
equivalents, which are then lowered to SVE specific nodes. This
two stage process is in preparation for making scalable vector
MGATHER/MSCATTER operations legal.
Differential Revision: https://reviews.llvm.org/D125192
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e3fa268732b30..f8b71ee642877 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4696,6 +4696,56 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
MGT->getMemOperand(), IndexType, ExtType);
}
+ // Lower fixed length gather to a scalable equivalent.
+ if (VT.isFixedLengthVector()) {
+ assert(Subtarget->useSVEForFixedLengthVectors() &&
+ "Cannot lower when not using SVE for fixed vectors!");
+
+ // NOTE: Handle floating-point as if integer then bitcast the result.
+ EVT DataVT = VT.changeVectorElementTypeToInteger();
+ MemVT = MemVT.changeVectorElementTypeToInteger();
+
+ // Find the smallest integer fixed length vector we can use for the gather.
+ EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
+ if (DataVT.getVectorElementType() == MVT::i64 ||
+ Index.getValueType().getVectorElementType() == MVT::i64 ||
+ Mask.getValueType().getVectorElementType() == MVT::i64)
+ PromotedVT = VT.changeVectorElementType(MVT::i64);
+
+ // Promote vector operands except for passthrough, which we know is either
+ // undef or zero, and thus best constructed directly.
+ unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+ Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
+ Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);
+
+ // A promoted result type forces the need for an extending load.
+ if (PromotedVT != DataVT && ExtType == ISD::NON_EXTLOAD)
+ ExtType = ISD::EXTLOAD;
+
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);
+
+ // Convert fixed length vector operands to scalable.
+ MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
+ Index = convertToScalableVector(DAG, ContainerVT, Index);
+ Mask = convertFixedMaskToScalableVector(Mask, DAG);
+ PassThru = PassThru->isUndef() ? DAG.getUNDEF(ContainerVT)
+ : DAG.getConstant(0, DL, ContainerVT);
+
+ // Emit equivalent scalable vector gather.
+ SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
+ SDValue Load =
+ DAG.getMaskedGather(DAG.getVTList(ContainerVT, MVT::Other), MemVT, DL,
+ Ops, MGT->getMemOperand(), IndexType, ExtType);
+
+ // Extract fixed length data then convert to the required result type.
+ SDValue Result = convertFromScalableVector(DAG, PromotedVT, Load);
+ Result = DAG.getNode(ISD::TRUNCATE, DL, DataVT, Result);
+ if (VT.isFloatingPoint())
+ Result = DAG.getNode(ISD::BITCAST, DL, VT, Result);
+
+ return DAG.getMergeValues({Result, Load.getValue(1)}, DL);
+ }
+
bool IdxNeedsExtend =
getGatherScatterIndexIsExtended(Index) ||
Index.getSimpleValueType().getVectorElementType() == MVT::i32;
@@ -4703,26 +4753,8 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
EVT IndexVT = Index.getSimpleValueType();
SDValue InputVT = DAG.getValueType(MemVT);
- bool IsFixedLength = MGT->getMemoryVT().isFixedLengthVector();
-
- if (IsFixedLength) {
- assert(Subtarget->useSVEForFixedLengthVectors() &&
- "Cannot lower when not using SVE for fixed vectors");
- if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) {
- IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
- MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
- } else {
- MemVT = getContainerForFixedLengthVector(DAG, MemVT);
- IndexVT = MemVT.changeTypeToInteger();
- }
- InputVT = DAG.getValueType(MemVT.changeTypeToInteger());
- Mask = DAG.getNode(
- ISD::SIGN_EXTEND, DL,
- VT.changeVectorElementType(IndexVT.getVectorElementType()), Mask);
- }
-
// Handle FP data by using an integer gather and casting the result.
- if (VT.isFloatingPoint() && !IsFixedLength)
+ if (VT.isFloatingPoint())
InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
SDVTList VTs = DAG.getVTList(IndexVT, MVT::Other);
@@ -4737,25 +4769,11 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
if (ExtType == ISD::SEXTLOAD)
Opcode = getSignExtendedGatherOpcode(Opcode);
- if (IsFixedLength) {
- if (Index.getSimpleValueType().isFixedLengthVector())
- Index = convertToScalableVector(DAG, IndexVT, Index);
- if (BasePtr.getSimpleValueType().isFixedLengthVector())
- BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr);
- Mask = convertFixedMaskToScalableVector(Mask, DAG);
- }
-
SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT};
SDValue Result = DAG.getNode(Opcode, DL, VTs, Ops);
Chain = Result.getValue(1);
- if (IsFixedLength) {
- Result = convertFromScalableVector(
- DAG, VT.changeVectorElementType(IndexVT.getVectorElementType()),
- Result);
- Result = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Result);
- Result = DAG.getNode(ISD::BITCAST, DL, VT, Result);
- } else if (VT.isFloatingPoint())
+ if (VT.isFloatingPoint())
Result = getSVESafeBitCast(VT, Result, DAG);
return DAG.getMergeValues({Result, Chain}, DL);
@@ -4775,6 +4793,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
EVT VT = StoreVal.getValueType();
EVT MemVT = MSC->getMemoryVT();
ISD::MemIndexType IndexType = MSC->getIndexType();
+ bool Truncating = MSC->isTruncatingStore();
bool IsScaled = MSC->isIndexScaled();
bool IsSigned = MSC->isIndexSigned();
@@ -4791,42 +4810,60 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
- MSC->getMemOperand(), IndexType,
- MSC->isTruncatingStore());
+ MSC->getMemOperand(), IndexType, Truncating);
+ }
+
+ // Lower fixed length scatter to a scalable equivalent.
+ if (VT.isFixedLengthVector()) {
+ assert(Subtarget->useSVEForFixedLengthVectors() &&
+ "Cannot lower when not using SVE for fixed vectors!");
+
+ // Once bitcast we treat floating-point scatters as if integer.
+ if (VT.isFloatingPoint()) {
+ VT = VT.changeVectorElementTypeToInteger();
+ MemVT = MemVT.changeVectorElementTypeToInteger();
+ StoreVal = DAG.getNode(ISD::BITCAST, DL, VT, StoreVal);
+ }
+
+ // Find the smallest integer fixed length vector we can use for the scatter.
+ EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
+ if (VT.getVectorElementType() == MVT::i64 ||
+ Index.getValueType().getVectorElementType() == MVT::i64 ||
+ Mask.getValueType().getVectorElementType() == MVT::i64)
+ PromotedVT = VT.changeVectorElementType(MVT::i64);
+
+ // Promote vector operands.
+ unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
+ Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
+ Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);
+ StoreVal = DAG.getNode(ISD::ANY_EXTEND, DL, PromotedVT, StoreVal);
+
+ // A promoted value type forces the need for a truncating store.
+ if (PromotedVT != VT)
+ Truncating = true;
+
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);
+
+ // Convert fixed length vector operands to scalable.
+ MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
+ Index = convertToScalableVector(DAG, ContainerVT, Index);
+ Mask = convertFixedMaskToScalableVector(Mask, DAG);
+ StoreVal = convertToScalableVector(DAG, ContainerVT, StoreVal);
+
+ // Emit equivalent scalable vector scatter.
+ SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
+ return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
+ MSC->getMemOperand(), IndexType, Truncating);
}
bool NeedsExtend =
getGatherScatterIndexIsExtended(Index) ||
Index.getSimpleValueType().getVectorElementType() == MVT::i32;
- EVT IndexVT = Index.getSimpleValueType();
SDVTList VTs = DAG.getVTList(MVT::Other);
SDValue InputVT = DAG.getValueType(MemVT);
- bool IsFixedLength = MSC->getMemoryVT().isFixedLengthVector();
-
- if (IsFixedLength) {
- assert(Subtarget->useSVEForFixedLengthVectors() &&
- "Cannot lower when not using SVE for fixed vectors");
- if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) {
- IndexVT = getContainerForFixedLengthVector(DAG, IndexVT);
- MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType());
- } else {
- MemVT = getContainerForFixedLengthVector(DAG, MemVT);
- IndexVT = MemVT.changeTypeToInteger();
- }
- InputVT = DAG.getValueType(MemVT.changeTypeToInteger());
-
- StoreVal =
- DAG.getNode(ISD::BITCAST, DL, VT.changeTypeToInteger(), StoreVal);
- StoreVal = DAG.getNode(
- ISD::ANY_EXTEND, DL,
- VT.changeVectorElementType(IndexVT.getVectorElementType()), StoreVal);
- StoreVal = convertToScalableVector(DAG, IndexVT, StoreVal);
- Mask = DAG.getNode(
- ISD::SIGN_EXTEND, DL,
- VT.changeVectorElementType(IndexVT.getVectorElementType()), Mask);
- } else if (VT.isFloatingPoint()) {
+ if (VT.isFloatingPoint()) {
// Handle FP data by casting the data so an integer scatter can be used.
EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
@@ -4840,14 +4877,6 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode,
/*isGather=*/false, DAG);
- if (IsFixedLength) {
- if (Index.getSimpleValueType().isFixedLengthVector())
- Index = convertToScalableVector(DAG, IndexVT, Index);
- if (BasePtr.getSimpleValueType().isFixedLengthVector())
- BasePtr = convertToScalableVector(DAG, IndexVT, BasePtr);
- Mask = convertFixedMaskToScalableVector(Mask, DAG);
- }
-
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT};
return DAG.getNode(Opcode, DL, VTs, Ops);
}
More information about the llvm-commits
mailing list