[llvm] 0bc3993 - [SelectionDAG] Add an ISD node for for get.active.lane.mask (#139084)
via llvm-commits
llvm-commits at lists.llvm.org
Thu May 15 01:14:50 PDT 2025
Author: Kerry McLaughlin
Date: 2025-05-15T09:14:46+01:00
New Revision: 0bc39937164f09823c906926d7eefd7a8bcb5161
URL: https://github.com/llvm/llvm-project/commit/0bc39937164f09823c906926d7eefd7a8bcb5161
DIFF: https://github.com/llvm/llvm-project/commit/0bc39937164f09823c906926d7eefd7a8bcb5161.diff
LOG: [SelectionDAG] Add an ISD node for for get.active.lane.mask (#139084)
For now expansion still happens in SelectionDAGBuilder when
GET_ACTIVE_LANE_MASK is not legal on the target.
This patch also includes changes in AArch64ISelLowering to replace
handling of the get.active.lane.mask intrinsic to use the ISD node.
Tablegen patterns are added which match to whilelo for scalable types.
A follow up change will add support for more types to be lowered to
GET_ACTIVE_LANE_MASK by allowing splitting of the node.
Added:
Modified:
llvm/include/llvm/CodeGen/ISDOpcodes.h
llvm/include/llvm/Target/TargetSelectionDAG.td
llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 1042318343987..9f66402e4c820 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1533,6 +1533,15 @@ enum NodeType {
// Operands: Mask
VECTOR_FIND_LAST_ACTIVE,
+ // GET_ACTIVE_LANE_MASK - this corrosponds to the llvm.get.active.lane.mask
+ // intrinsic. It creates a mask representing active and inactive vector
+ // lanes, active while Base + index < Trip Count. As with the intrinsic,
+ // the operands Base and Trip Count have the same scalar integer type and
+ // the internal addition of Base + index cannot overflow. However, the ISD
+ // node supports result types which are wider than i1, where the high
+ // bits conform to getBooleanContents similar to the SETCC operator.
+ GET_ACTIVE_LANE_MASK,
+
// llvm.clear_cache intrinsic
// Operands: Input Chain, Start Addres, End Address
// Outputs: Output Chain
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index b28a8b118de7a..406baa4f5fdaa 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -860,6 +860,12 @@ def find_last_active
: SDNode<"ISD::VECTOR_FIND_LAST_ACTIVE",
SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<1>]>, []>;
+def get_active_lane_mask
+ : SDNode<
+ "ISD::GET_ACTIVE_LANE_MASK",
+ SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<1, 2>]>,
+ []>;
+
// Nodes for intrinsics, you should use the intrinsic itself and let tblgen use
// these internally. Don't reference these directly.
def intrinsic_void : SDNode<"ISD::INTRINSIC_VOID",
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 25e74a2ae5b71..9ccd6a4d1684c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -160,6 +160,10 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
Res = PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(N);
break;
+ case ISD::GET_ACTIVE_LANE_MASK:
+ Res = PromoteIntRes_GET_ACTIVE_LANE_MASK(N);
+ break;
+
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
Res = PromoteIntRes_PARTIAL_REDUCE_MLA(N);
@@ -6222,6 +6226,12 @@ SDValue DAGTypeLegalizer::PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N) {
return DAG.getNode(ISD::VECTOR_FIND_LAST_ACTIVE, SDLoc(N), NVT, N->ops());
}
+SDValue DAGTypeLegalizer::PromoteIntRes_GET_ACTIVE_LANE_MASK(SDNode *N) {
+ EVT VT = N->getValueType(0);
+ EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
+ return DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, SDLoc(N), NVT, N->ops());
+}
+
SDValue DAGTypeLegalizer::PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 9e7a4030d24ec..cf3a9e23f4878 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -379,6 +379,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntRes_IS_FPCLASS(SDNode *N);
SDValue PromoteIntRes_PATCHPOINT(SDNode *N);
SDValue PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N);
+ SDValue PromoteIntRes_GET_ACTIVE_LANE_MASK(SDNode *N);
SDValue PromoteIntRes_PARTIAL_REDUCE_MLA(SDNode *N);
// Integer Operand Promotion.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index d99bd230c7861..3ebd3a4b88097 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -7987,14 +7987,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
case Intrinsic::get_active_lane_mask: {
EVT CCVT = TLI.getValueType(DAG.getDataLayout(), I.getType());
SDValue Index = getValue(I.getOperand(0));
+ SDValue TripCount = getValue(I.getOperand(1));
EVT ElementVT = Index.getValueType();
if (!TLI.shouldExpandGetActiveLaneMask(CCVT, ElementVT)) {
- visitTargetIntrinsic(I, Intrinsic);
+ setValue(&I, DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, sdl, CCVT, Index,
+ TripCount));
return;
}
- SDValue TripCount = getValue(I.getOperand(1));
EVT VecTy = EVT::getVectorVT(*DAG.getContext(), ElementVT,
CCVT.getVectorElementCount());
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 6f846bedf3c82..803894e298dd5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -577,6 +577,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::VECTOR_FIND_LAST_ACTIVE:
return "find_last_active";
+ case ISD::GET_ACTIVE_LANE_MASK:
+ return "get_active_lane_mask";
+
case ISD::PARTIAL_REDUCE_UMLA:
return "partial_reduce_umla";
case ISD::PARTIAL_REDUCE_SMLA:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 16f53aa3f1421..fb7f7d6f7537d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -290,6 +290,7 @@ static bool isZeroingInactiveLanes(SDValue Op) {
return false;
// We guarantee i1 splat_vectors to zero the other lanes
case ISD::SPLAT_VECTOR:
+ case ISD::GET_ACTIVE_LANE_MASK:
case AArch64ISD::PTRUE:
case AArch64ISD::SETCC_MERGE_ZERO:
return true;
@@ -1178,6 +1179,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::CTLZ);
+ setTargetDAGCombine(ISD::GET_ACTIVE_LANE_MASK);
+
setTargetDAGCombine(ISD::VECREDUCE_AND);
setTargetDAGCombine(ISD::VECREDUCE_OR);
setTargetDAGCombine(ISD::VECREDUCE_XOR);
@@ -1493,8 +1496,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
}
- for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1})
+ for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) {
setOperationAction(ISD::VECTOR_FIND_LAST_ACTIVE, VT, Legal);
+ setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Legal);
+ }
+
+ for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32})
+ setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom);
}
if (Subtarget->isSVEorStreamingSVEAvailable()) {
@@ -5731,21 +5739,24 @@ static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT,
DAG.getTargetConstant(Pattern, DL, MVT::i32));
}
-static SDValue optimizeIncrementingWhile(SDValue Op, SelectionDAG &DAG,
+static SDValue optimizeIncrementingWhile(SDNode *N, SelectionDAG &DAG,
bool IsSigned, bool IsEqual) {
- if (!isa<ConstantSDNode>(Op.getOperand(1)) ||
- !isa<ConstantSDNode>(Op.getOperand(2)))
+ unsigned Op0 = N->getOpcode() == ISD::INTRINSIC_WO_CHAIN ? 1 : 0;
+ unsigned Op1 = N->getOpcode() == ISD::INTRINSIC_WO_CHAIN ? 2 : 1;
+
+ if (!isa<ConstantSDNode>(N->getOperand(Op0)) ||
+ !isa<ConstantSDNode>(N->getOperand(Op1)))
return SDValue();
- SDLoc dl(Op);
- APInt X = Op.getConstantOperandAPInt(1);
- APInt Y = Op.getConstantOperandAPInt(2);
+ SDLoc dl(N);
+ APInt X = N->getConstantOperandAPInt(Op0);
+ APInt Y = N->getConstantOperandAPInt(Op1);
// When the second operand is the maximum value, comparisons that include
// equality can never fail and thus we can return an all active predicate.
if (IsEqual)
if (IsSigned ? Y.isMaxSignedValue() : Y.isMaxValue())
- return DAG.getConstant(1, dl, Op.getValueType());
+ return DAG.getConstant(1, dl, N->getValueType(0));
bool Overflow;
APInt NumActiveElems =
@@ -5766,10 +5777,10 @@ static SDValue optimizeIncrementingWhile(SDValue Op, SelectionDAG &DAG,
getSVEPredPatternFromNumElements(NumActiveElems.getZExtValue());
unsigned MinSVEVectorSize = std::max(
DAG.getSubtarget<AArch64Subtarget>().getMinSVEVectorSizeInBits(), 128u);
- unsigned ElementSize = 128 / Op.getValueType().getVectorMinNumElements();
+ unsigned ElementSize = 128 / N->getValueType(0).getVectorMinNumElements();
if (PredPattern != std::nullopt &&
NumActiveElems.getZExtValue() <= (MinSVEVectorSize / ElementSize))
- return getPTrue(DAG, dl, Op.getValueType(), *PredPattern);
+ return getPTrue(DAG, dl, N->getValueType(0), *PredPattern);
return SDValue();
}
@@ -6221,17 +6232,14 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
DAG.getNode(
AArch64ISD::URSHR_I, dl, Op.getOperand(1).getValueType(), Op.getOperand(1), Op.getOperand(2)));
return SDValue();
- case Intrinsic::aarch64_sve_whilelo:
- return optimizeIncrementingWhile(Op, DAG, /*IsSigned=*/false,
- /*IsEqual=*/false);
case Intrinsic::aarch64_sve_whilelt:
- return optimizeIncrementingWhile(Op, DAG, /*IsSigned=*/true,
+ return optimizeIncrementingWhile(Op.getNode(), DAG, /*IsSigned=*/true,
/*IsEqual=*/false);
case Intrinsic::aarch64_sve_whilels:
- return optimizeIncrementingWhile(Op, DAG, /*IsSigned=*/false,
+ return optimizeIncrementingWhile(Op.getNode(), DAG, /*IsSigned=*/false,
/*IsEqual=*/true);
case Intrinsic::aarch64_sve_whilele:
- return optimizeIncrementingWhile(Op, DAG, /*IsSigned=*/true,
+ return optimizeIncrementingWhile(Op.getNode(), DAG, /*IsSigned=*/true,
/*IsEqual=*/true);
case Intrinsic::aarch64_sve_sunpkhi:
return DAG.getNode(AArch64ISD::SUNPKHI, dl, Op.getValueType(),
@@ -6532,28 +6540,6 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
return DAG.getNode(AArch64ISD::USDOT, dl, Op.getValueType(),
Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
}
- case Intrinsic::get_active_lane_mask: {
- SDValue ID =
- DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, dl, MVT::i64);
-
- EVT VT = Op.getValueType();
- if (VT.isScalableVector())
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, Op.getOperand(1),
- Op.getOperand(2));
-
- // We can use the SVE whilelo instruction to lower this intrinsic by
- // creating the appropriate sequence of scalable vector operations and
- // then extracting a fixed-width subvector from the scalable vector.
-
- EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
- EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
-
- SDValue Mask = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, WhileVT, ID,
- Op.getOperand(1), Op.getOperand(2));
- SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, dl, ContainerVT, Mask);
- return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, MaskAsInt,
- DAG.getVectorIdxConstant(0, dl));
- }
case Intrinsic::aarch64_neon_saddlv:
case Intrinsic::aarch64_neon_uaddlv: {
EVT OpVT = Op.getOperand(1).getValueType();
@@ -7692,6 +7678,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerVECTOR_DEINTERLEAVE(Op, DAG);
case ISD::VECTOR_INTERLEAVE:
return LowerVECTOR_INTERLEAVE(Op, DAG);
+ case ISD::GET_ACTIVE_LANE_MASK:
+ return LowerGET_ACTIVE_LANE_MASK(Op, DAG);
case ISD::LRINT:
case ISD::LLRINT:
if (Op.getValueType().isVector())
@@ -18152,6 +18140,70 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
return DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, UADDLP);
}
+static SDValue
+performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ const AArch64Subtarget *ST) {
+ if (DCI.isBeforeLegalize())
+ return SDValue();
+
+ if (SDValue While = optimizeIncrementingWhile(N, DCI.DAG, /*IsSigned=*/false,
+ /*IsEqual=*/false))
+ return While;
+
+ if (!ST->hasSVE2p1())
+ return SDValue();
+
+ if (!N->hasNUsesOfValue(2, 0))
+ return SDValue();
+
+ const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
+ if (HalfSize < 2)
+ return SDValue();
+
+ auto It = N->user_begin();
+ SDNode *Lo = *It++;
+ SDNode *Hi = *It;
+
+ if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+ Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR)
+ return SDValue();
+
+ uint64_t OffLo = Lo->getConstantOperandVal(1);
+ uint64_t OffHi = Hi->getConstantOperandVal(1);
+
+ if (OffLo > OffHi) {
+ std::swap(Lo, Hi);
+ std::swap(OffLo, OffHi);
+ }
+
+ if (OffLo != 0 || OffHi != HalfSize)
+ return SDValue();
+
+ EVT HalfVec = Lo->getValueType(0);
+ if (HalfVec != Hi->getValueType(0) ||
+ HalfVec.getVectorElementCount() != ElementCount::getScalable(HalfSize))
+ return SDValue();
+
+ SelectionDAG &DAG = DCI.DAG;
+ SDLoc DL(N);
+ SDValue ID =
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
+ SDValue Idx = N->getOperand(0);
+ SDValue TC = N->getOperand(1);
+ if (Idx.getValueType() != MVT::i64) {
+ Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
+ TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
+ }
+ auto R =
+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
+ {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
+
+ DCI.CombineTo(Lo, R.getValue(0));
+ DCI.CombineTo(Hi, R.getValue(1));
+
+ return SDValue(N, 0);
+}
+
// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
// vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one))
// vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B))
@@ -19682,6 +19734,8 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
static bool isPredicateCCSettingOp(SDValue N) {
if ((N.getOpcode() == ISD::SETCC) ||
+ // get_active_lane_mask is lowered to a whilelo instruction.
+ (N.getOpcode() == ISD::GET_ACTIVE_LANE_MASK) ||
(N.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
(N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege ||
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt ||
@@ -19690,9 +19744,7 @@ static bool isPredicateCCSettingOp(SDValue N) {
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele ||
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo ||
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
- N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
- // get_active_lane_mask is lowered to a whilelo instruction.
- N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask)))
+ N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt)))
return true;
return false;
@@ -21806,66 +21858,6 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
return SDValue();
}
-static SDValue tryCombineWhileLo(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI,
- const AArch64Subtarget *Subtarget) {
- if (DCI.isBeforeLegalize())
- return SDValue();
-
- if (!Subtarget->hasSVE2p1())
- return SDValue();
-
- if (!N->hasNUsesOfValue(2, 0))
- return SDValue();
-
- const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
- if (HalfSize < 2)
- return SDValue();
-
- auto It = N->user_begin();
- SDNode *Lo = *It++;
- SDNode *Hi = *It;
-
- if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
- Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR)
- return SDValue();
-
- uint64_t OffLo = Lo->getConstantOperandVal(1);
- uint64_t OffHi = Hi->getConstantOperandVal(1);
-
- if (OffLo > OffHi) {
- std::swap(Lo, Hi);
- std::swap(OffLo, OffHi);
- }
-
- if (OffLo != 0 || OffHi != HalfSize)
- return SDValue();
-
- EVT HalfVec = Lo->getValueType(0);
- if (HalfVec != Hi->getValueType(0) ||
- HalfVec.getVectorElementCount() != ElementCount::getScalable(HalfSize))
- return SDValue();
-
- SelectionDAG &DAG = DCI.DAG;
- SDLoc DL(N);
- SDValue ID =
- DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
- SDValue Idx = N->getOperand(1);
- SDValue TC = N->getOperand(2);
- if (Idx.getValueType() != MVT::i64) {
- Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
- TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
- }
- auto R =
- DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
- {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
-
- DCI.CombineTo(Lo, R.getValue(0));
- DCI.CombineTo(Hi, R.getValue(1));
-
- return SDValue(N, 0);
-}
-
SDValue tryLowerPartialReductionToDot(SDNode *N,
const AArch64Subtarget *Subtarget,
SelectionDAG &DAG) {
@@ -22345,7 +22337,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
AArch64CC::LAST_ACTIVE);
case Intrinsic::aarch64_sve_whilelo:
- return tryCombineWhileLo(N, DCI, Subtarget);
+ return DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, SDLoc(N), N->getValueType(0),
+ N->getOperand(1), N->getOperand(2));
case Intrinsic::aarch64_sve_bsl:
case Intrinsic::aarch64_sve_bsl1n:
case Intrinsic::aarch64_sve_bsl2n:
@@ -26777,6 +26770,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performExtractVectorEltCombine(N, DCI, Subtarget);
case ISD::VECREDUCE_ADD:
return performVecReduceAddCombine(N, DCI.DAG, Subtarget);
+ case ISD::GET_ACTIVE_LANE_MASK:
+ return performActiveLaneMaskCombine(N, DCI, Subtarget);
case AArch64ISD::UADDV:
return performUADDVCombine(N, DAG);
case AArch64ISD::SMULL:
@@ -27759,8 +27754,7 @@ void AArch64TargetLowering::ReplaceNodeResults(
DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, RuntimePStateSM));
return;
}
- case Intrinsic::experimental_vector_match:
- case Intrinsic::get_active_lane_mask: {
+ case Intrinsic::experimental_vector_match: {
if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
return;
@@ -29552,6 +29546,29 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
}
+SDValue
+AArch64TargetLowering::LowerGET_ACTIVE_LANE_MASK(SDValue Op,
+ SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+ assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
+
+ assert(Subtarget->isSVEorStreamingSVEAvailable() &&
+ "Lowering fixed length get_active_lane_mask requires SVE!");
+
+ // There are no dedicated fixed-length instructions for GET_ACTIVE_LANE_MASK,
+ // but we can use SVE when available.
+
+ SDLoc DL(Op);
+ EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+ EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
+
+ SDValue Mask = DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, DL, WhileVT,
+ Op.getOperand(0), Op.getOperand(1));
+ SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask);
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt,
+ DAG.getVectorIdxConstant(0, DL));
+}
+
SDValue
AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index ec8b0b920c453..c1e6d70099fa5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1183,6 +1183,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerGET_ACTIVE_LANE_MASK(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index d6bd59adef03b..65fe31a2143f2 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2141,12 +2141,12 @@ let Predicates = [HasSVE_or_SME] in {
defm WHILELT_PWW : sve_int_while4_rr<0b010, "whilelt", int_aarch64_sve_whilelt, int_aarch64_sve_whilegt>;
defm WHILELE_PWW : sve_int_while4_rr<0b011, "whilele", int_aarch64_sve_whilele, null_frag>;
- defm WHILELO_PWW : sve_int_while4_rr<0b110, "whilelo", int_aarch64_sve_whilelo, int_aarch64_sve_whilehi>;
+ defm WHILELO_PWW : sve_int_while4_rr<0b110, "whilelo", get_active_lane_mask, int_aarch64_sve_whilehi>;
defm WHILELS_PWW : sve_int_while4_rr<0b111, "whilels", int_aarch64_sve_whilels, null_frag>;
defm WHILELT_PXX : sve_int_while8_rr<0b010, "whilelt", int_aarch64_sve_whilelt, int_aarch64_sve_whilegt>;
defm WHILELE_PXX : sve_int_while8_rr<0b011, "whilele", int_aarch64_sve_whilele, null_frag>;
- defm WHILELO_PXX : sve_int_while8_rr<0b110, "whilelo", int_aarch64_sve_whilelo, int_aarch64_sve_whilehi>;
+ defm WHILELO_PXX : sve_int_while8_rr<0b110, "whilelo", get_active_lane_mask, int_aarch64_sve_whilehi>;
defm WHILELS_PXX : sve_int_while8_rr<0b111, "whilels", int_aarch64_sve_whilels, null_frag>;
def CTERMEQ_WW : sve_int_cterm<0b0, 0b0, "ctermeq", GPR32>;
@@ -3998,12 +3998,12 @@ let Predicates = [HasSVE2_or_SME] in {
defm WHILEGE_PWW : sve_int_while4_rr<0b000, "whilege", int_aarch64_sve_whilege, null_frag>;
defm WHILEGT_PWW : sve_int_while4_rr<0b001, "whilegt", int_aarch64_sve_whilegt, int_aarch64_sve_whilelt>;
defm WHILEHS_PWW : sve_int_while4_rr<0b100, "whilehs", int_aarch64_sve_whilehs, null_frag>;
- defm WHILEHI_PWW : sve_int_while4_rr<0b101, "whilehi", int_aarch64_sve_whilehi, int_aarch64_sve_whilelo>;
+ defm WHILEHI_PWW : sve_int_while4_rr<0b101, "whilehi", int_aarch64_sve_whilehi, get_active_lane_mask>;
defm WHILEGE_PXX : sve_int_while8_rr<0b000, "whilege", int_aarch64_sve_whilege, null_frag>;
defm WHILEGT_PXX : sve_int_while8_rr<0b001, "whilegt", int_aarch64_sve_whilegt, int_aarch64_sve_whilelt>;
defm WHILEHS_PXX : sve_int_while8_rr<0b100, "whilehs", int_aarch64_sve_whilehs, null_frag>;
- defm WHILEHI_PXX : sve_int_while8_rr<0b101, "whilehi", int_aarch64_sve_whilehi, int_aarch64_sve_whilelo>;
+ defm WHILEHI_PXX : sve_int_while8_rr<0b101, "whilehi", int_aarch64_sve_whilehi, get_active_lane_mask>;
// SVE2 pointer conflict compare
defm WHILEWR_PXX : sve2_int_while_rr<0b0, "whilewr", "int_aarch64_sve_whilewr">;
More information about the llvm-commits
mailing list