[llvm] [SelectionDAG][x86] Ensure vector reduction optimization (PR #144231)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 17 11:16:46 PDT 2025
Suhajda =?utf-8?q?Tamás?= <sutajo at gmail.com>,
Suhajda =?utf-8?q?Tamás?= <sutajo at gmail.com>,
Suhajda =?utf-8?q?Tamás?= <sutajo at gmail.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/144231 at github.com>
================
@@ -25409,6 +25423,94 @@ static SDValue LowerEXTEND_VECTOR_INREG(SDValue Op,
return SignExt;
}
+// Create a min/max v8i16/v16i8 horizontal reduction with PHMINPOSUW.
+static SDValue createMinMaxReduction(SDValue Src, EVT TargetVT, SDLoc DL,
+ ISD::NodeType BinOp, SelectionDAG &DAG,
+ const X86Subtarget &Subtarget)
+{
+ assert(Subtarget.hasSSE41() && "The caller must check if SSE4.1 is available");
+
+ EVT SrcVT = Src.getValueType();
+ EVT SrcSVT = SrcVT.getScalarType();
+
+ if (SrcSVT != TargetVT || (SrcVT.getSizeInBits() % 128) != 0)
+ return SDValue();
+
+ // First, reduce the source down to 128-bit, applying BinOp to lo/hi.
+ while (SrcVT.getSizeInBits() > 128) {
+ SDValue Lo, Hi;
+ std::tie(Lo, Hi) = splitVector(Src, DAG, DL);
+ SrcVT = Lo.getValueType();
+ Src = DAG.getNode(BinOp, DL, SrcVT, Lo, Hi);
+ }
+ assert(((SrcVT == MVT::v8i16 && TargetVT == MVT::i16) ||
+ (SrcVT == MVT::v16i8 && TargetVT == MVT::i8)) &&
+ "Unexpected value type");
+
+ // PHMINPOSUW applies to UMIN(v8i16), for SMIN/SMAX/UMAX we must apply a mask
+ // to flip the value accordingly.
+ SDValue Mask;
+ unsigned MaskEltsBits = TargetVT.getSizeInBits();
+ if (BinOp == ISD::SMAX)
+ Mask = DAG.getConstant(APInt::getSignedMaxValue(MaskEltsBits), DL, SrcVT);
+ else if (BinOp == ISD::SMIN)
+ Mask = DAG.getConstant(APInt::getSignedMinValue(MaskEltsBits), DL, SrcVT);
+ else if (BinOp == ISD::UMAX)
+ Mask = DAG.getAllOnesConstant(DL, SrcVT);
+
+ if (Mask)
+ Src = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, Src);
+
+ // For v16i8 cases we need to perform UMIN on pairs of byte elements,
+ // shuffling each upper element down and insert zeros. This means that the
+ // v16i8 UMIN will leave the upper element as zero, performing zero-extension
+ // ready for the PHMINPOS.
+ if (TargetVT == MVT::i8) {
+ SDValue Upper = DAG.getVectorShuffle(
+ SrcVT, DL, Src, DAG.getConstant(0, DL, MVT::v16i8),
+ {1, 16, 3, 16, 5, 16, 7, 16, 9, 16, 11, 16, 13, 16, 15, 16});
+ Src = DAG.getNode(ISD::UMIN, DL, SrcVT, Src, Upper);
+ }
+
+ // Perform the PHMINPOS on a v8i16 vector,
+ Src = DAG.getBitcast(MVT::v8i16, Src);
+ Src = DAG.getNode(X86ISD::PHMINPOS, DL, MVT::v8i16, Src);
+ Src = DAG.getBitcast(SrcVT, Src);
+
+ if (Mask)
+ Src = DAG.getNode(ISD::XOR, DL, SrcVT, Mask, Src);
+
+ return DAG.getExtractVectorElt(DL, TargetVT, Src, 0);
+}
+
+static SDValue LowerVECTOR_REDUCE_MINMAX(SDValue Op,
+ const X86Subtarget& Subtarget,
+ SelectionDAG& DAG)
+{
+ ISD::NodeType BinOp;
+ switch (Op.getOpcode())
+ {
+ default:
+ assert(false && "Expected min/max reduction");
+ break;
+ case ISD::VECREDUCE_UMIN:
+ BinOp = ISD::UMIN;
+ break;
+ case ISD::VECREDUCE_UMAX:
+ BinOp = ISD::UMAX;
+ break;
+ case ISD::VECREDUCE_SMIN:
+ BinOp = ISD::SMIN;
+ break;
+ case ISD::VECREDUCE_SMAX:
+ BinOp = ISD::SMAX;
+ break;
+ }
----------------
RKSimon wrote:
replace this with `ISD::NodeType BinOp = ISD::getVecReduceBaseOpcode(Op.getOpcode())` - you can then assert that BinOp is a min/max opcode if you want.
https://github.com/llvm/llvm-project/pull/144231
More information about the llvm-commits
mailing list