[llvm-branch-commits] [llvm] release/21.x: [LoongArch] Fix failure to widen operand for `[X]VMSK{LT, GE, NE}Z` (#149442) (PR #149778)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Jul 21 01:44:28 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-loongarch
Author: None (llvmbot)
<details>
<summary>Changes</summary>
Backport 8a307ae61963a3f967052f7ea3c89aafa56934cf
Requested by: @<!-- -->heiher
---
Full diff: https://github.com/llvm/llvm-project/pull/149778.diff
2 Files Affected:
- (modified) llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp (+124-97)
- (modified) llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll (+15)
``````````diff
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index c47987fbf683b..12cf04bbbab56 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -4563,6 +4563,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT,
llvm_unreachable("Unexpected node type for vXi1 sign extension");
}
+static SDValue
+performSETCC_BITCASTCombine(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const LoongArchSubtarget &Subtarget) {
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ SDValue Src = N->getOperand(0);
+ EVT SrcVT = Src.getValueType();
+
+ if (Src.getOpcode() != ISD::SETCC || !Src.hasOneUse())
+ return SDValue();
+
+ bool UseLASX;
+ unsigned Opc = ISD::DELETED_NODE;
+ EVT CmpVT = Src.getOperand(0).getValueType();
+ EVT EltVT = CmpVT.getVectorElementType();
+
+ if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() == 128)
+ UseLASX = false;
+ else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
+ CmpVT.getSizeInBits() == 256)
+ UseLASX = true;
+ else
+ return SDValue();
+
+ SDValue SrcN1 = Src.getOperand(1);
+ switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
+ default:
+ break;
+ case ISD::SETEQ:
+ // x == 0 => not (vmsknez.b x)
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+ Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
+ break;
+ case ISD::SETGT:
+ // x > -1 => vmskgez.b x
+ if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+ break;
+ case ISD::SETGE:
+ // x >= 0 => vmskgez.b x
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+ break;
+ case ISD::SETLT:
+ // x < 0 => vmskltz.{b,h,w,d} x
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+ EltVT == MVT::i64))
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+ break;
+ case ISD::SETLE:
+ // x <= -1 => vmskltz.{b,h,w,d} x
+ if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+ EltVT == MVT::i64))
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+ break;
+ case ISD::SETNE:
+ // x != 0 => vmsknez.b x
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+ Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
+ break;
+ }
+
+ if (Opc == ISD::DELETED_NODE)
+ return SDValue();
+
+ SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src.getOperand(0));
+ EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
+ V = DAG.getZExtOrTrunc(V, DL, T);
+ return DAG.getBitcast(VT, V);
+}
+
static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const LoongArchSubtarget &Subtarget) {
@@ -4577,110 +4651,63 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1)
return SDValue();
- unsigned Opc = ISD::DELETED_NODE;
// Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
+ SDValue Res = performSETCC_BITCASTCombine(N, DAG, DCI, Subtarget);
+ if (Res)
+ return Res;
+
+ // Generate vXi1 using [X]VMSKLTZ
+ MVT SExtVT;
+ unsigned Opc;
+ bool UseLASX = false;
+ bool PropagateSExt = false;
+
if (Src.getOpcode() == ISD::SETCC && Src.hasOneUse()) {
- bool UseLASX;
EVT CmpVT = Src.getOperand(0).getValueType();
- EVT EltVT = CmpVT.getVectorElementType();
-
- if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() <= 128)
- UseLASX = false;
- else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
- CmpVT.getSizeInBits() <= 256)
- UseLASX = true;
- else
+ if (CmpVT.getSizeInBits() > 256)
return SDValue();
-
- SDValue SrcN1 = Src.getOperand(1);
- switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
- default:
- break;
- case ISD::SETEQ:
- // x == 0 => not (vmsknez.b x)
- if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
- Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
- break;
- case ISD::SETGT:
- // x > -1 => vmskgez.b x
- if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
- Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
- break;
- case ISD::SETGE:
- // x >= 0 => vmskgez.b x
- if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
- Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
- break;
- case ISD::SETLT:
- // x < 0 => vmskltz.{b,h,w,d} x
- if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
- (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
- EltVT == MVT::i64))
- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
- break;
- case ISD::SETLE:
- // x <= -1 => vmskltz.{b,h,w,d} x
- if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
- (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
- EltVT == MVT::i64))
- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
- break;
- case ISD::SETNE:
- // x != 0 => vmsknez.b x
- if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
- Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
- break;
- }
}
- // Generate vXi1 using [X]VMSKLTZ
- if (Opc == ISD::DELETED_NODE) {
- MVT SExtVT;
- bool UseLASX = false;
- bool PropagateSExt = false;
- switch (SrcVT.getSimpleVT().SimpleTy) {
- default:
- return SDValue();
- case MVT::v2i1:
- SExtVT = MVT::v2i64;
- break;
- case MVT::v4i1:
- SExtVT = MVT::v4i32;
- if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
- SExtVT = MVT::v4i64;
- UseLASX = true;
- PropagateSExt = true;
- }
- break;
- case MVT::v8i1:
- SExtVT = MVT::v8i16;
- if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
- SExtVT = MVT::v8i32;
- UseLASX = true;
- PropagateSExt = true;
- }
- break;
- case MVT::v16i1:
- SExtVT = MVT::v16i8;
- if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
- SExtVT = MVT::v16i16;
- UseLASX = true;
- PropagateSExt = true;
- }
- break;
- case MVT::v32i1:
- SExtVT = MVT::v32i8;
+ switch (SrcVT.getSimpleVT().SimpleTy) {
+ default:
+ return SDValue();
+ case MVT::v2i1:
+ SExtVT = MVT::v2i64;
+ break;
+ case MVT::v4i1:
+ SExtVT = MVT::v4i32;
+ if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+ SExtVT = MVT::v4i64;
UseLASX = true;
- break;
- };
- if (UseLASX && !Subtarget.has32S() && !Subtarget.hasExtLASX())
- return SDValue();
- Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
- : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
- } else {
- Src = Src.getOperand(0);
- }
+ PropagateSExt = true;
+ }
+ break;
+ case MVT::v8i1:
+ SExtVT = MVT::v8i16;
+ if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+ SExtVT = MVT::v8i32;
+ UseLASX = true;
+ PropagateSExt = true;
+ }
+ break;
+ case MVT::v16i1:
+ SExtVT = MVT::v16i8;
+ if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+ SExtVT = MVT::v16i16;
+ UseLASX = true;
+ PropagateSExt = true;
+ }
+ break;
+ case MVT::v32i1:
+ SExtVT = MVT::v32i8;
+ UseLASX = true;
+ break;
+ };
+ if (UseLASX && !(Subtarget.has32S() && Subtarget.hasExtLASX()))
+ return SDValue();
+ Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
+ : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src);
EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
diff --git a/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll b/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
index 0ee30120f77a6..ad57bbf9ee5c0 100644
--- a/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
+++ b/llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll
@@ -588,3 +588,18 @@ define i2 @vmsk_trunc_i64(<2 x i64> %a) {
%res = bitcast <2 x i1> %y to i2
ret i2 %res
}
+
+define i4 @vmsk_eq_allzeros_v4i8(<4 x i8> %a) {
+; CHECK-LABEL: vmsk_eq_allzeros_v4i8:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vseqi.b $vr0, $vr0, 0
+; CHECK-NEXT: vilvl.b $vr0, $vr0, $vr0
+; CHECK-NEXT: vilvl.h $vr0, $vr0, $vr0
+; CHECK-NEXT: vslli.w $vr0, $vr0, 24
+; CHECK-NEXT: vmskltz.w $vr0, $vr0
+; CHECK-NEXT: vpickve2gr.hu $a0, $vr0, 0
+; CHECK-NEXT: ret
+ %1 = icmp eq <4 x i8> %a, zeroinitializer
+ %2 = bitcast <4 x i1> %1 to i4
+ ret i4 %2
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/149778
More information about the llvm-branch-commits
mailing list