[llvm] [X86][ISel] Improve logic for optimizing `movmsk(bitcast(shuffle(x)))` (PR #68369)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 6 02:25:08 PDT 2023


================
@@ -45836,18 +45836,52 @@ static SDValue combineSetCCMOVMSK(SDValue EFLAGS, X86::CondCode &CC,
   // MOVMSK(SHUFFLE(X,u)) -> MOVMSK(X) iff every element is referenced.
   SmallVector<int, 32> ShuffleMask;
   SmallVector<SDValue, 2> ShuffleInputs;
+  SDValue BaseVec = peekThroughBitcasts(Vec);
   if (NumElts <= CmpBits &&
-      getTargetShuffleInputs(peekThroughBitcasts(Vec), ShuffleInputs,
-                             ShuffleMask, DAG) &&
+      getTargetShuffleInputs(BaseVec, ShuffleInputs, ShuffleMask, DAG) &&
       ShuffleInputs.size() == 1 && !isAnyZeroOrUndef(ShuffleMask) &&
       ShuffleInputs[0].getValueSizeInBits() == VecVT.getSizeInBits()) {
----------------
RKSimon wrote:

I think all we need to do is to add a scaleShuffleElements(ShuffleMask, NumElts, ScaledMask) check here to ensure that the shuffle mask can be scaled back to the original mask width:
```cpp
  // MOVMSK(SHUFFLE(X,u)) -> MOVMSK(X) iff every element is referenced.
  SmallVector<int, 32> ShuffleMask, ScaledMask;
  SmallVector<SDValue, 2> ShuffleInputs;
  if (NumElts <= CmpBits &&
      getTargetShuffleInputs(peekThroughBitcasts(Vec), ShuffleInputs,
                             ShuffleMask, DAG) &&
      ShuffleInputs.size() == 1 && !isAnyZeroOrUndef(ShuffleMask) &&
      ShuffleInputs[0].getValueSizeInBits() == VecVT.getSizeInBits() &&
      scaleShuffleElements(ShuffleMask, NumElts, ScaledMask)) {
    APInt DemandedElts = APInt::getZero(NumElts);
    for (int M : ScaledMask) {
      assert(0 <= M && M < (int)NumElts && "Bad unary shuffle index");
      DemandedElts.setBit(M);
    }
    if (DemandedElts.isAllOnes()) {
      SDLoc DL(EFLAGS);
      SDValue Result = DAG.getBitcast(VecVT, ShuffleInputs[0]);
      Result = DAG.getNode(X86ISD::MOVMSK, DL, MVT::i32, Result);
      Result =
          DAG.getZExtOrTrunc(Result, DL, EFLAGS.getOperand(0).getValueType());
      return DAG.getNode(X86ISD::CMP, DL, MVT::i32, Result,
                         EFLAGS.getOperand(1));
    }
  }
```

https://github.com/llvm/llvm-project/pull/68369


More information about the llvm-commits mailing list