[llvm] 3be6532 - [X86] canonicalizeShuffleWithBinOps - generalize to handle some unary ops

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 12 02:59:50 PST 2023


Author: Simon Pilgrim
Date: 2023-12-12T10:59:38Z
New Revision: 3be65325f9e37db2d6eeffe96f9709d710a9a49f

URL: https://github.com/llvm/llvm-project/commit/3be65325f9e37db2d6eeffe96f9709d710a9a49f
DIFF: https://github.com/llvm/llvm-project/commit/3be65325f9e37db2d6eeffe96f9709d710a9a49f.diff

LOG: [X86] canonicalizeShuffleWithBinOps - generalize to handle some unary ops

Rename to canonicalizeShuffleWithOp and begin adding SHUFFLE(UNARYOP(X),UNARYOP(Y)) -> UNARYOP(SHUFFLE(X,Y)) fold support.

This is only kicking in after legalization, so targets that expand bit counts are still duplicating but it helps with a few initial cases.

I'm investigating adding support for extensions/conversions as well, but this is a first step.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/widen_bitcnt.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index adafb425babf1..66cb1a77901c3 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -39587,9 +39587,21 @@ static SDValue combineCommutableSHUFP(SDValue N, MVT VT, const SDLoc &DL,
   return SDValue();
 }
 
+// TODO - move this to TLI like isBinOp?
+static bool isUnaryOp(unsigned Opcode) {
+  switch (Opcode) {
+  case ISD::CTLZ:
+  case ISD::CTTZ:
+  case ISD::CTPOP:
+    return true;
+  }
+  return false;
+}
+
+// Canonicalize SHUFFLE(UNARYOP(X)) -> UNARYOP(SHUFFLE(X)).
 // Canonicalize SHUFFLE(BINOP(X,Y)) -> BINOP(SHUFFLE(X),SHUFFLE(Y)).
-static SDValue canonicalizeShuffleWithBinOps(SDValue N, SelectionDAG &DAG,
-                                             const SDLoc &DL) {
+static SDValue canonicalizeShuffleWithOp(SDValue N, SelectionDAG &DAG,
+                                         const SDLoc &DL) {
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   EVT ShuffleVT = N.getValueType();
 
@@ -39716,6 +39728,25 @@ static SDValue canonicalizeShuffleWithBinOps(SDValue N, SelectionDAG &DAG,
                                             DAG.getBitcast(OpVT, RHS)));
         }
       }
+      if (isUnaryOp(SrcOpcode) && N1.getOpcode() == SrcOpcode &&
+          N0.getValueType() == N1.getValueType() &&
+          IsSafeToMoveShuffle(N0, SrcOpcode) &&
+          IsSafeToMoveShuffle(N1, SrcOpcode)) {
+        SDValue Op00 = peekThroughOneUseBitcasts(N0.getOperand(0));
+        SDValue Op10 = peekThroughOneUseBitcasts(N1.getOperand(0));
+        SDValue Res;
+        Op00 = DAG.getBitcast(ShuffleVT, Op00);
+        Op10 = DAG.getBitcast(ShuffleVT, Op10);
+        if (N.getNumOperands() == 3) {
+          Res = DAG.getNode(Opc, DL, ShuffleVT, Op00, Op10, N.getOperand(2));
+        } else {
+          Res = DAG.getNode(Opc, DL, ShuffleVT, Op00, Op10);
+        }
+        EVT OpVT = N0.getValueType();
+        return DAG.getBitcast(
+            ShuffleVT,
+            DAG.getNode(SrcOpcode, DL, OpVT, DAG.getBitcast(OpVT, Res)));
+      }
     }
     break;
   }
@@ -40797,10 +40828,11 @@ static SDValue combineShuffle(SDNode *N, SelectionDAG &DAG,
     if (TLI.SimplifyDemandedVectorElts(Op, DemandedElts, DCI))
       return SDValue(N, 0);
 
+    // Canonicalize SHUFFLE(UNARYOP(X)) -> UNARYOP(SHUFFLE(X)).
     // Canonicalize SHUFFLE(BINOP(X,Y)) -> BINOP(SHUFFLE(X),SHUFFLE(Y)).
     // Perform this after other shuffle combines to allow inner shuffles to be
     // combined away first.
-    if (SDValue BinOp = canonicalizeShuffleWithBinOps(Op, DAG, dl))
+    if (SDValue BinOp = canonicalizeShuffleWithOp(Op, DAG, dl))
       return BinOp;
   }
 

diff  --git a/llvm/test/CodeGen/X86/widen_bitcnt.ll b/llvm/test/CodeGen/X86/widen_bitcnt.ll
index da468b6d809e8..0f121d88b3573 100644
--- a/llvm/test/CodeGen/X86/widen_bitcnt.ll
+++ b/llvm/test/CodeGen/X86/widen_bitcnt.ll
@@ -88,9 +88,8 @@ define <4 x i32> @widen_ctpop_v2i32_v4i32(<2 x i32> %a0, <2 x i32> %a1) {
 ;
 ; AVX512VPOPCNT-LABEL: widen_ctpop_v2i32_v4i32:
 ; AVX512VPOPCNT:       # %bb.0:
-; AVX512VPOPCNT-NEXT:    vpopcntd %xmm0, %xmm0
-; AVX512VPOPCNT-NEXT:    vpopcntd %xmm1, %xmm1
 ; AVX512VPOPCNT-NEXT:    vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; AVX512VPOPCNT-NEXT:    vpopcntd %xmm0, %xmm0
 ; AVX512VPOPCNT-NEXT:    retq
   %b0 = call <2 x i32> @llvm.ctpop.v2i32(<2 x i32> %a0)
   %b1 = call <2 x i32> @llvm.ctpop.v2i32(<2 x i32> %a1)
@@ -325,10 +324,9 @@ define <8 x i32> @widen_ctpop_v2i32_v8i32(<2 x i32> %a0, <2 x i32> %a1, <2 x i32
 ; AVX512VPOPCNT-NEXT:    # kill: def $xmm1 killed $xmm1 def $ymm1
 ; AVX512VPOPCNT-NEXT:    # kill: def $xmm0 killed $xmm0 def $ymm0
 ; AVX512VPOPCNT-NEXT:    vinserti128 $1, %xmm3, %ymm1, %ymm1
-; AVX512VPOPCNT-NEXT:    vpopcntd %ymm1, %ymm1
 ; AVX512VPOPCNT-NEXT:    vinserti128 $1, %xmm2, %ymm0, %ymm0
-; AVX512VPOPCNT-NEXT:    vpopcntd %ymm0, %ymm0
 ; AVX512VPOPCNT-NEXT:    vpunpcklqdq {{.*#+}} ymm0 = ymm0[0],ymm1[0],ymm0[2],ymm1[2]
+; AVX512VPOPCNT-NEXT:    vpopcntd %ymm0, %ymm0
 ; AVX512VPOPCNT-NEXT:    retq
   %b0 = call <2 x i32> @llvm.ctpop.v2i32(<2 x i32> %a0)
   %b1 = call <2 x i32> @llvm.ctpop.v2i32(<2 x i32> %a1)
@@ -438,9 +436,8 @@ define <4 x i32> @widen_ctlz_v2i32_v4i32(<2 x i32> %a0, <2 x i32> %a1) {
 ;
 ; AVX512-LABEL: widen_ctlz_v2i32_v4i32:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vplzcntd %xmm0, %xmm0
-; AVX512-NEXT:    vplzcntd %xmm1, %xmm1
 ; AVX512-NEXT:    vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; AVX512-NEXT:    vplzcntd %xmm0, %xmm0
 ; AVX512-NEXT:    retq
   %b0 = call <2 x i32> @llvm.ctlz.v2i32(<2 x i32> %a0, i1 0)
   %b1 = call <2 x i32> @llvm.ctlz.v2i32(<2 x i32> %a1, i1 0)
@@ -706,10 +703,9 @@ define <8 x i32> @widen_ctlz_v2i32_v8i32(<2 x i32> %a0, <2 x i32> %a1, <2 x i32>
 ; AVX512-NEXT:    # kill: def $xmm1 killed $xmm1 def $ymm1
 ; AVX512-NEXT:    # kill: def $xmm0 killed $xmm0 def $ymm0
 ; AVX512-NEXT:    vinserti128 $1, %xmm3, %ymm1, %ymm1
-; AVX512-NEXT:    vplzcntd %ymm1, %ymm1
 ; AVX512-NEXT:    vinserti128 $1, %xmm2, %ymm0, %ymm0
-; AVX512-NEXT:    vplzcntd %ymm0, %ymm0
 ; AVX512-NEXT:    vpunpcklqdq {{.*#+}} ymm0 = ymm0[0],ymm1[0],ymm0[2],ymm1[2]
+; AVX512-NEXT:    vplzcntd %ymm0, %ymm0
 ; AVX512-NEXT:    retq
   %b0 = call <2 x i32> @llvm.ctlz.v2i32(<2 x i32> %a0, i1 0)
   %b1 = call <2 x i32> @llvm.ctlz.v2i32(<2 x i32> %a1, i1 0)
@@ -819,9 +815,8 @@ define <4 x i32> @widen_ctlz_undef_v2i32_v4i32(<2 x i32> %a0, <2 x i32> %a1) {
 ;
 ; AVX512-LABEL: widen_ctlz_undef_v2i32_v4i32:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vplzcntd %xmm0, %xmm0
-; AVX512-NEXT:    vplzcntd %xmm1, %xmm1
 ; AVX512-NEXT:    vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; AVX512-NEXT:    vplzcntd %xmm0, %xmm0
 ; AVX512-NEXT:    retq
   %b0 = call <2 x i32> @llvm.ctlz.v2i32(<2 x i32> %a0, i1 1)
   %b1 = call <2 x i32> @llvm.ctlz.v2i32(<2 x i32> %a1, i1 1)
@@ -1087,10 +1082,9 @@ define <8 x i32> @widen_ctlz_undef_v2i32_v8i32(<2 x i32> %a0, <2 x i32> %a1, <2
 ; AVX512-NEXT:    # kill: def $xmm1 killed $xmm1 def $ymm1
 ; AVX512-NEXT:    # kill: def $xmm0 killed $xmm0 def $ymm0
 ; AVX512-NEXT:    vinserti128 $1, %xmm3, %ymm1, %ymm1
-; AVX512-NEXT:    vplzcntd %ymm1, %ymm1
 ; AVX512-NEXT:    vinserti128 $1, %xmm2, %ymm0, %ymm0
-; AVX512-NEXT:    vplzcntd %ymm0, %ymm0
 ; AVX512-NEXT:    vpunpcklqdq {{.*#+}} ymm0 = ymm0[0],ymm1[0],ymm0[2],ymm1[2]
+; AVX512-NEXT:    vplzcntd %ymm0, %ymm0
 ; AVX512-NEXT:    retq
   %b0 = call <2 x i32> @llvm.ctlz.v2i32(<2 x i32> %a0, i1 1)
   %b1 = call <2 x i32> @llvm.ctlz.v2i32(<2 x i32> %a1, i1 1)
@@ -1176,12 +1170,11 @@ define <4 x i32> @widen_cttz_v2i32_v4i32(<2 x i32> %a0, <2 x i32> %a1) {
 ; AVX512VL-NEXT:    vpcmpeqd %xmm2, %xmm2, %xmm2
 ; AVX512VL-NEXT:    vpaddd %xmm2, %xmm0, %xmm3
 ; AVX512VL-NEXT:    vpandn %xmm3, %xmm0, %xmm0
-; AVX512VL-NEXT:    vplzcntd %xmm0, %xmm0
 ; AVX512VL-NEXT:    vpbroadcastd {{.*#+}} xmm3 = [32,32,32,32]
 ; AVX512VL-NEXT:    vpaddd %xmm2, %xmm1, %xmm2
 ; AVX512VL-NEXT:    vpandn %xmm2, %xmm1, %xmm1
-; AVX512VL-NEXT:    vplzcntd %xmm1, %xmm1
 ; AVX512VL-NEXT:    vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; AVX512VL-NEXT:    vplzcntd %xmm0, %xmm0
 ; AVX512VL-NEXT:    vpsubd %xmm0, %xmm3, %xmm0
 ; AVX512VL-NEXT:    retq
 ;
@@ -1190,11 +1183,10 @@ define <4 x i32> @widen_cttz_v2i32_v4i32(<2 x i32> %a0, <2 x i32> %a1) {
 ; AVX512VPOPCNT-NEXT:    vpcmpeqd %xmm2, %xmm2, %xmm2
 ; AVX512VPOPCNT-NEXT:    vpaddd %xmm2, %xmm0, %xmm3
 ; AVX512VPOPCNT-NEXT:    vpandn %xmm3, %xmm0, %xmm0
-; AVX512VPOPCNT-NEXT:    vpopcntd %xmm0, %xmm0
 ; AVX512VPOPCNT-NEXT:    vpaddd %xmm2, %xmm1, %xmm2
 ; AVX512VPOPCNT-NEXT:    vpandn %xmm2, %xmm1, %xmm1
-; AVX512VPOPCNT-NEXT:    vpopcntd %xmm1, %xmm1
 ; AVX512VPOPCNT-NEXT:    vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; AVX512VPOPCNT-NEXT:    vpopcntd %xmm0, %xmm0
 ; AVX512VPOPCNT-NEXT:    retq
   %b0 = call <2 x i32> @llvm.cttz.v2i32(<2 x i32> %a0, i1 0)
   %b1 = call <2 x i32> @llvm.cttz.v2i32(<2 x i32> %a1, i1 0)
@@ -1416,12 +1408,11 @@ define <8 x i32> @widen_cttz_v2i32_v8i32(<2 x i32> %a0, <2 x i32> %a1, <2 x i32>
 ; AVX512VL-NEXT:    vpcmpeqd %ymm3, %ymm3, %ymm3
 ; AVX512VL-NEXT:    vpaddd %ymm3, %ymm1, %ymm4
 ; AVX512VL-NEXT:    vpandn %ymm4, %ymm1, %ymm1
-; AVX512VL-NEXT:    vplzcntd %ymm1, %ymm1
 ; AVX512VL-NEXT:    vinserti128 $1, %xmm2, %ymm0, %ymm0
 ; AVX512VL-NEXT:    vpaddd %ymm3, %ymm0, %ymm2
 ; AVX512VL-NEXT:    vpandn %ymm2, %ymm0, %ymm0
-; AVX512VL-NEXT:    vplzcntd %ymm0, %ymm0
 ; AVX512VL-NEXT:    vpunpcklqdq {{.*#+}} ymm0 = ymm0[0],ymm1[0],ymm0[2],ymm1[2]
+; AVX512VL-NEXT:    vplzcntd %ymm0, %ymm0
 ; AVX512VL-NEXT:    vpbroadcastd {{.*#+}} ymm1 = [32,32,32,32,32,32,32,32]
 ; AVX512VL-NEXT:    vpsubd %ymm0, %ymm1, %ymm0
 ; AVX512VL-NEXT:    retq
@@ -1434,12 +1425,11 @@ define <8 x i32> @widen_cttz_v2i32_v8i32(<2 x i32> %a0, <2 x i32> %a1, <2 x i32>
 ; AVX512VPOPCNT-NEXT:    vpcmpeqd %ymm3, %ymm3, %ymm3
 ; AVX512VPOPCNT-NEXT:    vpaddd %ymm3, %ymm1, %ymm4
 ; AVX512VPOPCNT-NEXT:    vpandn %ymm4, %ymm1, %ymm1
-; AVX512VPOPCNT-NEXT:    vpopcntd %ymm1, %ymm1
 ; AVX512VPOPCNT-NEXT:    vinserti128 $1, %xmm2, %ymm0, %ymm0
 ; AVX512VPOPCNT-NEXT:    vpaddd %ymm3, %ymm0, %ymm2
 ; AVX512VPOPCNT-NEXT:    vpandn %ymm2, %ymm0, %ymm0
-; AVX512VPOPCNT-NEXT:    vpopcntd %ymm0, %ymm0
 ; AVX512VPOPCNT-NEXT:    vpunpcklqdq {{.*#+}} ymm0 = ymm0[0],ymm1[0],ymm0[2],ymm1[2]
+; AVX512VPOPCNT-NEXT:    vpopcntd %ymm0, %ymm0
 ; AVX512VPOPCNT-NEXT:    retq
   %b0 = call <2 x i32> @llvm.cttz.v2i32(<2 x i32> %a0, i1 0)
   %b1 = call <2 x i32> @llvm.cttz.v2i32(<2 x i32> %a1, i1 0)
@@ -1525,12 +1515,11 @@ define <4 x i32> @widen_cttz_undef_v2i32_v4i32(<2 x i32> %a0, <2 x i32> %a1) {
 ; AVX512VL-NEXT:    vpcmpeqd %xmm2, %xmm2, %xmm2
 ; AVX512VL-NEXT:    vpaddd %xmm2, %xmm0, %xmm3
 ; AVX512VL-NEXT:    vpandn %xmm3, %xmm0, %xmm0
-; AVX512VL-NEXT:    vplzcntd %xmm0, %xmm0
 ; AVX512VL-NEXT:    vpbroadcastd {{.*#+}} xmm3 = [32,32,32,32]
 ; AVX512VL-NEXT:    vpaddd %xmm2, %xmm1, %xmm2
 ; AVX512VL-NEXT:    vpandn %xmm2, %xmm1, %xmm1
-; AVX512VL-NEXT:    vplzcntd %xmm1, %xmm1
 ; AVX512VL-NEXT:    vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; AVX512VL-NEXT:    vplzcntd %xmm0, %xmm0
 ; AVX512VL-NEXT:    vpsubd %xmm0, %xmm3, %xmm0
 ; AVX512VL-NEXT:    retq
 ;
@@ -1539,11 +1528,10 @@ define <4 x i32> @widen_cttz_undef_v2i32_v4i32(<2 x i32> %a0, <2 x i32> %a1) {
 ; AVX512VPOPCNT-NEXT:    vpcmpeqd %xmm2, %xmm2, %xmm2
 ; AVX512VPOPCNT-NEXT:    vpaddd %xmm2, %xmm0, %xmm3
 ; AVX512VPOPCNT-NEXT:    vpandn %xmm3, %xmm0, %xmm0
-; AVX512VPOPCNT-NEXT:    vpopcntd %xmm0, %xmm0
 ; AVX512VPOPCNT-NEXT:    vpaddd %xmm2, %xmm1, %xmm2
 ; AVX512VPOPCNT-NEXT:    vpandn %xmm2, %xmm1, %xmm1
-; AVX512VPOPCNT-NEXT:    vpopcntd %xmm1, %xmm1
 ; AVX512VPOPCNT-NEXT:    vpunpcklqdq {{.*#+}} xmm0 = xmm0[0],xmm1[0]
+; AVX512VPOPCNT-NEXT:    vpopcntd %xmm0, %xmm0
 ; AVX512VPOPCNT-NEXT:    retq
   %b0 = call <2 x i32> @llvm.cttz.v2i32(<2 x i32> %a0, i1 1)
   %b1 = call <2 x i32> @llvm.cttz.v2i32(<2 x i32> %a1, i1 1)
@@ -1765,12 +1753,11 @@ define <8 x i32> @widen_cttz_undef_v2i32_v8i32(<2 x i32> %a0, <2 x i32> %a1, <2
 ; AVX512VL-NEXT:    vpcmpeqd %ymm3, %ymm3, %ymm3
 ; AVX512VL-NEXT:    vpaddd %ymm3, %ymm1, %ymm4
 ; AVX512VL-NEXT:    vpandn %ymm4, %ymm1, %ymm1
-; AVX512VL-NEXT:    vplzcntd %ymm1, %ymm1
 ; AVX512VL-NEXT:    vinserti128 $1, %xmm2, %ymm0, %ymm0
 ; AVX512VL-NEXT:    vpaddd %ymm3, %ymm0, %ymm2
 ; AVX512VL-NEXT:    vpandn %ymm2, %ymm0, %ymm0
-; AVX512VL-NEXT:    vplzcntd %ymm0, %ymm0
 ; AVX512VL-NEXT:    vpunpcklqdq {{.*#+}} ymm0 = ymm0[0],ymm1[0],ymm0[2],ymm1[2]
+; AVX512VL-NEXT:    vplzcntd %ymm0, %ymm0
 ; AVX512VL-NEXT:    vpbroadcastd {{.*#+}} ymm1 = [32,32,32,32,32,32,32,32]
 ; AVX512VL-NEXT:    vpsubd %ymm0, %ymm1, %ymm0
 ; AVX512VL-NEXT:    retq
@@ -1783,12 +1770,11 @@ define <8 x i32> @widen_cttz_undef_v2i32_v8i32(<2 x i32> %a0, <2 x i32> %a1, <2
 ; AVX512VPOPCNT-NEXT:    vpcmpeqd %ymm3, %ymm3, %ymm3
 ; AVX512VPOPCNT-NEXT:    vpaddd %ymm3, %ymm1, %ymm4
 ; AVX512VPOPCNT-NEXT:    vpandn %ymm4, %ymm1, %ymm1
-; AVX512VPOPCNT-NEXT:    vpopcntd %ymm1, %ymm1
 ; AVX512VPOPCNT-NEXT:    vinserti128 $1, %xmm2, %ymm0, %ymm0
 ; AVX512VPOPCNT-NEXT:    vpaddd %ymm3, %ymm0, %ymm2
 ; AVX512VPOPCNT-NEXT:    vpandn %ymm2, %ymm0, %ymm0
-; AVX512VPOPCNT-NEXT:    vpopcntd %ymm0, %ymm0
 ; AVX512VPOPCNT-NEXT:    vpunpcklqdq {{.*#+}} ymm0 = ymm0[0],ymm1[0],ymm0[2],ymm1[2]
+; AVX512VPOPCNT-NEXT:    vpopcntd %ymm0, %ymm0
 ; AVX512VPOPCNT-NEXT:    retq
   %b0 = call <2 x i32> @llvm.cttz.v2i32(<2 x i32> %a0, i1 1)
   %b1 = call <2 x i32> @llvm.cttz.v2i32(<2 x i32> %a1, i1 1)


        


More information about the llvm-commits mailing list