[llvm] 99ae5c1 - [X86] Add 'getSplitVectorSrc' helper to determine if subvectors all come from the same source

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 26 07:17:28 PST 2022


Author: Simon Pilgrim
Date: 2022-01-26T15:17:21Z
New Revision: 99ae5c13f64e138d6b17c00bd01c87c3ce58cb6b

URL: https://github.com/llvm/llvm-project/commit/99ae5c13f64e138d6b17c00bd01c87c3ce58cb6b
DIFF: https://github.com/llvm/llvm-project/commit/99ae5c13f64e138d6b17c00bd01c87c3ce58cb6b.diff

LOG: [X86] Add 'getSplitVectorSrc' helper to determine if subvectors all come from the same source

Helps determine if the subvector ops come from the same larger vector and match the lower/upper extractions

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b25a170a683ab..ba606d7a80edb 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -6146,6 +6146,29 @@ static SDValue getZeroVector(MVT VT, const X86Subtarget &Subtarget,
   return DAG.getBitcast(VT, Vec);
 }
 
+// Helper to determine if the ops are all the extracted subvectors come from a
+// single source. If we allow commute they don't have to be in order (Lo/Hi).
+static SDValue getSplitVectorSrc(SDValue LHS, SDValue RHS, bool AllowCommute) {
+  if (LHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      RHS.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
+      LHS.getValueType() != RHS.getValueType() ||
+      LHS.getOperand(0) != RHS.getOperand(0))
+    return SDValue();
+
+  SDValue Src = LHS.getOperand(0);
+  if (Src.getValueSizeInBits() != (LHS.getValueSizeInBits() * 2))
+    return SDValue();
+
+  unsigned NumElts = LHS.getValueType().getVectorNumElements();
+  if ((LHS.getConstantOperandAPInt(1) == 0 &&
+       RHS.getConstantOperandAPInt(1) == NumElts) ||
+      (AllowCommute && RHS.getConstantOperandAPInt(1) == 0 &&
+       LHS.getConstantOperandAPInt(1) == NumElts))
+    return Src;
+
+  return SDValue();
+}
+
 static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG,
                                 const SDLoc &dl, unsigned vectorWidth) {
   EVT VT = Vec.getValueType();
@@ -44512,30 +44535,28 @@ static SDValue combineSetCCMOVMSK(SDValue EFLAGS, X86::CondCode &CC,
     // PMOVMSKB(PACKSSBW(LO(X), HI(X)))
     // -> PMOVMSKB(BITCAST_v32i8(X)) & 0xAAAAAAAA.
     if (CmpBits >= 16 && Subtarget.hasInt256() &&
-        VecOp0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
-        VecOp1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
-        VecOp0.getOperand(0) == VecOp1.getOperand(0) &&
-        VecOp0.getConstantOperandAPInt(1) == 0 &&
-        VecOp1.getConstantOperandAPInt(1) == 8 &&
         (IsAnyOf || (SignExt0 && SignExt1))) {
-      SDLoc DL(EFLAGS);
-      SDValue Result = peekThroughBitcasts(VecOp0.getOperand(0));
-      if (IsAllOf && Result.getOpcode() == X86ISD::PCMPEQ) {
-        SDValue V = DAG.getNode(ISD::SUB, DL, Result.getValueType(),
-                                Result.getOperand(0), Result.getOperand(1));
-        V = DAG.getBitcast(MVT::v4i64, V);
-        return DAG.getNode(X86ISD::PTEST, SDLoc(EFLAGS), MVT::i32, V, V);
-      }
-      Result = DAG.getBitcast(MVT::v32i8, Result);
-      Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
-      unsigned CmpMask = IsAnyOf ? 0 : 0xFFFFFFFF;
-      if (!SignExt0 || !SignExt1) {
-        assert(IsAnyOf && "Only perform v16i16 signmasks for any_of patterns");
-        Result = DAG.getNode(ISD::AND, DL, MVT::i32, Result,
-                             DAG.getConstant(0xAAAAAAAA, DL, MVT::i32));
+      if (SDValue Src = getSplitVectorSrc(VecOp0, VecOp1, true)) {
+        SDLoc DL(EFLAGS);
+        SDValue Result = peekThroughBitcasts(Src);
+        if (IsAllOf && Result.getOpcode() == X86ISD::PCMPEQ) {
+          SDValue V = DAG.getNode(ISD::SUB, DL, Result.getValueType(),
+                                  Result.getOperand(0), Result.getOperand(1));
+          V = DAG.getBitcast(MVT::v4i64, V);
+          return DAG.getNode(X86ISD::PTEST, SDLoc(EFLAGS), MVT::i32, V, V);
+        }
+        Result = DAG.getBitcast(MVT::v32i8, Result);
+        Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
+        unsigned CmpMask = IsAnyOf ? 0 : 0xFFFFFFFF;
+        if (!SignExt0 || !SignExt1) {
+          assert(IsAnyOf &&
+                 "Only perform v16i16 signmasks for any_of patterns");
+          Result = DAG.getNode(ISD::AND, DL, MVT::i32, Result,
+                               DAG.getConstant(0xAAAAAAAA, DL, MVT::i32));
+        }
+        return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result,
+                           DAG.getConstant(CmpMask, DL, MVT::i32));
       }
-      return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result,
-                         DAG.getConstant(CmpMask, DL, MVT::i32));
     }
   }
 
@@ -45582,33 +45603,28 @@ static SDValue combineHorizOpWithShuffle(SDNode *N, SelectionDAG &DAG,
   // truncation trees that help us avoid lane crossing shuffles.
   // TODO: There's a lot more we can do for PACK/HADD style shuffle combines.
   // TODO: We don't handle vXf64 shuffles yet.
-  if (VT.is128BitVector() && SrcVT.getScalarSizeInBits() <= 32 &&
-      BC0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
-      BC1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
-      BC0.getOperand(0) == BC1.getOperand(0) &&
-      BC0.getOperand(0).getValueType().is256BitVector() &&
-      BC0.getConstantOperandAPInt(1) == 0 &&
-      BC1.getConstantOperandAPInt(1) ==
-          BC0.getValueType().getVectorNumElements()) {
-    SmallVector<SDValue> ShuffleOps;
-    SmallVector<int> ShuffleMask, ScaledMask;
-    SDValue Vec = peekThroughBitcasts(BC0.getOperand(0));
-    if (getTargetShuffleInputs(Vec, ShuffleOps, ShuffleMask, DAG)) {
-      resolveTargetShuffleInputsAndMask(ShuffleOps, ShuffleMask);
-      // To keep the HOP LHS/RHS coherency, we must be able to scale the unary
-      // shuffle to a v4X64 width - we can probably relax this in the future.
-      if (!isAnyZero(ShuffleMask) && ShuffleOps.size() == 1 &&
-          ShuffleOps[0].getValueType().is256BitVector() &&
-          scaleShuffleElements(ShuffleMask, 4, ScaledMask)) {
-        SDValue Lo, Hi;
-        MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f32 : MVT::v4i32;
-        std::tie(Lo, Hi) = DAG.SplitVector(ShuffleOps[0], DL);
-        Lo = DAG.getBitcast(SrcVT, Lo);
-        Hi = DAG.getBitcast(SrcVT, Hi);
-        SDValue Res = DAG.getNode(Opcode, DL, VT, Lo, Hi);
-        Res = DAG.getBitcast(ShufVT, Res);
-        Res = DAG.getVectorShuffle(ShufVT, DL, Res, Res, ScaledMask);
-        return DAG.getBitcast(VT, Res);
+  if (VT.is128BitVector() && SrcVT.getScalarSizeInBits() <= 32) {
+    if (SDValue BCSrc = getSplitVectorSrc(BC0, BC1, false)) {
+      SmallVector<SDValue> ShuffleOps;
+      SmallVector<int> ShuffleMask, ScaledMask;
+      SDValue Vec = peekThroughBitcasts(BCSrc);
+      if (getTargetShuffleInputs(Vec, ShuffleOps, ShuffleMask, DAG)) {
+        resolveTargetShuffleInputsAndMask(ShuffleOps, ShuffleMask);
+        // To keep the HOP LHS/RHS coherency, we must be able to scale the unary
+        // shuffle to a v4X64 width - we can probably relax this in the future.
+        if (!isAnyZero(ShuffleMask) && ShuffleOps.size() == 1 &&
+            ShuffleOps[0].getValueType().is256BitVector() &&
+            scaleShuffleElements(ShuffleMask, 4, ScaledMask)) {
+          SDValue Lo, Hi;
+          MVT ShufVT = VT.isFloatingPoint() ? MVT::v4f32 : MVT::v4i32;
+          std::tie(Lo, Hi) = DAG.SplitVector(ShuffleOps[0], DL);
+          Lo = DAG.getBitcast(SrcVT, Lo);
+          Hi = DAG.getBitcast(SrcVT, Hi);
+          SDValue Res = DAG.getNode(Opcode, DL, VT, Lo, Hi);
+          Res = DAG.getBitcast(ShufVT, Res);
+          Res = DAG.getVectorShuffle(ShufVT, DL, Res, Res, ScaledMask);
+          return DAG.getBitcast(VT, Res);
+        }
       }
     }
   }


        


More information about the llvm-commits mailing list