[llvm] daa8033 - [CodeGen] Support folds of not(cmp(cc, ...)) -> cmp(!cc, ...) for scalable vectors

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 1 01:50:33 PST 2022


Author: David Sherwood
Date: 2022-02-01T09:50:00Z
New Revision: daa80339dfcb2a8cfea412cf918cbfae0c279fa0

URL: https://github.com/llvm/llvm-project/commit/daa80339dfcb2a8cfea412cf918cbfae0c279fa0
DIFF: https://github.com/llvm/llvm-project/commit/daa80339dfcb2a8cfea412cf918cbfae0c279fa0.diff

LOG: [CodeGen] Support folds of not(cmp(cc, ...)) -> cmp(!cc, ...) for scalable vectors

I have updated TargetLowering::isConstTrueVal to also consider
SPLAT_VECTOR nodes with constant integer operands. This allows the
optimisation to also work for targets that support scalable vectors.

Differential Revision: https://reviews.llvm.org/D117210

Added: 
    llvm/test/CodeGen/AArch64/sve-cmp-folds.ll
    llvm/test/CodeGen/RISCV/rvv/cmp-folds.ll

Modified: 
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Target/ARM/ARMISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 94d845b39c148..e53da0274480a 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3682,11 +3682,11 @@ class TargetLowering : public TargetLoweringBase {
 
   /// Return if the N is a constant or constant vector equal to the true value
   /// from getBooleanContents().
-  bool isConstTrueVal(const SDNode *N) const;
+  bool isConstTrueVal(SDValue N) const;
 
   /// Return if the N is a constant or constant vector equal to the false value
   /// from getBooleanContents().
-  bool isConstFalseVal(const SDNode *N) const;
+  bool isConstFalseVal(SDValue N) const;
 
   /// Return if \p N is a True value when extended to \p VT.
   bool isExtendedTrueVal(const ConstantSDNode *N, EVT VT, bool SExt) const;

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 57b933c05450d..1957f5231bdf0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -906,9 +906,8 @@ bool DAGCombiner::isSetCCEquivalent(SDValue N, SDValue &LHS, SDValue &RHS,
     return true;
   }
 
-  if (N.getOpcode() != ISD::SELECT_CC ||
-      !TLI.isConstTrueVal(N.getOperand(2).getNode()) ||
-      !TLI.isConstFalseVal(N.getOperand(3).getNode()))
+  if (N.getOpcode() != ISD::SELECT_CC || !TLI.isConstTrueVal(N.getOperand(2)) ||
+      !TLI.isConstFalseVal(N.getOperand(3)))
     return false;
 
   if (TLI.getBooleanContents(N.getValueType()) ==
@@ -8035,8 +8034,8 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
   // fold !(x cc y) -> (x !cc y)
   unsigned N0Opcode = N0.getOpcode();
   SDValue LHS, RHS, CC;
-  if (TLI.isConstTrueVal(N1.getNode()) &&
-      isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/true)) {
+  if (TLI.isConstTrueVal(N1) &&
+      isSetCCEquivalent(N0, LHS, RHS, CC, /*MatchStrict*/ true)) {
     ISD::CondCode NotCC = ISD::getSetCCInverse(cast<CondCodeSDNode>(CC)->get(),
                                                LHS.getValueType());
     if (!LegalOperations ||

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index cdd51bb88beaa..f6d1fa87676f2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -3194,29 +3194,25 @@ bool TargetLowering::isSplatValueForTargetNode(SDValue Op,
 // FIXME: Ideally, this would use ISD::isConstantSplatVector(), but that must
 // work with truncating build vectors and vectors with elements of less than
 // 8 bits.
-bool TargetLowering::isConstTrueVal(const SDNode *N) const {
+bool TargetLowering::isConstTrueVal(SDValue N) const {
   if (!N)
     return false;
 
+  unsigned EltWidth;
   APInt CVal;
-  if (auto *CN = dyn_cast<ConstantSDNode>(N)) {
+  if (ConstantSDNode *CN = isConstOrConstSplat(N, /*AllowUndefs=*/false,
+                                               /*AllowTruncation=*/true)) {
     CVal = CN->getAPIntValue();
-  } else if (auto *BV = dyn_cast<BuildVectorSDNode>(N)) {
-    auto *CN = BV->getConstantSplatNode();
-    if (!CN)
-      return false;
-
-    // If this is a truncating build vector, truncate the splat value.
-    // Otherwise, we may fail to match the expected values below.
-    unsigned BVEltWidth = BV->getValueType(0).getScalarSizeInBits();
-    CVal = CN->getAPIntValue();
-    if (BVEltWidth < CVal.getBitWidth())
-      CVal = CVal.trunc(BVEltWidth);
-  } else {
+    EltWidth = N.getValueType().getScalarSizeInBits();
+  } else
     return false;
-  }
 
-  switch (getBooleanContents(N->getValueType(0))) {
+  // If this is a truncating splat, truncate the splat value.
+  // Otherwise, we may fail to match the expected values below.
+  if (EltWidth < CVal.getBitWidth())
+    CVal = CVal.trunc(EltWidth);
+
+  switch (getBooleanContents(N.getValueType())) {
   case UndefinedBooleanContent:
     return CVal[0];
   case ZeroOrOneBooleanContent:
@@ -3228,7 +3224,7 @@ bool TargetLowering::isConstTrueVal(const SDNode *N) const {
   llvm_unreachable("Invalid boolean contents");
 }
 
-bool TargetLowering::isConstFalseVal(const SDNode *N) const {
+bool TargetLowering::isConstFalseVal(SDValue N) const {
   if (!N)
     return false;
 
@@ -3763,7 +3759,7 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
         if (TopSetCC.getValueType() == MVT::i1 && VT == MVT::i1 &&
             TopSetCC.getOpcode() == ISD::SETCC &&
             (N0Opc == ISD::ZERO_EXTEND || N0Opc == ISD::SIGN_EXTEND) &&
-            (isConstFalseVal(N1C) ||
+            (isConstFalseVal(N1) ||
              isExtendedTrueVal(N1C, N0->getValueType(0), SExt))) {
 
           bool Inverse = (N1C->isZero() && Cond == ISD::SETEQ) ||

diff  --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index fe4e6b24367a3..1b41427a1cab1 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -14527,7 +14527,7 @@ static SDValue PerformXORCombine(SDNode *N,
     SDValue N0 = N->getOperand(0);
     SDValue N1 = N->getOperand(1);
     const TargetLowering *TLI = Subtarget->getTargetLowering();
-    if (TLI->isConstTrueVal(N1.getNode()) &&
+    if (TLI->isConstTrueVal(N1) &&
         (N0->getOpcode() == ARMISD::VCMP || N0->getOpcode() == ARMISD::VCMPZ)) {
       if (CanInvertMVEVCMP(N0)) {
         SDLoc DL(N0);

diff  --git a/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll
new file mode 100644
index 0000000000000..c758889f77edb
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-cmp-folds.ll
@@ -0,0 +1,54 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-unknown -mattr=+sve -o - < %s | FileCheck %s
+
+define <vscale x 8 x i1> @not_icmp_sle_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: not_icmp_sle_nxv8i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.h
+; CHECK-NEXT:    cmpgt p0.h, p0/z, z0.h, z1.h
+; CHECK-NEXT:    ret
+  %icmp = icmp sle <vscale x 8 x i16> %a, %b
+  %tmp = insertelement <vscale x 8 x i1> undef, i1 true, i32 0
+  %ones = shufflevector <vscale x 8 x i1> %tmp, <vscale x 8 x i1> undef, <vscale x 8 x i32> zeroinitializer
+  %not = xor <vscale x 8 x i1> %ones, %icmp
+  ret <vscale x 8 x i1> %not
+}
+
+define <vscale x 4 x i1> @not_icmp_sgt_nxv4i32(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b) {
+; CHECK-LABEL: not_icmp_sgt_nxv4i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    cmpge p0.s, p0/z, z1.s, z0.s
+; CHECK-NEXT:    ret
+  %icmp = icmp sgt <vscale x 4 x i32> %a, %b
+  %tmp = insertelement <vscale x 4 x i1> undef, i1 true, i32 0
+  %ones = shufflevector <vscale x 4 x i1> %tmp, <vscale x 4 x i1> undef, <vscale x 4 x i32> zeroinitializer
+  %not = xor <vscale x 4 x i1> %icmp, %ones
+  ret <vscale x 4 x i1> %not
+}
+
+define <vscale x 2 x i1> @not_fcmp_une_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b) {
+; CHECK-LABEL: not_fcmp_une_nxv2f64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fcmeq p0.d, p0/z, z0.d, z1.d
+; CHECK-NEXT:    ret
+  %icmp = fcmp une <vscale x 2 x double> %a, %b
+  %tmp = insertelement <vscale x 2 x i1> undef, i1 true, i32 0
+  %ones = shufflevector <vscale x 2 x i1> %tmp, <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer
+  %not = xor <vscale x 2 x i1> %icmp, %ones
+  ret <vscale x 2 x i1> %not
+}
+
+define <vscale x 4 x i1> @not_fcmp_uge_nxv4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b) {
+; CHECK-LABEL: not_fcmp_uge_nxv4f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s
+; CHECK-NEXT:    fcmgt p0.s, p0/z, z1.s, z0.s
+; CHECK-NEXT:    ret
+  %icmp = fcmp uge <vscale x 4 x float> %a, %b
+  %tmp = insertelement <vscale x 4 x i1> undef, i1 true, i32 0
+  %ones = shufflevector <vscale x 4 x i1> %tmp, <vscale x 4 x i1> undef, <vscale x 4 x i32> zeroinitializer
+  %not = xor <vscale x 4 x i1> %icmp, %ones
+  ret <vscale x 4 x i1> %not
+}

diff  --git a/llvm/test/CodeGen/RISCV/rvv/cmp-folds.ll b/llvm/test/CodeGen/RISCV/rvv/cmp-folds.ll
new file mode 100644
index 0000000000000..bafcce4688cc2
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/cmp-folds.ll
@@ -0,0 +1,55 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+m,+d,+zfh,+v -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv64 -mattr=+m,+d,+zfh,+v -verify-machineinstrs < %s | FileCheck %s
+
+define <vscale x 8 x i1> @not_icmp_sle_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i16> %b) {
+; CHECK-LABEL: not_icmp_sle_nxv8i16:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e16, m2, ta, mu
+; CHECK-NEXT:    vmslt.vv v0, v10, v8
+; CHECK-NEXT:    ret
+  %icmp = icmp sle <vscale x 8 x i16> %a, %b
+  %tmp = insertelement <vscale x 8 x i1> undef, i1 true, i32 0
+  %ones = shufflevector <vscale x 8 x i1> %tmp, <vscale x 8 x i1> undef, <vscale x 8 x i32> zeroinitializer
+  %not = xor <vscale x 8 x i1> %ones, %icmp
+  ret <vscale x 8 x i1> %not
+}
+
+define <vscale x 4 x i1> @not_icmp_sgt_nxv4i32(<vscale x 4 x i32> %a, <vscale x 4 x i32> %b) {
+; CHECK-LABEL: not_icmp_sgt_nxv4i32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e32, m2, ta, mu
+; CHECK-NEXT:    vmsle.vv v0, v8, v10
+; CHECK-NEXT:    ret
+  %icmp = icmp sgt <vscale x 4 x i32> %a, %b
+  %tmp = insertelement <vscale x 4 x i1> undef, i1 true, i32 0
+  %ones = shufflevector <vscale x 4 x i1> %tmp, <vscale x 4 x i1> undef, <vscale x 4 x i32> zeroinitializer
+  %not = xor <vscale x 4 x i1> %icmp, %ones
+  ret <vscale x 4 x i1> %not
+}
+
+define <vscale x 2 x i1> @not_fcmp_une_nxv2f64(<vscale x 2 x double> %a, <vscale x 2 x double> %b) {
+; CHECK-LABEL: not_fcmp_une_nxv2f64:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e64, m2, ta, mu
+; CHECK-NEXT:    vmfeq.vv v0, v8, v10
+; CHECK-NEXT:    ret
+  %icmp = fcmp une <vscale x 2 x double> %a, %b
+  %tmp = insertelement <vscale x 2 x i1> undef, i1 true, i32 0
+  %ones = shufflevector <vscale x 2 x i1> %tmp, <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer
+  %not = xor <vscale x 2 x i1> %icmp, %ones
+  ret <vscale x 2 x i1> %not
+}
+
+define <vscale x 4 x i1> @not_fcmp_uge_nxv4f32(<vscale x 4 x float> %a, <vscale x 4 x float> %b) {
+; CHECK-LABEL: not_fcmp_uge_nxv4f32:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli a0, zero, e32, m2, ta, mu
+; CHECK-NEXT:    vmflt.vv v0, v8, v10
+; CHECK-NEXT:    ret
+  %icmp = fcmp uge <vscale x 4 x float> %a, %b
+  %tmp = insertelement <vscale x 4 x i1> undef, i1 true, i32 0
+  %ones = shufflevector <vscale x 4 x i1> %tmp, <vscale x 4 x i1> undef, <vscale x 4 x i32> zeroinitializer
+  %not = xor <vscale x 4 x i1> %icmp, %ones
+  ret <vscale x 4 x i1> %not
+}


        


More information about the llvm-commits mailing list