[llvm] [AArch64][Codegen]Transform saturating smull to sqdmulh (PR #143671)

David Green via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 27 02:08:40 PDT 2025


================
@@ -20717,6 +20717,83 @@ static SDValue performBuildVectorCombine(SDNode *N,
   return SDValue();
 }
 
+// A special combine for the vqdmulh family of instructions.
+// truncate( smin( sra ( mul( sext v0, sext v1 ) ), SHIFT_AMOUNT ),
+// SATURATING_VAL ) can be reduced to sqdmulh(...)
+static SDValue trySQDMULHCombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI,
+                                 SelectionDAG &DAG) {
+
+  if (N->getOpcode() != ISD::TRUNCATE)
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+
+  if (!VT.isVector() || VT.getScalarSizeInBits() > 64)
+    return SDValue();
+
+  SDValue SMin = N->getOperand(0);
+
+  if (SMin.getOpcode() != ISD::SMIN)
+    return SDValue();
+
+  ConstantSDNode *Clamp = isConstOrConstSplat(SMin.getOperand(1));
+
+  if (!Clamp)
+    return SDValue();
+
+  MVT ScalarType;
+  unsigned ShiftAmt = 0;
+  // Here we are considering clamped Arm Q format
+  // data types which use 2 upper bits, one for the
+  // integer part and one for the sign. We also consider
+  // standard signed integer types
+  switch (Clamp->getSExtValue()) {
+  case (1ULL << 14) - 1: // Q15 saturation
+  case (1ULL << 15) - 1:
+    ScalarType = MVT::i16;
+    ShiftAmt = 16;
+    break;
+  case (1ULL << 30) - 1: // Q31 saturation
+  case (1ULL << 31) - 1:
+    ScalarType = MVT::i32;
+    ShiftAmt = 32;
+    break;
+  default:
+    return SDValue();
+  }
+
+  SDValue Sra = SMin.getOperand(0);
+  if (Sra.getOpcode() != ISD::SRA)
+    return SDValue();
+
+  ConstantSDNode *RightShiftVec = isConstOrConstSplat(Sra.getOperand(1));
+  if (!RightShiftVec)
+    return SDValue();
+  unsigned SExtValue = RightShiftVec->getSExtValue();
+
+  if (SExtValue != (ShiftAmt - 1))
+    return SDValue();
+
+  SDValue Mul = Sra.getOperand(0);
+  if (Mul.getOpcode() != ISD::MUL)
+    return SDValue();
+
+  SDValue SExt0 = Mul.getOperand(0);
+  SDValue SExt1 = Mul.getOperand(1);
+
+  if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
+      SExt1.getOpcode() != ISD::SIGN_EXTEND)
----------------
davemgreen wrote:

Check for things like the types of the SExt's match, and that they are the legal types you expect them to be.

It can also be good to test with larger types like <8 x i32>, to see if they match correctly after type legalization. They are often more/different nodes but the important part is that they do not fail to select because we produced illegal nodes.

https://github.com/llvm/llvm-project/pull/143671


More information about the llvm-commits mailing list