[llvm] [AArch64][Codegen]Transform saturating smull to sqdmulh (PR #143671)
Nashe Mncube via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 13 06:24:17 PDT 2025
https://github.com/nasherm updated https://github.com/llvm/llvm-project/pull/143671
>From 37613743cfb3ef212f9cfd556f50c082311202d5 Mon Sep 17 00:00:00 2001
From: nasmnc01 <nashe.mncube at arm.com>
Date: Tue, 10 Jun 2025 16:20:42 +0100
Subject: [PATCH] [AArch64][Codegen]Transform saturating smull to sqdmulh
This patch adds a pattern for recognizing saturating vector
smull. Prior to this patch these were performed using a
combination of smull+smull2+uzp+smin like the following
```
smull2 v5.2d, v1.4s, v2.4s
smull v1.2d, v1.2s, v2.2s
uzp2 v1.4s, v1.4s, v5.4s
smin v1.4s, v1.4s, v0.4s
add v1.4s, v1.4s, v1.4s
```
which now optimizes to
```
sqdmulh v0.4s, v1.4s, v0.4s
sshr v0.4s, v0.4s, #1
add v0.4s, v0.4s, v0.4s
```
This only operates on vectors containing Q31 data types.
Change-Id: Ib7d4d5284d1bd3fdd0907365f9e2f37f4da14671
---
.../Target/AArch64/AArch64ISelLowering.cpp | 73 +++++++++++++++++++
llvm/lib/Target/AArch64/AArch64InstrInfo.td | 4 +
.../CodeGen/AArch64/saturating-vec-smull.ll | 25 +++++++
3 files changed, 102 insertions(+)
create mode 100644 llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9f51caef6d228..4abe7af42aba8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -26356,6 +26356,77 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
return NVCAST;
}
+// A special combine for the vqdmulh family of instructions. This is one of the
+// potential set of patterns that could patch this instruction. The base pattern
+// vshl(smin(uzp(smull, smull2), 1) can be reduced to vshl(vshr(sqdmulh(...),
+// 1), 1) when operating on Q31 data types
+static SDValue performVSHLCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ SelectionDAG &DAG) {
+
+ SDValue Op0 = N->getOperand(0);
+ ConstantSDNode *Splat = isConstOrConstSplat(N->getOperand(1));
+
+ if (Op0.getOpcode() != ISD::SMIN || !Splat || !Splat->isOne())
+ return SDValue();
+
+ auto trySQDMULHCombine = [](SDNode *N, SelectionDAG &DAG) -> SDValue {
+ EVT VT = N->getValueType(0);
+
+ if (!VT.isVector() || VT.getScalarSizeInBits() > 64)
+ return SDValue();
+
+ ConstantSDNode *Clamp;
+
+ if (N->getOpcode() != ISD::SMIN)
+ return SDValue();
+
+ Clamp = isConstOrConstSplat(N->getOperand(1));
+
+ if (!Clamp) {
+ return SDValue();
+ }
+
+ MVT ScalarType;
+ int ShftAmt = 0;
+ // Here we are considering clamped Arm Q format
+ // data types which uses 2 upper bits, one for the
+ // integer part and one for the sign.
+ switch (Clamp->getSExtValue()) {
+ case (1ULL << 30) - 1:
+ ScalarType = MVT::i32;
+ ShftAmt = 32;
+ break;
+ default:
+ return SDValue();
+ }
+
+ SDValue Mulhs = N->getOperand(0);
+ if (Mulhs.getOpcode() != ISD::MULHS)
+ return SDValue();
+
+ SDValue V0 = Mulhs.getOperand(0);
+ SDValue V1 = Mulhs.getOperand(1);
+
+ SDLoc DL(Mulhs);
+ const unsigned LegalLanes = 128 / ShftAmt;
+ EVT LegalVecVT = MVT::getVectorVT(ScalarType, LegalLanes);
+ return DAG.getNode(AArch64ISD::SQDMULH, DL, LegalVecVT, V0, V1);
+ };
+
+ if (SDValue Val = trySQDMULHCombine(Op0.getNode(), DAG)) {
+ SDLoc DL(N);
+ EVT VecVT = N->getOperand(0).getValueType();
+ // Clear lower bits for correctness
+ SDValue RightShift =
+ DAG.getNode(AArch64ISD::VASHR, DL, VecVT, Val, N->getOperand(1));
+ return DAG.getNode(AArch64ISD::VSHL, DL, VecVT, RightShift,
+ N->getOperand(1));
+ }
+
+ return SDValue();
+}
+
/// If the operand is a bitwise AND with a constant RHS, and the shift has a
/// constant RHS and is the only use, we can pull it out of the shift, i.e.
///
@@ -26496,6 +26567,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performMaskedGatherScatterCombine(N, DCI, DAG);
case ISD::FP_EXTEND:
return performFPExtendCombine(N, DAG, DCI, Subtarget);
+ case AArch64ISD::VSHL:
+ return performVSHLCombine(N, DCI, DAG);
case AArch64ISD::BRCOND:
return performBRCONDCombine(N, DCI, DAG);
case AArch64ISD::TBNZ:
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 727831896737d..87990c63d533f 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1194,6 +1194,7 @@ def AArch64gld1q_index_merge_zero
: SDNode<"AArch64ISD::GLD1Q_INDEX_MERGE_ZERO", SDTypeProfile<1, 4, []>,
[SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+
// Match add node and also treat an 'or' node is as an 'add' if the or'ed operands
// have no common bits.
def add_and_or_is_add : PatFrags<(ops node:$lhs, node:$rhs),
@@ -9365,6 +9366,9 @@ def : Pat<(v4i32 (mulhu V128:$Rn, V128:$Rm)),
(EXTRACT_SUBREG V128:$Rm, dsub)),
(UMULLv4i32_v2i64 V128:$Rn, V128:$Rm))>;
+def : Pat<(v4i32 (AArch64sqdmulh (v4i32 V128:$Rn), (v4i32 V128:$Rm))),
+ (SQDMULHv4i32 V128:$Rn, V128:$Rm)>;
+
// Conversions within AdvSIMD types in the same register size are free.
// But because we need a consistent lane ordering, in big endian many
// conversions require one or more REV instructions.
diff --git a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
new file mode 100644
index 0000000000000..c1bb370ac3e89
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
@@ -0,0 +1,25 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-none-elf < %s | FileCheck %s
+
+define <4 x i32> @arm_mult_q31(ptr %0, ptr %1){
+; CHECK-LABEL: arm_mult_q31:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ldr q0, [x0]
+; CHECK-NEXT: ldr q1, [x1]
+; CHECK-NEXT: sqdmulh v0.4s, v1.4s, v0.4s
+; CHECK-NEXT: sshr v0.4s, v0.4s, #1
+; CHECK-NEXT: add v0.4s, v0.4s, v0.4s
+; CHECK-NEXT: ret
+ %7 = getelementptr i8, ptr %0, i64 0
+ %9 = getelementptr i8, ptr %1, i64 0
+ %12 = load <4 x i32>, ptr %7, align 4
+ %13 = sext <4 x i32> %12 to <4 x i64>
+ %14 = load <4 x i32>, ptr %9, align 4
+ %15 = sext <4 x i32> %14 to <4 x i64>
+ %16 = mul nsw <4 x i64> %15, %13
+ %17 = lshr <4 x i64> %16, splat (i64 32)
+ %18 = trunc nuw <4 x i64> %17 to <4 x i32>
+ %19 = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %18, <4 x i32> splat (i32 1073741823))
+ %20 = shl <4 x i32> %19, splat (i32 1)
+ ret <4 x i32> %20
+}
More information about the llvm-commits
mailing list