[llvm] [AArch64][Codegen]Transform saturating smull to sqdmulh (PR #143671)

Nashe Mncube via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 13 06:24:17 PDT 2025


https://github.com/nasherm updated https://github.com/llvm/llvm-project/pull/143671

>From 37613743cfb3ef212f9cfd556f50c082311202d5 Mon Sep 17 00:00:00 2001
From: nasmnc01 <nashe.mncube at arm.com>
Date: Tue, 10 Jun 2025 16:20:42 +0100
Subject: [PATCH] [AArch64][Codegen]Transform saturating smull to sqdmulh

This patch adds a pattern for recognizing saturating vector
smull. Prior to this patch these were performed using a
combination of smull+smull2+uzp+smin  like the following

```
        smull2  v5.2d, v1.4s, v2.4s
        smull   v1.2d, v1.2s, v2.2s
        uzp2    v1.4s, v1.4s, v5.4s
        smin    v1.4s, v1.4s, v0.4s
        add     v1.4s, v1.4s, v1.4s
```

which now optimizes to

```
    sqdmulh v0.4s, v1.4s, v0.4s
    sshr    v0.4s, v0.4s, #1
    add v0.4s, v0.4s, v0.4s
```

This only operates on vectors containing Q31 data types.

Change-Id: Ib7d4d5284d1bd3fdd0907365f9e2f37f4da14671
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 73 +++++++++++++++++++
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   |  4 +
 .../CodeGen/AArch64/saturating-vec-smull.ll   | 25 +++++++
 3 files changed, 102 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/saturating-vec-smull.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9f51caef6d228..4abe7af42aba8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -26356,6 +26356,77 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
   return NVCAST;
 }
 
+// A special combine for the vqdmulh family of instructions. This is one of the
+// potential set of patterns that could patch this instruction. The base pattern
+// vshl(smin(uzp(smull, smull2), 1) can be reduced to vshl(vshr(sqdmulh(...),
+// 1), 1) when operating on Q31 data types
+static SDValue performVSHLCombine(SDNode *N,
+                                  TargetLowering::DAGCombinerInfo &DCI,
+                                  SelectionDAG &DAG) {
+
+  SDValue Op0 = N->getOperand(0);
+  ConstantSDNode *Splat = isConstOrConstSplat(N->getOperand(1));
+
+  if (Op0.getOpcode() != ISD::SMIN || !Splat || !Splat->isOne())
+    return SDValue();
+
+  auto trySQDMULHCombine = [](SDNode *N, SelectionDAG &DAG) -> SDValue {
+    EVT VT = N->getValueType(0);
+
+    if (!VT.isVector() || VT.getScalarSizeInBits() > 64)
+      return SDValue();
+
+    ConstantSDNode *Clamp;
+
+    if (N->getOpcode() != ISD::SMIN)
+      return SDValue();
+
+    Clamp = isConstOrConstSplat(N->getOperand(1));
+
+    if (!Clamp) {
+      return SDValue();
+    }
+
+    MVT ScalarType;
+    int ShftAmt = 0;
+    // Here we are considering clamped Arm Q format
+    // data types which uses 2 upper bits, one for the
+    // integer part and one for the sign.
+    switch (Clamp->getSExtValue()) {
+    case (1ULL << 30) - 1:
+      ScalarType = MVT::i32;
+      ShftAmt = 32;
+      break;
+    default:
+      return SDValue();
+    }
+
+    SDValue Mulhs = N->getOperand(0);
+    if (Mulhs.getOpcode() != ISD::MULHS)
+      return SDValue();
+
+    SDValue V0 = Mulhs.getOperand(0);
+    SDValue V1 = Mulhs.getOperand(1);
+
+    SDLoc DL(Mulhs);
+    const unsigned LegalLanes = 128 / ShftAmt;
+    EVT LegalVecVT = MVT::getVectorVT(ScalarType, LegalLanes);
+    return DAG.getNode(AArch64ISD::SQDMULH, DL, LegalVecVT, V0, V1);
+  };
+
+  if (SDValue Val = trySQDMULHCombine(Op0.getNode(), DAG)) {
+    SDLoc DL(N);
+    EVT VecVT = N->getOperand(0).getValueType();
+    // Clear lower bits for correctness
+    SDValue RightShift =
+        DAG.getNode(AArch64ISD::VASHR, DL, VecVT, Val, N->getOperand(1));
+    return DAG.getNode(AArch64ISD::VSHL, DL, VecVT, RightShift,
+                       N->getOperand(1));
+  }
+
+  return SDValue();
+}
+
 /// If the operand is a bitwise AND with a constant RHS, and the shift has a
 /// constant RHS and is the only use, we can pull it out of the shift, i.e.
 ///
@@ -26496,6 +26567,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     return performMaskedGatherScatterCombine(N, DCI, DAG);
   case ISD::FP_EXTEND:
     return performFPExtendCombine(N, DAG, DCI, Subtarget);
+  case AArch64ISD::VSHL:
+    return performVSHLCombine(N, DCI, DAG);
   case AArch64ISD::BRCOND:
     return performBRCONDCombine(N, DCI, DAG);
   case AArch64ISD::TBNZ:
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 727831896737d..87990c63d533f 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1194,6 +1194,7 @@ def AArch64gld1q_index_merge_zero
   : SDNode<"AArch64ISD::GLD1Q_INDEX_MERGE_ZERO", SDTypeProfile<1, 4, []>,
            [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
 
+
 // Match add node and also treat an 'or' node is as an 'add' if the or'ed operands
 // have no common bits.
 def add_and_or_is_add : PatFrags<(ops node:$lhs, node:$rhs),
@@ -9365,6 +9366,9 @@ def : Pat<(v4i32 (mulhu V128:$Rn, V128:$Rm)),
                              (EXTRACT_SUBREG V128:$Rm, dsub)),
            (UMULLv4i32_v2i64 V128:$Rn, V128:$Rm))>;
 
+def : Pat<(v4i32 (AArch64sqdmulh (v4i32 V128:$Rn), (v4i32 V128:$Rm))),
+          (SQDMULHv4i32 V128:$Rn, V128:$Rm)>;
+
 // Conversions within AdvSIMD types in the same register size are free.
 // But because we need a consistent lane ordering, in big endian many
 // conversions require one or more REV instructions.
diff --git a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
new file mode 100644
index 0000000000000..c1bb370ac3e89
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
@@ -0,0 +1,25 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-none-elf < %s | FileCheck %s
+
+define <4 x i32> @arm_mult_q31(ptr %0, ptr %1){
+; CHECK-LABEL: arm_mult_q31:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldr q0, [x0]
+; CHECK-NEXT:    ldr q1, [x1]
+; CHECK-NEXT:    sqdmulh v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    sshr v0.4s, v0.4s, #1
+; CHECK-NEXT:    add v0.4s, v0.4s, v0.4s
+; CHECK-NEXT:    ret
+  %7 = getelementptr i8, ptr %0, i64 0
+  %9 = getelementptr i8, ptr %1, i64 0
+  %12 = load <4 x i32>, ptr %7, align 4
+  %13 = sext <4 x i32> %12 to <4 x i64>
+  %14 = load <4 x i32>, ptr %9, align 4
+  %15 = sext <4 x i32> %14 to <4 x i64>
+  %16 = mul nsw <4 x i64> %15, %13
+  %17 = lshr <4 x i64> %16, splat (i64 32)
+  %18 = trunc nuw <4 x i64> %17 to <4 x i32>
+  %19 = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %18, <4 x i32> splat (i32 1073741823))
+  %20 = shl <4 x i32> %19, splat (i32 1)
+  ret <4 x i32> %20
+}



More information about the llvm-commits mailing list