[llvm] [SelectionDAG][x86] Ensure vector reduction optimization (PR #144231)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 17 11:16:46 PDT 2025


Suhajda =?utf-8?q?Tamás?= <sutajo at gmail.com>,
Suhajda =?utf-8?q?Tamás?= <sutajo at gmail.com>,
Suhajda =?utf-8?q?Tamás?= <sutajo at gmail.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/144231 at github.com>


================
@@ -25409,6 +25423,94 @@ static SDValue LowerEXTEND_VECTOR_INREG(SDValue Op,
   return SignExt;
 }
 
+// Create a min/max v8i16/v16i8 horizontal reduction with PHMINPOSUW.
+static SDValue createMinMaxReduction(SDValue Src, EVT TargetVT, SDLoc DL,
+                                     ISD::NodeType BinOp, SelectionDAG &DAG,
+                                     const X86Subtarget &Subtarget)
+{
+  assert(Subtarget.hasSSE41() && "The caller must check if SSE4.1 is available");
+
+  EVT SrcVT = Src.getValueType();
+  EVT SrcSVT = SrcVT.getScalarType();
+
+  if (SrcSVT != TargetVT || (SrcVT.getSizeInBits() % 128) != 0)
+    return SDValue();
+
+  // First, reduce the source down to 128-bit, applying BinOp to lo/hi.
+  while (SrcVT.getSizeInBits() > 128) {
+    SDValue Lo, Hi;
+    std::tie(Lo, Hi) = splitVector(Src, DAG, DL);
+    SrcVT = Lo.getValueType();
+    Src = DAG.getNode(BinOp, DL, SrcVT, Lo, Hi);
+  }
+  assert(((SrcVT == MVT::v8i16 && TargetVT == MVT::i16) ||
+          (SrcVT == MVT::v16i8 && TargetVT == MVT::i8)) &&
+         "Unexpected value type");
+
+  // PHMINPOSUW applies to UMIN(v8i16), for SMIN/SMAX/UMAX we must apply a mask
+  // to flip the value accordingly.
+  SDValue Mask;
+  unsigned MaskEltsBits = TargetVT.getSizeInBits();
+  if (BinOp == ISD::SMAX)
+    Mask = DAG.getConstant(APInt::getSignedMaxValue(MaskEltsBits), DL, SrcVT);
+  else if (BinOp == ISD::SMIN)
+    Mask = DAG.getConstant(APInt::getSignedMinValue(MaskEltsBits), DL, SrcVT);
+  else if (BinOp == ISD::UMAX)
+    Mask = DAG.getAllOnesConstant(DL, SrcVT);
+
+  if (Mask)
+    Src = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, Src);
+
+  // For v16i8 cases we need to perform UMIN on pairs of byte elements,
+  // shuffling each upper element down and insert zeros. This means that the
+  // v16i8 UMIN will leave the upper element as zero, performing zero-extension
+  // ready for the PHMINPOS.
+  if (TargetVT == MVT::i8) {
+    SDValue Upper = DAG.getVectorShuffle(
+        SrcVT, DL, Src, DAG.getConstant(0, DL, MVT::v16i8),
+        {1, 16, 3, 16, 5, 16, 7, 16, 9, 16, 11, 16, 13, 16, 15, 16});
+    Src = DAG.getNode(ISD::UMIN, DL, SrcVT, Src, Upper);
+  }
+
+  // Perform the PHMINPOS on a v8i16 vector,
+  Src = DAG.getBitcast(MVT::v8i16, Src);
+  Src = DAG.getNode(X86ISD::PHMINPOS, DL, MVT::v8i16, Src);
+  Src = DAG.getBitcast(SrcVT, Src);
+
+  if (Mask)
+    Src = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, Src);
+
+  return DAG.getExtractVectorElt(DL, TargetVT, Src, 0);
+}
+
+static SDValue LowerVECTOR_REDUCE_MINMAX(SDValue Op,
+    const X86Subtarget& Subtarget,
+    SelectionDAG& DAG)
+{
+  ISD::NodeType BinOp;
+  switch (Op.getOpcode())
+  {
+    default: 
+      assert(false && "Expected min/max reduction");
+      break;
+    case ISD::VECREDUCE_UMIN:
+      BinOp = ISD::UMIN;
+      break;
+    case ISD::VECREDUCE_UMAX:
+      BinOp = ISD::UMAX;
+      break;
+    case ISD::VECREDUCE_SMIN:
+      BinOp = ISD::SMIN;
+      break;
+    case ISD::VECREDUCE_SMAX:
+      BinOp = ISD::SMAX;
+      break;
+  }
----------------
RKSimon wrote:

replace this with `ISD::NodeType BinOp = ISD::getVecReduceBaseOpcode(Op.getOpcode())` - you can then assert that BinOp is a min/max opcode if you want.

https://github.com/llvm/llvm-project/pull/144231


More information about the llvm-commits mailing list