[llvm-branch-commits] [RISCV] Support memcmp expansion for vectors (PR #114517)

Craig Topper via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sun Nov 3 21:26:20 PST 2024


================
@@ -14474,17 +14475,116 @@ static bool narrowIndex(SDValue &N, ISD::MemIndexType IndexType, SelectionDAG &D
   return true;
 }
 
+/// Recursive helper for combineVectorSizedSetCCEquality() to see if we have a
+/// recognizable memcmp expansion.
+static bool isOrXorXorTree(SDValue X, bool Root = true) {
+  if (X.getOpcode() == ISD::OR)
+    return isOrXorXorTree(X.getOperand(0), false) &&
+           isOrXorXorTree(X.getOperand(1), false);
+  if (Root)
+    return false;
+  return X.getOpcode() == ISD::XOR;
+}
+
+/// Recursive helper for combineVectorSizedSetCCEquality() to emit the memcmp
+/// expansion.
+static SDValue emitOrXorXorTree(SDValue X, const SDLoc &DL, SelectionDAG &DAG,
+                                EVT VecVT, EVT CmpVT) {
+  SDValue Op0 = X.getOperand(0);
+  SDValue Op1 = X.getOperand(1);
+  if (X.getOpcode() == ISD::OR) {
+    SDValue A = emitOrXorXorTree(Op0, DL, DAG, VecVT, CmpVT);
+    SDValue B = emitOrXorXorTree(Op1, DL, DAG, VecVT, CmpVT);
+    if (VecVT != CmpVT)
+      return DAG.getNode(ISD::OR, DL, CmpVT, A, B);
+    return DAG.getNode(ISD::AND, DL, CmpVT, A, B);
+  }
+  if (X.getOpcode() == ISD::XOR) {
+    SDValue A = DAG.getBitcast(VecVT, Op0);
+    SDValue B = DAG.getBitcast(VecVT, Op1);
+    if (VecVT != CmpVT)
+      return DAG.getSetCC(DL, CmpVT, A, B, ISD::SETNE);
+    return DAG.getSetCC(DL, CmpVT, A, B, ISD::SETEQ);
+  }
+  llvm_unreachable("Impossible");
+}
+
+/// Try to map a 128-bit or larger integer comparison to vector instructions
+/// before type legalization splits it up into chunks.
+static SDValue
+combineVectorSizedSetCCEquality(EVT VT, SDValue X, SDValue Y, ISD::CondCode CC,
+                                const SDLoc &DL, SelectionDAG &DAG,
+                                const RISCVSubtarget &Subtarget) {
+  assert((CC == ISD::SETNE || CC == ISD::SETEQ) && "Bad comparison predicate");
+
+  EVT OpVT = X.getValueType();
+  MVT XLenVT = Subtarget.getXLenVT();
+  unsigned OpSize = OpVT.getSizeInBits();
+
+  // We're looking for an oversized integer equality comparison.
+  if (!Subtarget.hasVInstructions() || !OpVT.isScalarInteger() ||
+      OpSize < Subtarget.getRealMinVLen() ||
+      OpSize > Subtarget.getRealMinVLen() * 8)
+    return SDValue();
+
+  bool IsOrXorXorTreeCCZero = isNullConstant(Y) && isOrXorXorTree(X);
+  if (isNullConstant(Y) && !IsOrXorXorTreeCCZero)
+    return SDValue();
+
+  // Don't perform this combine if constructing the vector will be expensive.
+  auto IsVectorBitCastCheap = [](SDValue X) {
+    X = peekThroughBitcasts(X);
+    return isa<ConstantSDNode>(X) || X.getValueType().isVector() ||
+           X.getOpcode() == ISD::LOAD;
+  };
+  if ((!IsVectorBitCastCheap(X) || !IsVectorBitCastCheap(Y)) &&
+      !IsOrXorXorTreeCCZero)
+    return SDValue();
+
+  bool NoImplicitFloatOps =
+      DAG.getMachineFunction().getFunction().hasFnAttribute(
+          Attribute::NoImplicitFloat);
+  if (!NoImplicitFloatOps && Subtarget.hasVInstructions()) {
+    unsigned VecSize = OpSize / 8;
+    EVT VecVT = MVT::getVectorVT(MVT::i8, VecSize);
+    EVT CmpVT = MVT::getVectorVT(MVT::i1, VecSize);
+
+    SDValue Cmp;
+    if (IsOrXorXorTreeCCZero) {
+      Cmp = emitOrXorXorTree(X, DL, DAG, VecVT, CmpVT);
+    } else {
+      SDValue VecX = DAG.getBitcast(VecVT, X);
+      SDValue VecY = DAG.getBitcast(VecVT, Y);
+      Cmp = DAG.getSetCC(DL, CmpVT, VecX, VecY, ISD::SETEQ);
+    }
+    return DAG.getSetCC(DL, VT,
+                        DAG.getNode(ISD::VECREDUCE_AND, DL, XLenVT, Cmp),
+                        DAG.getConstant(0, DL, XLenVT), CC);
+  }
+
+  return SDValue();
+}
+
 // Replace (seteq (i64 (and X, 0xffffffff)), C1) with
 // (seteq (i64 (sext_inreg (X, i32)), C1')) where C1' is C1 sign extended from
 // bit 31. Same for setne. C1' may be cheaper to materialize and the sext_inreg
 // can become a sext.w instead of a shift pair.
 static SDValue performSETCCCombine(SDNode *N, SelectionDAG &DAG,
                                    const RISCVSubtarget &Subtarget) {
+  SDLoc dl(N);
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   EVT VT = N->getValueType(0);
   EVT OpVT = N0.getValueType();
 
+  // Looking for an equality compare.
+  ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
+  if (Cond == ISD::SETNE || Cond == ISD::SETEQ) {
----------------
topperc wrote:

Use `ISD::isIntEqualitySetCC(Cond)`

https://github.com/llvm/llvm-project/pull/114517


More information about the llvm-branch-commits mailing list