[llvm] df12524 - [X86] Turn X86DAGToDAGISel::tryVPTERNLOG into a fully custom instruction selector that can handle bitcasts between logic ops

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Sun Jul 26 12:19:35 PDT 2020


Author: Craig Topper
Date: 2020-07-26T12:19:08-07:00
New Revision: df12524e6ba02d3eda975de4541f55e151074b07

URL: https://github.com/llvm/llvm-project/commit/df12524e6ba02d3eda975de4541f55e151074b07
DIFF: https://github.com/llvm/llvm-project/commit/df12524e6ba02d3eda975de4541f55e151074b07.diff

LOG: [X86] Turn X86DAGToDAGISel::tryVPTERNLOG into a fully custom instruction selector that can handle bitcasts between logic ops

Previously we just matched the logic ops and replaced with an
X86ISD::VPTERNLOG node that we would send through the normal
pattern match. But that approach couldn't handle a bitcast
between the logic ops. Extending that approach would require us
to peek through the bitcasts and emit new bitcasts to match
the types. Those new bitcasts would then have to be properly
topologically sorted.

This patch instead switches to directly emitting the
MachineSDNode and skips the normal tablegen pattern matching.
We do have to handle load folding and broadcast load folding
ourselves now. Which also means commuting the immediate control.

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D83630

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
    llvm/test/CodeGen/X86/avx512-logic.ll
    llvm/test/CodeGen/X86/avx512vl-logic.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
index 3cd80cb04ab8..4098911dee3b 100644
--- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -3940,30 +3940,39 @@ bool X86DAGToDAGISel::tryVPTERNLOG(SDNode *N) {
   if (!(Subtarget->hasVLX() || NVT.is512BitVector()))
     return false;
 
-  unsigned Opc1 = N->getOpcode();
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
 
-  auto isLogicOp = [](unsigned Opc) {
-    return Opc == ISD::AND || Opc == ISD::OR || Opc == ISD::XOR ||
-           Opc == X86ISD::ANDNP;
+  auto getFoldableLogicOp = [](SDValue Op) {
+    // Peek through single use bitcast.
+    if (Op.getOpcode() == ISD::BITCAST && Op.hasOneUse())
+      Op = Op.getOperand(0);
+
+    if (!Op.hasOneUse())
+      return SDValue();
+
+    unsigned Opc = Op.getOpcode();
+    if (Opc == ISD::AND || Opc == ISD::OR || Opc == ISD::XOR ||
+        Opc == X86ISD::ANDNP)
+      return Op;
+
+    return SDValue();
   };
 
-  SDValue A, B, C;
-  unsigned Opc2;
-  if (isLogicOp(N1.getOpcode()) && N1.hasOneUse()) {
-    Opc2 = N1.getOpcode();
+  SDValue A, FoldableOp;
+  if ((FoldableOp = getFoldableLogicOp(N1))) {
     A = N0;
-    B = N1.getOperand(0);
-    C = N1.getOperand(1);
-  } else if (isLogicOp(N0.getOpcode()) && N0.hasOneUse()) {
-    Opc2 = N0.getOpcode();
+  } else if ((FoldableOp = getFoldableLogicOp(N0))) {
     A = N1;
-    B = N0.getOperand(0);
-    C = N0.getOperand(1);
   } else
     return false;
 
+  SDValue B = FoldableOp.getOperand(0);
+  SDValue C = FoldableOp.getOperand(1);
+
+  unsigned Opc1 = N->getOpcode();
+  unsigned Opc2 = FoldableOp.getOpcode();
+
   uint64_t Imm;
   switch (Opc1) {
   default: llvm_unreachable("Unexpected opcode!");
@@ -3996,11 +4005,117 @@ bool X86DAGToDAGISel::tryVPTERNLOG(SDNode *N) {
     break;
   }
 
+  auto tryFoldLoadOrBCast =
+      [this](SDNode *Root, SDNode *P, SDValue &L, SDValue &Base, SDValue &Scale,
+             SDValue &Index, SDValue &Disp, SDValue &Segment) {
+        if (tryFoldLoad(Root, P, L, Base, Scale, Index, Disp, Segment))
+          return true;
+
+        // Not a load, check for broadcast which may be behind a bitcast.
+        if (L.getOpcode() == ISD::BITCAST && L.hasOneUse()) {
+          P = L.getNode();
+          L = L.getOperand(0);
+        }
+
+        if (L.getOpcode() != X86ISD::VBROADCAST_LOAD)
+          return false;
+
+        // Only 32 and 64 bit broadcasts are supported.
+        auto *MemIntr = cast<MemIntrinsicSDNode>(L);
+        unsigned Size = MemIntr->getMemoryVT().getSizeInBits();
+        if (Size != 32 && Size != 64)
+          return false;
+
+        return tryFoldBroadcast(Root, P, L, Base, Scale, Index, Disp, Segment);
+      };
+
+  bool FoldedLoad = false;
+  SDValue Tmp0, Tmp1, Tmp2, Tmp3, Tmp4;
+  if (tryFoldLoadOrBCast(N, FoldableOp.getNode(), C, Tmp0, Tmp1, Tmp2, Tmp3,
+                         Tmp4)) {
+    FoldedLoad = true;
+  } else if (tryFoldLoadOrBCast(N, N, A, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4)) {
+    FoldedLoad = true;
+    std::swap(A, C);
+    // Swap bits 1/4 and 3/6.
+    uint8_t OldImm = Imm;
+    Imm = OldImm & 0xa5;
+    if (OldImm & 0x02) Imm |= 0x10;
+    if (OldImm & 0x10) Imm |= 0x02;
+    if (OldImm & 0x08) Imm |= 0x40;
+    if (OldImm & 0x40) Imm |= 0x08;
+  } else if (tryFoldLoadOrBCast(N, FoldableOp.getNode(), B, Tmp0, Tmp1, Tmp2,
+                                Tmp3, Tmp4)) {
+    FoldedLoad = true;
+    std::swap(B, C);
+    // Swap bits 1/2 and 5/6.
+    uint8_t OldImm = Imm;
+    Imm = OldImm & 0x99;
+    if (OldImm & 0x02) Imm |= 0x04;
+    if (OldImm & 0x04) Imm |= 0x02;
+    if (OldImm & 0x20) Imm |= 0x40;
+    if (OldImm & 0x40) Imm |= 0x20;
+  }
+
   SDLoc DL(N);
-  SDValue New = CurDAG->getNode(X86ISD::VPTERNLOG, DL, NVT, A, B, C,
-                                CurDAG->getTargetConstant(Imm, DL, MVT::i8));
-  ReplaceNode(N, New.getNode());
-  SelectCode(New.getNode());
+
+  SDValue TImm = CurDAG->getTargetConstant(Imm, DL, MVT::i8);
+
+  MachineSDNode *MNode;
+  if (FoldedLoad) {
+    SDVTList VTs = CurDAG->getVTList(NVT, MVT::Other);
+
+    unsigned Opc;
+    if (C.getOpcode() == X86ISD::VBROADCAST_LOAD) {
+      auto *MemIntr = cast<MemIntrinsicSDNode>(C);
+      unsigned EltSize = MemIntr->getMemoryVT().getSizeInBits();
+      assert((EltSize == 32 || EltSize == 64) && "Unexpected broadcast size!");
+
+      bool UseD = EltSize == 32;
+      if (NVT.is128BitVector())
+        Opc = UseD ? X86::VPTERNLOGDZ128rmbi : X86::VPTERNLOGQZ128rmbi;
+      else if (NVT.is256BitVector())
+        Opc = UseD ? X86::VPTERNLOGDZ256rmbi : X86::VPTERNLOGQZ256rmbi;
+      else if (NVT.is512BitVector())
+        Opc = UseD ? X86::VPTERNLOGDZrmbi : X86::VPTERNLOGQZrmbi;
+      else
+        llvm_unreachable("Unexpected vector size!");
+    } else {
+      bool UseD = NVT.getVectorElementType() == MVT::i32;
+      if (NVT.is128BitVector())
+        Opc = UseD ? X86::VPTERNLOGDZ128rmi : X86::VPTERNLOGQZ128rmi;
+      else if (NVT.is256BitVector())
+        Opc = UseD ? X86::VPTERNLOGDZ256rmi : X86::VPTERNLOGQZ256rmi;
+      else if (NVT.is512BitVector())
+        Opc = UseD ? X86::VPTERNLOGDZrmi : X86::VPTERNLOGQZrmi;
+      else
+        llvm_unreachable("Unexpected vector size!");
+    }
+
+    SDValue Ops[] = {A, B, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4, TImm, C.getOperand(0)};
+    MNode = CurDAG->getMachineNode(Opc, DL, VTs, Ops);
+
+    // Update the chain.
+    ReplaceUses(C.getValue(1), SDValue(MNode, 1));
+    // Record the mem-refs
+    CurDAG->setNodeMemRefs(MNode, {cast<MemSDNode>(C)->getMemOperand()});
+  } else {
+    bool UseD = NVT.getVectorElementType() == MVT::i32;
+    unsigned Opc;
+    if (NVT.is128BitVector())
+      Opc = UseD ? X86::VPTERNLOGDZ128rri : X86::VPTERNLOGQZ128rri;
+    else if (NVT.is256BitVector())
+      Opc = UseD ? X86::VPTERNLOGDZ256rri : X86::VPTERNLOGQZ256rri;
+    else if (NVT.is512BitVector())
+      Opc = UseD ? X86::VPTERNLOGDZrri : X86::VPTERNLOGQZrri;
+    else
+      llvm_unreachable("Unexpected vector size!");
+
+    MNode = CurDAG->getMachineNode(Opc, DL, NVT, {A, B, C, TImm});
+  }
+
+  ReplaceUses(SDValue(N, 0), SDValue(MNode, 0));
+  CurDAG->RemoveDeadNode(N);
   return true;
 }
 

diff  --git a/llvm/test/CodeGen/X86/avx512-logic.ll b/llvm/test/CodeGen/X86/avx512-logic.ll
index 30607214f56d..24e58149eb4c 100644
--- a/llvm/test/CodeGen/X86/avx512-logic.ll
+++ b/llvm/test/CodeGen/X86/avx512-logic.ll
@@ -887,34 +887,20 @@ define <16 x i32> @ternlog_xor_andn(<16 x i32> %x, <16 x i32> %y, <16 x i32> %z)
 }
 
 define <16 x i32> @ternlog_or_and_mask(<16 x i32> %x, <16 x i32> %y) {
-; KNL-LABEL: ternlog_or_and_mask:
-; KNL:       ## %bb.0:
-; KNL-NEXT:    vpandq {{.*}}(%rip), %zmm0, %zmm0
-; KNL-NEXT:    vpord %zmm1, %zmm0, %zmm0
-; KNL-NEXT:    retq
-;
-; SKX-LABEL: ternlog_or_and_mask:
-; SKX:       ## %bb.0:
-; SKX-NEXT:    vandps {{.*}}(%rip), %zmm0, %zmm0
-; SKX-NEXT:    vorps %zmm1, %zmm0, %zmm0
-; SKX-NEXT:    retq
+; ALL-LABEL: ternlog_or_and_mask:
+; ALL:       ## %bb.0:
+; ALL-NEXT:    vpternlogd $236, {{.*}}(%rip), %zmm1, %zmm0
+; ALL-NEXT:    retq
   %a = and <16 x i32> %x, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255>
   %b = or <16 x i32> %a, %y
   ret <16 x i32> %b
 }
 
 define <8 x i64> @ternlog_xor_and_mask(<8 x i64> %x, <8 x i64> %y) {
-; KNL-LABEL: ternlog_xor_and_mask:
-; KNL:       ## %bb.0:
-; KNL-NEXT:    vpandd {{.*}}(%rip), %zmm0, %zmm0
-; KNL-NEXT:    vpxorq %zmm1, %zmm0, %zmm0
-; KNL-NEXT:    retq
-;
-; SKX-LABEL: ternlog_xor_and_mask:
-; SKX:       ## %bb.0:
-; SKX-NEXT:    vandps {{.*}}(%rip), %zmm0, %zmm0
-; SKX-NEXT:    vxorps %zmm1, %zmm0, %zmm0
-; SKX-NEXT:    retq
+; ALL-LABEL: ternlog_xor_and_mask:
+; ALL:       ## %bb.0:
+; ALL-NEXT:    vpternlogq $108, {{.*}}(%rip), %zmm1, %zmm0
+; ALL-NEXT:    retq
   %a = and <8 x i64> %x, <i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295>
   %b = xor <8 x i64> %a, %y
   ret <8 x i64> %b

diff  --git a/llvm/test/CodeGen/X86/avx512vl-logic.ll b/llvm/test/CodeGen/X86/avx512vl-logic.ll
index 3f0ce3092847..13c4c8afb9a8 100644
--- a/llvm/test/CodeGen/X86/avx512vl-logic.ll
+++ b/llvm/test/CodeGen/X86/avx512vl-logic.ll
@@ -991,8 +991,7 @@ define <4 x i32> @ternlog_xor_andn(<4 x i32> %x, <4 x i32> %y, <4 x i32> %z) {
 define <4 x i32> @ternlog_or_and_mask(<4 x i32> %x, <4 x i32> %y) {
 ; CHECK-LABEL: ternlog_or_and_mask:
 ; CHECK:       ## %bb.0:
-; CHECK-NEXT:    vandps {{.*}}(%rip), %xmm0, %xmm0
-; CHECK-NEXT:    vorps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vpternlogd $236, {{.*}}(%rip), %xmm1, %xmm0
 ; CHECK-NEXT:    retq
   %a = and <4 x i32> %x, <i32 255, i32 255, i32 255, i32 255>
   %b = or <4 x i32> %a, %y
@@ -1002,8 +1001,7 @@ define <4 x i32> @ternlog_or_and_mask(<4 x i32> %x, <4 x i32> %y) {
 define <8 x i32> @ternlog_or_and_mask_ymm(<8 x i32> %x, <8 x i32> %y) {
 ; CHECK-LABEL: ternlog_or_and_mask_ymm:
 ; CHECK:       ## %bb.0:
-; CHECK-NEXT:    vandps {{.*}}(%rip), %ymm0, %ymm0
-; CHECK-NEXT:    vorps %ymm1, %ymm0, %ymm0
+; CHECK-NEXT:    vpternlogd $236, {{.*}}(%rip), %ymm1, %ymm0
 ; CHECK-NEXT:    retq
   %a = and <8 x i32> %x, <i32 -16777216, i32 -16777216, i32 -16777216, i32 -16777216, i32 -16777216, i32 -16777216, i32 -16777216, i32 -16777216>
   %b = or <8 x i32> %a, %y
@@ -1013,8 +1011,7 @@ define <8 x i32> @ternlog_or_and_mask_ymm(<8 x i32> %x, <8 x i32> %y) {
 define <2 x i64> @ternlog_xor_and_mask(<2 x i64> %x, <2 x i64> %y) {
 ; CHECK-LABEL: ternlog_xor_and_mask:
 ; CHECK:       ## %bb.0:
-; CHECK-NEXT:    vandps {{.*}}(%rip), %xmm0, %xmm0
-; CHECK-NEXT:    vxorps %xmm1, %xmm0, %xmm0
+; CHECK-NEXT:    vpternlogq $108, {{.*}}(%rip), %xmm1, %xmm0
 ; CHECK-NEXT:    retq
   %a = and <2 x i64> %x, <i64 1099511627775, i64 1099511627775>
   %b = xor <2 x i64> %a, %y
@@ -1024,8 +1021,7 @@ define <2 x i64> @ternlog_xor_and_mask(<2 x i64> %x, <2 x i64> %y) {
 define <4 x i64> @ternlog_xor_and_mask_ymm(<4 x i64> %x, <4 x i64> %y) {
 ; CHECK-LABEL: ternlog_xor_and_mask_ymm:
 ; CHECK:       ## %bb.0:
-; CHECK-NEXT:    vandps {{.*}}(%rip), %ymm0, %ymm0
-; CHECK-NEXT:    vxorps %ymm1, %ymm0, %ymm0
+; CHECK-NEXT:    vpternlogq $108, {{.*}}(%rip), %ymm1, %ymm0
 ; CHECK-NEXT:    retq
   %a = and <4 x i64> %x, <i64 72057594037927935, i64 72057594037927935, i64 72057594037927935, i64 72057594037927935>
   %b = xor <4 x i64> %a, %y


        


More information about the llvm-commits mailing list