[llvm] [X86] Fold blend(pshufb(x,m1),pshufb(y,m2)) -> blend(pshufb(x,blend(m1,m2)),pshufb(y,blend(m1,m2))) to reduce constant pool (PR #98466)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 16 03:30:38 PDT 2024


================
@@ -41016,23 +41016,59 @@ static SDValue combineTargetShuffle(SDValue N, const SDLoc &DL,
   case X86ISD::BLENDI: {
     SDValue N0 = N.getOperand(0);
     SDValue N1 = N.getOperand(1);
-
-    // blend(bitcast(x),bitcast(y)) -> bitcast(blend(x,y)) to narrower types.
-    // TODO: Handle MVT::v16i16 repeated blend mask.
-    if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST &&
-        N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()) {
-      MVT SrcVT = N0.getOperand(0).getSimpleValueType();
-      if ((VT.getScalarSizeInBits() % SrcVT.getScalarSizeInBits()) == 0 &&
-          SrcVT.getScalarSizeInBits() >= 32) {
-        unsigned Size = VT.getVectorNumElements();
-        unsigned NewSize = SrcVT.getVectorNumElements();
-        APInt BlendMask = N.getConstantOperandAPInt(2).zextOrTrunc(Size);
-        APInt NewBlendMask = APIntOps::ScaleBitMask(BlendMask, NewSize);
-        return DAG.getBitcast(
-            VT, DAG.getNode(X86ISD::BLENDI, DL, SrcVT, N0.getOperand(0),
-                            N1.getOperand(0),
-                            DAG.getTargetConstant(NewBlendMask.getZExtValue(),
-                                                  DL, MVT::i8)));
+    unsigned EltBits = VT.getScalarSizeInBits();
+
+    if (N0.getOpcode() == ISD::BITCAST && N1.getOpcode() == ISD::BITCAST) {
+      // blend(bitcast(x),bitcast(y)) -> bitcast(blend(x,y)) to narrower types.
+      // TODO: Handle MVT::v16i16 repeated blend mask.
+      if (N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType()) {
+        MVT SrcVT = N0.getOperand(0).getSimpleValueType();
+        unsigned SrcBits = SrcVT.getScalarSizeInBits();
+        if ((EltBits % SrcBits) == 0 && SrcBits >= 32) {
+          unsigned Size = VT.getVectorNumElements();
+          unsigned NewSize = SrcVT.getVectorNumElements();
+          APInt BlendMask = N.getConstantOperandAPInt(2).zextOrTrunc(Size);
+          APInt NewBlendMask = APIntOps::ScaleBitMask(BlendMask, NewSize);
+          return DAG.getBitcast(
+              VT, DAG.getNode(X86ISD::BLENDI, DL, SrcVT, N0.getOperand(0),
+                              N1.getOperand(0),
+                              DAG.getTargetConstant(NewBlendMask.getZExtValue(),
+                                                    DL, MVT::i8)));
+        }
+      }
+      // Share PSHUFB masks:
+      // blend(pshufb(x,m1),pshufb(y,m2))
+      // --> m3 = blend(m1,m2)
+      //     blend(pshufb(x,m3),pshufb(y,m3))
+      if (N0.hasOneUse() && N1.hasOneUse()) {
+        SmallVector<int> Mask, ByteMask;
+        SmallVector<SDValue> Ops;
+        SDValue LHS = peekThroughOneUseBitcasts(N0);
+        SDValue RHS = peekThroughOneUseBitcasts(N1);
+        if (LHS.getOpcode() == X86ISD::PSHUFB &&
+            RHS.getOpcode() == X86ISD::PSHUFB &&
+            LHS.getOperand(1) != RHS.getOperand(1) &&
+            (LHS.getOperand(1).hasOneUse() || RHS.getOperand(1).hasOneUse()) &&
----------------
goldsteinn wrote:

Im basing this off something you said to me a while ago, but I thought we didn't track constant uses precisely.

https://github.com/llvm/llvm-project/pull/98466


More information about the llvm-commits mailing list