[llvm] [AArch64][Codegen]Transform saturating smull to sqdmulh (PR #143671)

Nashe Mncube via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 14 07:34:19 PDT 2025


================
@@ -20918,6 +20929,90 @@ static SDValue performBuildVectorCombine(SDNode *N,
   return SDValue();
 }
 
+// A special combine for the sqdmulh family of instructions.
+// smin( sra ( mul( sext v0, sext v1 ) ), SHIFT_AMOUNT ),
+// SATURATING_VAL ) can be reduced to sqdmulh(...)
+static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
+
+  if (N->getOpcode() != ISD::SMIN)
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+
+  if (!VT.isVector() || VT.getScalarSizeInBits() > 64)
+    return SDValue();
+
+  ConstantSDNode *Clamp = isConstOrConstSplat(N->getOperand(1));
+
+  if (!Clamp)
+    return SDValue();
+
+  MVT ScalarType;
+  unsigned ShiftAmt = 0;
+  switch (Clamp->getSExtValue()) {
+  case (1ULL << 15) - 1:
+    ScalarType = MVT::i16;
+    ShiftAmt = 16;
+    break;
+  case (1ULL << 31) - 1:
+    ScalarType = MVT::i32;
+    ShiftAmt = 32;
+    break;
+  default:
+    return SDValue();
+  }
+
+  SDValue Sra = N->getOperand(0);
+  if (Sra.getOpcode() != ISD::SRA || !Sra.hasOneUse())
+    return SDValue();
+
+  ConstantSDNode *RightShiftVec = isConstOrConstSplat(Sra.getOperand(1));
+  if (!RightShiftVec)
+    return SDValue();
+  unsigned SExtValue = RightShiftVec->getSExtValue();
+
+  if (SExtValue != (ShiftAmt - 1))
+    return SDValue();
+
+  SDValue Mul = Sra.getOperand(0);
+  if (Mul.getOpcode() != ISD::MUL)
+    return SDValue();
+
+  SDValue SExt0 = Mul.getOperand(0);
+  SDValue SExt1 = Mul.getOperand(1);
+
+  EVT SExt0Type = SExt0.getOperand(0).getValueType();
+  EVT SExt1Type = SExt1.getOperand(0).getValueType();
+
+  if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
+      SExt1.getOpcode() != ISD::SIGN_EXTEND || SExt0Type != SExt1Type ||
+      SExt0Type.getScalarType() != ScalarType ||
+      SExt0Type.getFixedSizeInBits() > 128)
+    return SDValue();
+
+  // Source vectors with width < 64 are illegal and will need to be extended
+  unsigned SourceVectorWidth = SExt0Type.getFixedSizeInBits();
+  SDValue V0 = (SourceVectorWidth < 64) ? SExt0 : SExt0.getOperand(0);
+  SDValue V1 = (SourceVectorWidth < 64) ? SExt1 : SExt1.getOperand(0);
----------------
nasherm wrote:

Yes. My latest patch checks the types directly and sign-extends to legal types. With a v2i48  as extension destination we get the following

```
    shl v0.2s, v0.2s, #16
    shl v1.2s, v1.2s, #16
    sshr v0.2s, v0.2s, #16
    sshr v1.2s, v1.2s, #16
    sqdmulh v0.2s, v1.2s, v0.2s
    ret
```

I also wanted to see what happens when the source vector before extension is illegal. The below code is generated for a source vector of v2i11 (which I've also added to the testcases):

```
    shl	v0.2s, v0.2s, #21
    shl	v1.2s, v1.2s, #21
    sshr	v0.2s, v0.2s, #21
    sshr	v1.2s, v1.2s, #21
    mul	v0.2s, v1.2s, v0.2s
    movi	v1.2s, #127, msl #8
    sshr	v0.2s, v0.2s, #15
    smin	v0.2s, v0.2s, v1.2s
    ret
```

https://github.com/llvm/llvm-project/pull/143671


More information about the llvm-commits mailing list