[llvm] [AArch64][Codegen]Transform saturating smull to sqdmulh (PR #143671)

David Green via llvm-commits llvm-commits at lists.llvm.org
Sun Jun 29 13:24:28 PDT 2025


================
@@ -20717,6 +20717,83 @@ static SDValue performBuildVectorCombine(SDNode *N,
   return SDValue();
 }
 
+// A special combine for the vqdmulh family of instructions.
+// smin( sra ( mul( sext v0, sext v1 ) ), SHIFT_AMOUNT ),
+// SATURATING_VAL ) can be reduced to sext(sqdmulh(...))
+static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
+
+  if (N->getOpcode() != ISD::TRUNCATE)
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+
+  if (!VT.isVector() || VT.getScalarSizeInBits() > 64)
+    return SDValue();
+
+  SDValue SMin = N->getOperand(0);
+
+  if (SMin.getOpcode() != ISD::SMIN)
+    return SDValue();
+
+  ConstantSDNode *Clamp = isConstOrConstSplat(SMin.getOperand(1));
+
+  if (!Clamp)
+    return SDValue();
+
+  MVT ScalarType;
+  unsigned ShiftAmt = 0;
+  switch (Clamp->getSExtValue()) {
+  case (1ULL << 15) - 1:
+    ScalarType = MVT::i16;
+    ShiftAmt = 16;
+    break;
+  case (1ULL << 31) - 1:
+    ScalarType = MVT::i32;
+    ShiftAmt = 32;
+    break;
+  default:
+    return SDValue();
+  }
+
+  SDValue Sra = SMin.getOperand(0);
+  if (Sra.getOpcode() != ISD::SRA)
+    return SDValue();
+
+  ConstantSDNode *RightShiftVec = isConstOrConstSplat(Sra.getOperand(1));
+  if (!RightShiftVec)
+    return SDValue();
+  unsigned SExtValue = RightShiftVec->getSExtValue();
+
+  if (SExtValue != (ShiftAmt - 1))
+    return SDValue();
+
+  SDValue Mul = Sra.getOperand(0);
+  if (Mul.getOpcode() != ISD::MUL)
+    return SDValue();
+
+  SDValue SExt0 = Mul.getOperand(0);
+  SDValue SExt1 = Mul.getOperand(1);
+
+  if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
+      SExt1.getOpcode() != ISD::SIGN_EXTEND ||
+      SExt0.getValueType() != SExt1.getValueType())
+    return SDValue();
+
+  if ((ShiftAmt == 16 && (SExt0.getValueType() != MVT::v8i32 &&
+                          SExt0.getValueType() != MVT::v4i32)) ||
+      (ShiftAmt == 32 && (SExt0.getValueType() != MVT::v4i64 &&
+                          SExt0.getValueType() != MVT::v2i64)))
+    return SDValue();
+
+  SDValue V0 = SExt0.getOperand(0);
+  SDValue V1 = SExt1.getOperand(0);
+
+  SDLoc DL(SMin);
+  EVT VecVT = N->getValueType(0);
+  SDValue SQDMULH = DAG.getNode(AArch64ISD::SQDMULH, DL, VecVT, V0, V1);
+  return DAG.getNode(ISD::SIGN_EXTEND, DL, N->getValueType(0), SQDMULH);
----------------
davemgreen wrote:

This won't do anything at the moment, I don't think. The idea is that if we have `trunc(add(etc(sqdmulh(..))`, we can still recognise it without needing the trunc to be adjacent. The VecVT can come from SExt0.getOperand(0).getValueType().

https://github.com/llvm/llvm-project/pull/143671


More information about the llvm-commits mailing list