[llvm] [AArch64][Codegen]Transform saturating smull to sqdmulh (PR #143671)
David Green via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 2 23:55:27 PDT 2025
================
@@ -20918,6 +20929,90 @@ static SDValue performBuildVectorCombine(SDNode *N,
return SDValue();
}
+// A special combine for the sqdmulh family of instructions.
+// smin( sra ( mul( sext v0, sext v1 ) ), SHIFT_AMOUNT ),
+// SATURATING_VAL ) can be reduced to sqdmulh(...)
+static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
+
+ if (N->getOpcode() != ISD::SMIN)
+ return SDValue();
+
+ EVT VT = N->getValueType(0);
+
+ if (!VT.isVector() || VT.getScalarSizeInBits() > 64)
+ return SDValue();
+
+ ConstantSDNode *Clamp = isConstOrConstSplat(N->getOperand(1));
+
+ if (!Clamp)
+ return SDValue();
+
+ MVT ScalarType;
+ unsigned ShiftAmt = 0;
+ switch (Clamp->getSExtValue()) {
+ case (1ULL << 15) - 1:
+ ScalarType = MVT::i16;
+ ShiftAmt = 16;
+ break;
+ case (1ULL << 31) - 1:
+ ScalarType = MVT::i32;
+ ShiftAmt = 32;
+ break;
+ default:
+ return SDValue();
+ }
+
+ SDValue Sra = N->getOperand(0);
+ if (Sra.getOpcode() != ISD::SRA || !Sra.hasOneUse())
+ return SDValue();
+
+ ConstantSDNode *RightShiftVec = isConstOrConstSplat(Sra.getOperand(1));
+ if (!RightShiftVec)
+ return SDValue();
+ unsigned SExtValue = RightShiftVec->getSExtValue();
+
+ if (SExtValue != (ShiftAmt - 1))
+ return SDValue();
+
+ SDValue Mul = Sra.getOperand(0);
+ if (Mul.getOpcode() != ISD::MUL)
+ return SDValue();
+
+ SDValue SExt0 = Mul.getOperand(0);
+ SDValue SExt1 = Mul.getOperand(1);
+
+ EVT SExt0Type = SExt0.getOperand(0).getValueType();
+ EVT SExt1Type = SExt1.getOperand(0).getValueType();
+
+ if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
+ SExt1.getOpcode() != ISD::SIGN_EXTEND || SExt0Type != SExt1Type ||
+ SExt0Type.getScalarType() != ScalarType ||
+ SExt0Type.getFixedSizeInBits() > 128)
+ return SDValue();
+
+ // Source vectors with width < 64 are illegal and will need to be extended
+ unsigned SourceVectorWidth = SExt0Type.getFixedSizeInBits();
+ SDValue V0 = (SourceVectorWidth < 64) ? SExt0 : SExt0.getOperand(0);
+ SDValue V1 = (SourceVectorWidth < 64) ? SExt1 : SExt1.getOperand(0);
+
+ SDLoc DL(N);
+ SDValue SQDMULH =
+ DAG.getNode(AArch64ISD::SQDMULH, DL, V0.getValueType(), V0, V1);
+ EVT DestVT = N->getValueType(0);
+ if (DestVT.getScalarSizeInBits() > SExt0Type.getScalarSizeInBits())
+ return DAG.getNode(ISD::SIGN_EXTEND, DL, DestVT, SQDMULH);
----------------
davemgreen wrote:
I think it should always return `sext(SQDMULH)` or else the transform wasn't valid (there were not enough bits to prevent wrapping).
https://github.com/llvm/llvm-project/pull/143671
More information about the llvm-commits
mailing list