[llvm] de4b458 - [AArch64][Codegen]Transform saturating smull to sqdmulh (#143671)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 16 01:50:08 PDT 2025


Author: Nashe Mncube
Date: 2025-07-16T09:50:04+01:00
New Revision: de4b458aa5e52812aa9c392f62a616b6c6c1716f

URL: https://github.com/llvm/llvm-project/commit/de4b458aa5e52812aa9c392f62a616b6c6c1716f
DIFF: https://github.com/llvm/llvm-project/commit/de4b458aa5e52812aa9c392f62a616b6c6c1716f.diff

LOG: [AArch64][Codegen]Transform saturating smull to sqdmulh (#143671)

This patch adds a pattern for recognizing saturating vector
smull. Prior to this patch these were performed using a
combination of smull+smull2+uzp+smin  like the following

```
        smull2  v5.2d, v1.4s, v2.4s
        smull   v1.2d, v1.2s, v2.2s
        uzp2    v1.4s, v1.4s, v5.4s
        smin    v1.4s, v1.4s, v0.4s
```

which now optimizes to

```
    sqdmulh v0.4s, v1.4s, v0.4s
```

This only operates on vectors containing int32 and int16 types

Added: 
    llvm/test/CodeGen/AArch64/saturating-vec-smull.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64InstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 235df9022c6fb..4f13a14d24649 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1143,6 +1143,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
                        ISD::SIGN_EXTEND_INREG, ISD::CONCAT_VECTORS,
                        ISD::EXTRACT_SUBVECTOR, ISD::INSERT_SUBVECTOR,
                        ISD::STORE, ISD::BUILD_VECTOR});
+  setTargetDAGCombine(ISD::SMIN);
   setTargetDAGCombine(ISD::TRUNCATE);
   setTargetDAGCombine(ISD::LOAD);
 
@@ -2392,6 +2393,15 @@ static bool isIntImmediate(const SDNode *N, uint64_t &Imm) {
   return false;
 }
 
+bool isVectorizedBinOp(unsigned Opcode) {
+  switch (Opcode) {
+  case AArch64ISD::SQDMULH:
+    return true;
+  default:
+    return false;
+  }
+}
+
 // isOpcWithIntImmediate - This method tests to see if the node is a specific
 // opcode and that it has a immediate integer right operand.
 // If so Imm will receive the value.
@@ -20131,8 +20141,9 @@ static SDValue performConcatVectorsCombine(SDNode *N,
   // size, combine into an binop of two contacts of the source vectors. eg:
   // concat(uhadd(a,b), uhadd(c, d)) -> uhadd(concat(a, c), concat(b, d))
   if (N->getNumOperands() == 2 && N0Opc == N1Opc && VT.is128BitVector() &&
-      DAG.getTargetLoweringInfo().isBinOp(N0Opc) && N0->hasOneUse() &&
-      N1->hasOneUse()) {
+      (DAG.getTargetLoweringInfo().isBinOp(N0Opc) ||
+       isVectorizedBinOp(N0Opc)) &&
+      N0->hasOneUse() && N1->hasOneUse()) {
     SDValue N00 = N0->getOperand(0);
     SDValue N01 = N0->getOperand(1);
     SDValue N10 = N1->getOperand(0);
@@ -20991,6 +21002,98 @@ static SDValue performBuildVectorCombine(SDNode *N,
   return SDValue();
 }
 
+// A special combine for the sqdmulh family of instructions.
+// smin( sra ( mul( sext v0, sext v1 ) ), SHIFT_AMOUNT ),
+// SATURATING_VAL ) can be reduced to sqdmulh(...)
+static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
+
+  if (N->getOpcode() != ISD::SMIN)
+    return SDValue();
+
+  EVT DestVT = N->getValueType(0);
+
+  if (!DestVT.isVector() || DestVT.getScalarSizeInBits() > 64 ||
+      DestVT.isScalableVector())
+    return SDValue();
+
+  ConstantSDNode *Clamp = isConstOrConstSplat(N->getOperand(1));
+
+  if (!Clamp)
+    return SDValue();
+
+  MVT ScalarType;
+  unsigned ShiftAmt = 0;
+  switch (Clamp->getSExtValue()) {
+  case (1ULL << 15) - 1:
+    ScalarType = MVT::i16;
+    ShiftAmt = 16;
+    break;
+  case (1ULL << 31) - 1:
+    ScalarType = MVT::i32;
+    ShiftAmt = 32;
+    break;
+  default:
+    return SDValue();
+  }
+
+  SDValue Sra = N->getOperand(0);
+  if (Sra.getOpcode() != ISD::SRA || !Sra.hasOneUse())
+    return SDValue();
+
+  ConstantSDNode *RightShiftVec = isConstOrConstSplat(Sra.getOperand(1));
+  if (!RightShiftVec)
+    return SDValue();
+  unsigned SExtValue = RightShiftVec->getSExtValue();
+
+  if (SExtValue != (ShiftAmt - 1))
+    return SDValue();
+
+  SDValue Mul = Sra.getOperand(0);
+  if (Mul.getOpcode() != ISD::MUL)
+    return SDValue();
+
+  SDValue SExt0 = Mul.getOperand(0);
+  SDValue SExt1 = Mul.getOperand(1);
+
+  if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
+      SExt1.getOpcode() != ISD::SIGN_EXTEND)
+    return SDValue();
+
+  EVT SExt0Type = SExt0.getOperand(0).getValueType();
+  EVT SExt1Type = SExt1.getOperand(0).getValueType();
+
+  if (SExt0Type != SExt1Type || SExt0Type.getScalarType() != ScalarType ||
+      SExt0Type.getFixedSizeInBits() > 128 || !SExt0Type.isPow2VectorType() ||
+      SExt0Type.getVectorNumElements() == 1)
+    return SDValue();
+
+  SDLoc DL(N);
+  SDValue V0 = SExt0.getOperand(0);
+  SDValue V1 = SExt1.getOperand(0);
+
+  // Ensure input vectors are extended to legal types
+  if (SExt0Type.getFixedSizeInBits() < 64) {
+    unsigned VecNumElements = SExt0Type.getVectorNumElements();
+    EVT ExtVecVT = MVT::getVectorVT(MVT::getIntegerVT(64 / VecNumElements),
+                                    VecNumElements);
+    V0 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVecVT, V0);
+    V1 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVecVT, V1);
+  }
+
+  SDValue SQDMULH =
+      DAG.getNode(AArch64ISD::SQDMULH, DL, V0.getValueType(), V0, V1);
+
+  return DAG.getNode(ISD::SIGN_EXTEND, DL, DestVT, SQDMULH);
+}
+
+static SDValue performSMINCombine(SDNode *N, SelectionDAG &DAG) {
+  if (SDValue V = trySQDMULHCombine(N, DAG)) {
+    return V;
+  }
+
+  return SDValue();
+}
+
 static SDValue performTruncateCombine(SDNode *N, SelectionDAG &DAG,
                                       TargetLowering::DAGCombinerInfo &DCI) {
   SDLoc DL(N);
@@ -26742,6 +26845,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     return performAddSubCombine(N, DCI);
   case ISD::BUILD_VECTOR:
     return performBuildVectorCombine(N, DCI, DAG);
+  case ISD::SMIN:
+    return performSMINCombine(N, DAG);
   case ISD::TRUNCATE:
     return performTruncateCombine(N, DAG, DCI);
   case AArch64ISD::ANDS:

diff  --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index ddc685fae5e9a..ce91b72fa24e5 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1022,6 +1022,7 @@ def AArch64smull    : SDNode<"AArch64ISD::SMULL", SDT_AArch64mull,
                              [SDNPCommutative]>;
 def AArch64umull    : SDNode<"AArch64ISD::UMULL", SDT_AArch64mull,
                              [SDNPCommutative]>;
+def AArch64sqdmulh : SDNode<"AArch64ISD::SQDMULH", SDT_AArch64mull>;
 
 // Reciprocal estimates and steps.
 def AArch64frecpe   : SDNode<"AArch64ISD::FRECPE", SDTFPUnaryOp>;
@@ -9439,6 +9440,15 @@ def : Pat<(v4i32 (mulhu V128:$Rn, V128:$Rm)),
                              (EXTRACT_SUBREG V128:$Rm, dsub)),
            (UMULLv4i32_v2i64 V128:$Rn, V128:$Rm))>;
 
+def : Pat<(v4i16 (AArch64sqdmulh (v4i16 V64:$Rn), (v4i16 V64:$Rm))),
+          (SQDMULHv4i16 V64:$Rn, V64:$Rm)>;
+def : Pat<(v2i32 (AArch64sqdmulh (v2i32 V64:$Rn), (v2i32 V64:$Rm))),
+          (SQDMULHv2i32 V64:$Rn, V64:$Rm)>;
+def : Pat<(v8i16 (AArch64sqdmulh (v8i16 V128:$Rn), (v8i16 V128:$Rm))),
+          (SQDMULHv8i16 V128:$Rn, V128:$Rm)>;
+def : Pat<(v4i32 (AArch64sqdmulh (v4i32 V128:$Rn), (v4i32 V128:$Rm))),
+          (SQDMULHv4i32 V128:$Rn, V128:$Rm)>;
+
 // Conversions within AdvSIMD types in the same register size are free.
 // But because we need a consistent lane ordering, in big endian many
 // conversions require one or more REV instructions.

diff  --git a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
new file mode 100644
index 0000000000000..b647daf72ca35
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
@@ -0,0 +1,223 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-none-elf < %s | FileCheck %s
+
+
+define <2 x i16> @saturating_2xi16(<2 x i16> %a, <2 x i16> %b) {
+; CHECK-LABEL: saturating_2xi16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v0.2s, v0.2s, #16
+; CHECK-NEXT:    shl v1.2s, v1.2s, #16
+; CHECK-NEXT:    sshr v0.2s, v0.2s, #16
+; CHECK-NEXT:    sshr v1.2s, v1.2s, #16
+; CHECK-NEXT:    sqdmulh v0.2s, v1.2s, v0.2s
+; CHECK-NEXT:    ret
+  %as = sext <2 x i16> %a to <2 x i32>
+  %bs = sext <2 x i16> %b to <2 x i32>
+  %m = mul <2 x i32> %bs, %as
+  %sh = ashr <2 x i32> %m, splat (i32 15)
+  %ma = tail call <2 x i32> @llvm.smin.v4i32(<2 x i32> %sh, <2 x i32> splat (i32 32767))
+  %t = trunc <2 x i32> %ma to <2 x i16>
+  ret <2 x i16> %t
+}
+
+define <4 x i16> @saturating_4xi16(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: saturating_4xi16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sqdmulh v0.4h, v1.4h, v0.4h
+; CHECK-NEXT:    ret
+  %as = sext <4 x i16> %a to <4 x i32>
+  %bs = sext <4 x i16> %b to <4 x i32>
+  %m = mul <4 x i32> %bs, %as
+  %sh = ashr <4 x i32> %m, splat (i32 15)
+  %ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 32767))
+  %t = trunc <4 x i32> %ma to <4 x i16>
+  ret <4 x i16> %t
+}
+
+define <8 x i16> @saturating_8xi16(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: saturating_8xi16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sqdmulh v0.8h, v1.8h, v0.8h
+; CHECK-NEXT:    ret
+  %as = sext <8 x i16> %a to <8 x i32>
+  %bs = sext <8 x i16> %b to <8 x i32>
+  %m = mul <8 x i32> %bs, %as
+  %sh = ashr <8 x i32> %m, splat (i32 15)
+  %ma = tail call <8 x i32> @llvm.smin.v8i32(<8 x i32> %sh, <8 x i32> splat (i32 32767))
+  %t = trunc <8 x i32> %ma to <8 x i16>
+  ret <8 x i16> %t
+}
+
+define <2 x i32> @saturating_2xi32(<2 x i32> %a, <2 x i32> %b) {
+; CHECK-LABEL: saturating_2xi32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sqdmulh v0.2s, v1.2s, v0.2s
+; CHECK-NEXT:    ret
+  %as = sext <2 x i32> %a to <2 x i64>
+  %bs = sext <2 x i32> %b to <2 x i64>
+  %m = mul <2 x i64> %bs, %as
+  %sh = ashr <2 x i64> %m, splat (i64 31)
+  %ma = tail call <2 x i64> @llvm.smin.v8i64(<2 x i64> %sh, <2 x i64> splat (i64 2147483647))
+  %t = trunc <2 x i64> %ma to <2 x i32>
+  ret <2 x i32> %t
+}
+
+define <4 x i32> @saturating_4xi32(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: saturating_4xi32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sqdmulh v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    ret
+  %as = sext <4 x i32> %a to <4 x i64>
+  %bs = sext <4 x i32> %b to <4 x i64>
+  %m = mul <4 x i64> %bs, %as
+  %sh = ashr <4 x i64> %m, splat (i64 31)
+  %ma = tail call <4 x i64> @llvm.smin.v4i64(<4 x i64> %sh, <4 x i64> splat (i64 2147483647))
+  %t = trunc <4 x i64> %ma to <4 x i32>
+  ret <4 x i32> %t
+}
+
+define <8 x i32> @saturating_8xi32(<8 x i32> %a, <8 x i32> %b) {
+; CHECK-LABEL: saturating_8xi32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sqdmulh v1.4s, v3.4s, v1.4s
+; CHECK-NEXT:    sqdmulh v0.4s, v2.4s, v0.4s
+; CHECK-NEXT:    ret
+  %as = sext <8 x i32> %a to <8 x i64>
+  %bs = sext <8 x i32> %b to <8 x i64>
+  %m = mul <8 x i64> %bs, %as
+  %sh = ashr <8 x i64> %m, splat (i64 31)
+  %ma = tail call <8 x i64> @llvm.smin.v8i64(<8 x i64> %sh, <8 x i64> splat (i64 2147483647))
+  %t = trunc <8 x i64> %ma to <8 x i32>
+  ret <8 x i32> %t
+}
+
+define <2 x i64> @saturating_2xi32_2xi64(<2 x i32> %a, <2 x i32> %b) {
+; CHECK-LABEL: saturating_2xi32_2xi64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sqdmulh v0.2s, v1.2s, v0.2s
+; CHECK-NEXT:    sshll v0.2d, v0.2s, #0
+; CHECK-NEXT:    ret
+  %as = sext <2 x i32> %a to <2 x i64>
+  %bs = sext <2 x i32> %b to <2 x i64>
+  %m = mul <2 x i64> %bs, %as
+  %sh = ashr <2 x i64> %m, splat (i64 31)
+  %ma = tail call <2 x i64> @llvm.smin.v8i64(<2 x i64> %sh, <2 x i64> splat (i64 2147483647))
+  ret <2 x i64> %ma
+}
+
+define <6 x i16> @saturating_6xi16(<6 x i16> %a, <6 x i16> %b) {
+; CHECK-LABEL: saturating_6xi16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    smull2 v3.4s, v1.8h, v0.8h
+; CHECK-NEXT:    movi v2.4s, #127, msl #8
+; CHECK-NEXT:    sqdmulh v0.4h, v1.4h, v0.4h
+; CHECK-NEXT:    sshr v3.4s, v3.4s, #15
+; CHECK-NEXT:    smin v2.4s, v3.4s, v2.4s
+; CHECK-NEXT:    xtn2 v0.8h, v2.4s
+; CHECK-NEXT:    ret
+  %as = sext <6 x i16> %a to <6 x i32>
+  %bs = sext <6 x i16> %b to <6 x i32>
+  %m = mul <6 x i32> %bs, %as
+  %sh = ashr <6 x i32> %m, splat (i32 15)
+  %ma = tail call <6 x i32> @llvm.smin.v6i32(<6 x i32> %sh, <6 x i32> splat (i32 32767))
+  %t = trunc <6 x i32> %ma to <6 x i16>
+  ret <6 x i16> %t
+}
+
+define <4 x i16> @unsupported_saturation_value_v4i16(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: unsupported_saturation_value_v4i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    smull v0.4s, v1.4h, v0.4h
+; CHECK-NEXT:    movi v1.4s, #42
+; CHECK-NEXT:    sshr v0.4s, v0.4s, #15
+; CHECK-NEXT:    smin v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    ret
+  %as = sext <4 x i16> %a to <4 x i32>
+  %bs = sext <4 x i16> %b to <4 x i32>
+  %m = mul <4 x i32> %bs, %as
+  %sh = ashr <4 x i32> %m, splat (i32 15)
+  %ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 42))
+  %t = trunc <4 x i32> %ma to <4 x i16>
+  ret <4 x i16> %t
+}
+
+define <4 x i16> @unsupported_shift_value_v4i16(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: unsupported_shift_value_v4i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    smull v0.4s, v1.4h, v0.4h
+; CHECK-NEXT:    movi v1.4s, #127, msl #8
+; CHECK-NEXT:    sshr v0.4s, v0.4s, #3
+; CHECK-NEXT:    smin v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    ret
+  %as = sext <4 x i16> %a to <4 x i32>
+  %bs = sext <4 x i16> %b to <4 x i32>
+  %m = mul <4 x i32> %bs, %as
+  %sh = ashr <4 x i32> %m, splat (i32 3)
+  %ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 32767))
+  %t = trunc <4 x i32> %ma to <4 x i16>
+  ret <4 x i16> %t
+}
+
+define <2 x i16> @extend_to_illegal_type(<2 x i16> %a, <2 x i16> %b) {
+; CHECK-LABEL: extend_to_illegal_type:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v0.2s, v0.2s, #16
+; CHECK-NEXT:    shl v1.2s, v1.2s, #16
+; CHECK-NEXT:    sshr v0.2s, v0.2s, #16
+; CHECK-NEXT:    sshr v1.2s, v1.2s, #16
+; CHECK-NEXT:    sqdmulh v0.2s, v1.2s, v0.2s
+; CHECK-NEXT:    ret
+  %as = sext <2 x i16> %a to <2 x i48>
+  %bs = sext <2 x i16> %b to <2 x i48>
+  %m = mul <2 x i48> %bs, %as
+  %sh = ashr <2 x i48> %m, splat (i48 15)
+  %ma = tail call <2 x i48> @llvm.smin.v4i32(<2 x i48> %sh, <2 x i48> splat (i48 32767))
+  %t = trunc <2 x i48> %ma to <2 x i16>
+  ret <2 x i16> %t
+}
+
+define <2 x i11> @illegal_source(<2 x i11> %a, <2 x i11> %b) {
+; CHECK-LABEL: illegal_source:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v0.2s, v0.2s, #21
+; CHECK-NEXT:    shl v1.2s, v1.2s, #21
+; CHECK-NEXT:    sshr v0.2s, v0.2s, #21
+; CHECK-NEXT:    sshr v1.2s, v1.2s, #21
+; CHECK-NEXT:    mul v0.2s, v1.2s, v0.2s
+; CHECK-NEXT:    movi v1.2s, #127, msl #8
+; CHECK-NEXT:    sshr v0.2s, v0.2s, #15
+; CHECK-NEXT:    smin v0.2s, v0.2s, v1.2s
+; CHECK-NEXT:    ret
+  %as = sext <2 x i11> %a to <2 x i32>
+  %bs = sext <2 x i11> %b to <2 x i32>
+  %m = mul <2 x i32> %bs, %as
+  %sh = ashr <2 x i32> %m, splat (i32 15)
+  %ma = tail call <2 x i32> @llvm.smin.v2i32(<2 x i32> %sh, <2 x i32> splat (i32 32767))
+  %t = trunc <2 x i32> %ma to <2 x i11>
+  ret <2 x i11> %t
+}
+define <1 x i16> @saturating_1xi16(<1 x i16> %a, <1 x i16> %b) {
+; CHECK-LABEL: saturating_1xi16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    zip1 v0.4h, v0.4h, v0.4h
+; CHECK-NEXT:    zip1 v1.4h, v1.4h, v0.4h
+; CHECK-NEXT:    shl v0.2s, v0.2s, #16
+; CHECK-NEXT:    sshr v0.2s, v0.2s, #16
+; CHECK-NEXT:    shl v1.2s, v1.2s, #16
+; CHECK-NEXT:    sshr v1.2s, v1.2s, #16
+; CHECK-NEXT:    mul v0.2s, v1.2s, v0.2s
+; CHECK-NEXT:    movi v1.2s, #127, msl #8
+; CHECK-NEXT:    sshr v0.2s, v0.2s, #15
+; CHECK-NEXT:    smin v0.2s, v0.2s, v1.2s
+; CHECK-NEXT:    uzp1 v0.4h, v0.4h, v0.4h
+; CHECK-NEXT:    ret
+  %as = sext <1 x i16> %a to <1 x i32>
+  %bs = sext <1 x i16> %b to <1 x i32>
+  %m = mul <1 x i32> %bs, %as
+  %sh = ashr <1 x i32> %m, splat (i32 15)
+  %ma = tail call <1 x i32> @llvm.smin.v1i32(<1 x i32> %sh, <1 x i32> splat (i32 32767))
+  %t = trunc <1 x i32> %ma to <1 x i16>
+  ret <1 x i16> %t
+}


        


More information about the llvm-commits mailing list