[llvm] [AArch64][Codegen]Transform saturating smull to sqdmulh (PR #143671)
Nashe Mncube via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 14 07:40:45 PDT 2025
https://github.com/nasherm updated https://github.com/llvm/llvm-project/pull/143671
>From dc5a42425a426fbdc404acd5fa917bf5ecd4ecca Mon Sep 17 00:00:00 2001
From: nasmnc01 <nashe.mncube at arm.com>
Date: Tue, 10 Jun 2025 16:20:42 +0100
Subject: [PATCH 1/7] [AArch64][Codegen]Transform saturating smull to sqdmulh
This patch adds a pattern for recognizing saturating vector
smull. Prior to this patch these were performed using a
combination of smull+smull2+uzp+smin like the following
```
smull2 v5.2d, v1.4s, v2.4s
smull v1.2d, v1.2s, v2.2s
uzp2 v1.4s, v1.4s, v5.4s
smin v1.4s, v1.4s, v0.4s
add v1.4s, v1.4s, v1.4s
```
which now optimizes to
```
sqdmulh v0.4s, v1.4s, v0.4s
sshr v0.4s, v0.4s, #1
add v0.4s, v0.4s, v0.4s
```
This only operates on vectors containing Q31 data types.
Change-Id: Ib7d4d5284d1bd3fdd0907365f9e2f37f4da14671
---
.../Target/AArch64/AArch64ISelLowering.cpp | 73 +++++++++++++++++++
llvm/lib/Target/AArch64/AArch64InstrInfo.td | 7 ++
.../CodeGen/AArch64/saturating-vec-smull.ll | 25 +++++++
3 files changed, 105 insertions(+)
create mode 100644 llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 55601e6327e98..2484336f4ff19 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -26671,6 +26671,77 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
return NVCAST;
}
+// A special combine for the vqdmulh family of instructions. This is one of the
+// potential set of patterns that could patch this instruction. The base pattern
+// vshl(smin(uzp(smull, smull2), 1) can be reduced to vshl(vshr(sqdmulh(...),
+// 1), 1) when operating on Q31 data types
+static SDValue performVSHLCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ SelectionDAG &DAG) {
+
+ SDValue Op0 = N->getOperand(0);
+ ConstantSDNode *Splat = isConstOrConstSplat(N->getOperand(1));
+
+ if (Op0.getOpcode() != ISD::SMIN || !Splat || !Splat->isOne())
+ return SDValue();
+
+ auto trySQDMULHCombine = [](SDNode *N, SelectionDAG &DAG) -> SDValue {
+ EVT VT = N->getValueType(0);
+
+ if (!VT.isVector() || VT.getScalarSizeInBits() > 64)
+ return SDValue();
+
+ ConstantSDNode *Clamp;
+
+ if (N->getOpcode() != ISD::SMIN)
+ return SDValue();
+
+ Clamp = isConstOrConstSplat(N->getOperand(1));
+
+ if (!Clamp) {
+ return SDValue();
+ }
+
+ MVT ScalarType;
+ int ShftAmt = 0;
+ // Here we are considering clamped Arm Q format
+ // data types which uses 2 upper bits, one for the
+ // integer part and one for the sign.
+ switch (Clamp->getSExtValue()) {
+ case (1ULL << 30) - 1:
+ ScalarType = MVT::i32;
+ ShftAmt = 32;
+ break;
+ default:
+ return SDValue();
+ }
+
+ SDValue Mulhs = N->getOperand(0);
+ if (Mulhs.getOpcode() != ISD::MULHS)
+ return SDValue();
+
+ SDValue V0 = Mulhs.getOperand(0);
+ SDValue V1 = Mulhs.getOperand(1);
+
+ SDLoc DL(Mulhs);
+ const unsigned LegalLanes = 128 / ShftAmt;
+ EVT LegalVecVT = MVT::getVectorVT(ScalarType, LegalLanes);
+ return DAG.getNode(AArch64ISD::SQDMULH, DL, LegalVecVT, V0, V1);
+ };
+
+ if (SDValue Val = trySQDMULHCombine(Op0.getNode(), DAG)) {
+ SDLoc DL(N);
+ EVT VecVT = N->getOperand(0).getValueType();
+ // Clear lower bits for correctness
+ SDValue RightShift =
+ DAG.getNode(AArch64ISD::VASHR, DL, VecVT, Val, N->getOperand(1));
+ return DAG.getNode(AArch64ISD::VSHL, DL, VecVT, RightShift,
+ N->getOperand(1));
+ }
+
+ return SDValue();
+}
+
/// If the operand is a bitwise AND with a constant RHS, and the shift has a
/// constant RHS and is the only use, we can pull it out of the shift, i.e.
///
@@ -26811,6 +26882,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performMaskedGatherScatterCombine(N, DCI, DAG);
case ISD::FP_EXTEND:
return performFPExtendCombine(N, DAG, DCI, Subtarget);
+ case AArch64ISD::VSHL:
+ return performVSHLCombine(N, DCI, DAG);
case AArch64ISD::BRCOND:
return performBRCONDCombine(N, DCI, DAG);
case AArch64ISD::TBNZ:
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index ddc685fae5e9a..b60d96cbecda3 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1022,6 +1022,7 @@ def AArch64smull : SDNode<"AArch64ISD::SMULL", SDT_AArch64mull,
[SDNPCommutative]>;
def AArch64umull : SDNode<"AArch64ISD::UMULL", SDT_AArch64mull,
[SDNPCommutative]>;
+def AArch64sqdmulh : SDNode<"AArch64ISD::SQDMULH", SDT_AArch64mull>;
// Reciprocal estimates and steps.
def AArch64frecpe : SDNode<"AArch64ISD::FRECPE", SDTFPUnaryOp>;
@@ -1224,6 +1225,7 @@ def AArch64gld1q_index_merge_zero
: SDNode<"AArch64ISD::GLD1Q_INDEX_MERGE_ZERO", SDTypeProfile<1, 4, []>,
[SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+
// Match add node and also treat an 'or' node is as an 'add' if the or'ed operands
// have no common bits.
def add_and_or_is_add : PatFrags<(ops node:$lhs, node:$rhs),
@@ -8321,6 +8323,7 @@ def : Pat<(v2f64 (any_fmul V128:$Rn, (AArch64dup (f64 FPR64:$Rm)))),
defm SQDMULH : SIMDIndexedHS<0, 0b1100, "sqdmulh", int_aarch64_neon_sqdmulh>;
defm SQRDMULH : SIMDIndexedHS<0, 0b1101, "sqrdmulh", int_aarch64_neon_sqrdmulh>;
+
defm SQDMULH : SIMDIndexedHSPatterns<int_aarch64_neon_sqdmulh_lane,
int_aarch64_neon_sqdmulh_laneq>;
defm SQRDMULH : SIMDIndexedHSPatterns<int_aarch64_neon_sqrdmulh_lane,
@@ -9439,6 +9442,10 @@ def : Pat<(v4i32 (mulhu V128:$Rn, V128:$Rm)),
(EXTRACT_SUBREG V128:$Rm, dsub)),
(UMULLv4i32_v2i64 V128:$Rn, V128:$Rm))>;
+
+def : Pat<(v4i32 (AArch64sqdmulh (v4i32 V128:$Rn), (v4i32 V128:$Rm))),
+ (SQDMULHv4i32 V128:$Rn, V128:$Rm)>;
+
// Conversions within AdvSIMD types in the same register size are free.
// But because we need a consistent lane ordering, in big endian many
// conversions require one or more REV instructions.
diff --git a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
new file mode 100644
index 0000000000000..c1bb370ac3e89
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
@@ -0,0 +1,25 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-none-elf < %s | FileCheck %s
+
+define <4 x i32> @arm_mult_q31(ptr %0, ptr %1){
+; CHECK-LABEL: arm_mult_q31:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ldr q0, [x0]
+; CHECK-NEXT: ldr q1, [x1]
+; CHECK-NEXT: sqdmulh v0.4s, v1.4s, v0.4s
+; CHECK-NEXT: sshr v0.4s, v0.4s, #1
+; CHECK-NEXT: add v0.4s, v0.4s, v0.4s
+; CHECK-NEXT: ret
+ %7 = getelementptr i8, ptr %0, i64 0
+ %9 = getelementptr i8, ptr %1, i64 0
+ %12 = load <4 x i32>, ptr %7, align 4
+ %13 = sext <4 x i32> %12 to <4 x i64>
+ %14 = load <4 x i32>, ptr %9, align 4
+ %15 = sext <4 x i32> %14 to <4 x i64>
+ %16 = mul nsw <4 x i64> %15, %13
+ %17 = lshr <4 x i64> %16, splat (i64 32)
+ %18 = trunc nuw <4 x i64> %17 to <4 x i32>
+ %19 = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %18, <4 x i32> splat (i32 1073741823))
+ %20 = shl <4 x i32> %19, splat (i32 1)
+ ret <4 x i32> %20
+}
>From c495d49cbebfcf618bc1c2eca869bbb64a3bff0e Mon Sep 17 00:00:00 2001
From: nasmnc01 <nashe.mncube at arm.com>
Date: Thu, 26 Jun 2025 11:27:14 +0100
Subject: [PATCH 2/7] Responding to review comments
Based on the most recent PR comments I've
- refactored the change to work on a reduced pattern
which is truer to the actual SQDMULH instruction
- written pattern matches for q31, q15 and int32, int16
data types
- rewritten and extended the tests
Change-Id: I18c05e56b3979b8dd757d533e44a65496434937b
---
.../Target/AArch64/AArch64ISelLowering.cpp | 153 +++++++++---------
llvm/lib/Target/AArch64/AArch64InstrInfo.td | 3 +
.../CodeGen/AArch64/saturating-vec-smull.ll | 69 +++++---
3 files changed, 134 insertions(+), 91 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 2484336f4ff19..2a74113b72a65 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20980,6 +20980,83 @@ static SDValue performBuildVectorCombine(SDNode *N,
return SDValue();
}
+// A special combine for the vqdmulh family of instructions.
+// truncate( smin( sra ( mul( sext v0, sext v1 ) ), SHIFT_AMOUNT ),
+// SATURATING_VAL ) can be reduced to sqdmulh(...)
+static SDValue trySQDMULHCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ SelectionDAG &DAG) {
+
+ if (N->getOpcode() != ISD::TRUNCATE)
+ return SDValue();
+
+ EVT VT = N->getValueType(0);
+
+ if (!VT.isVector() || VT.getScalarSizeInBits() > 64)
+ return SDValue();
+
+ SDValue SMin = N->getOperand(0);
+
+ if (SMin.getOpcode() != ISD::SMIN)
+ return SDValue();
+
+ ConstantSDNode *Clamp = isConstOrConstSplat(SMin.getOperand(1));
+
+ if (!Clamp)
+ return SDValue();
+
+ MVT ScalarType;
+ unsigned ShiftAmt = 0;
+ // Here we are considering clamped Arm Q format
+ // data types which use 2 upper bits, one for the
+ // integer part and one for the sign. We also consider
+ // standard signed integer types
+ switch (Clamp->getSExtValue()) {
+ case (1ULL << 14) - 1: // Q15 saturation
+ case (1ULL << 15) - 1:
+ ScalarType = MVT::i16;
+ ShiftAmt = 16;
+ break;
+ case (1ULL << 30) - 1: // Q31 saturation
+ case (1ULL << 31) - 1:
+ ScalarType = MVT::i32;
+ ShiftAmt = 32;
+ break;
+ default:
+ return SDValue();
+ }
+
+ SDValue Sra = SMin.getOperand(0);
+ if (Sra.getOpcode() != ISD::SRA)
+ return SDValue();
+
+ ConstantSDNode *RightShiftVec = isConstOrConstSplat(Sra.getOperand(1));
+ if (!RightShiftVec)
+ return SDValue();
+ unsigned SExtValue = RightShiftVec->getSExtValue();
+
+ if (SExtValue != ShiftAmt && SExtValue != (ShiftAmt - 1))
+ return SDValue();
+
+ SDValue Mul = Sra.getOperand(0);
+ if (Mul.getOpcode() != ISD::MUL)
+ return SDValue();
+
+ SDValue SExt0 = Mul.getOperand(0);
+ SDValue SExt1 = Mul.getOperand(1);
+
+ if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
+ SExt1.getOpcode() != ISD::SIGN_EXTEND)
+ return SDValue();
+
+ SDValue V0 = SExt0.getOperand(0);
+ SDValue V1 = SExt1.getOperand(0);
+
+ SDLoc DL(N);
+ EVT VecVT = N->getValueType(0);
+ return DAG.getNode(AArch64ISD::SQDMULH, DL, VecVT, V0, V1);
+}
+
static SDValue performTruncateCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
SDLoc DL(N);
@@ -20994,6 +21071,9 @@ static SDValue performTruncateCombine(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(N0.getOpcode(), DL, VT, Op);
}
+ if (SDValue V = trySQDMULHCombine(N, DCI, DAG))
+ return V;
+
// Performing the following combine produces a preferable form for ISEL.
// i32 (trunc (extract Vi64, idx)) -> i32 (extract (nvcast Vi32), idx*2))
if (DCI.isAfterLegalizeDAG() && N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
@@ -26671,77 +26751,6 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
return NVCAST;
}
-// A special combine for the vqdmulh family of instructions. This is one of the
-// potential set of patterns that could patch this instruction. The base pattern
-// vshl(smin(uzp(smull, smull2), 1) can be reduced to vshl(vshr(sqdmulh(...),
-// 1), 1) when operating on Q31 data types
-static SDValue performVSHLCombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI,
- SelectionDAG &DAG) {
-
- SDValue Op0 = N->getOperand(0);
- ConstantSDNode *Splat = isConstOrConstSplat(N->getOperand(1));
-
- if (Op0.getOpcode() != ISD::SMIN || !Splat || !Splat->isOne())
- return SDValue();
-
- auto trySQDMULHCombine = [](SDNode *N, SelectionDAG &DAG) -> SDValue {
- EVT VT = N->getValueType(0);
-
- if (!VT.isVector() || VT.getScalarSizeInBits() > 64)
- return SDValue();
-
- ConstantSDNode *Clamp;
-
- if (N->getOpcode() != ISD::SMIN)
- return SDValue();
-
- Clamp = isConstOrConstSplat(N->getOperand(1));
-
- if (!Clamp) {
- return SDValue();
- }
-
- MVT ScalarType;
- int ShftAmt = 0;
- // Here we are considering clamped Arm Q format
- // data types which uses 2 upper bits, one for the
- // integer part and one for the sign.
- switch (Clamp->getSExtValue()) {
- case (1ULL << 30) - 1:
- ScalarType = MVT::i32;
- ShftAmt = 32;
- break;
- default:
- return SDValue();
- }
-
- SDValue Mulhs = N->getOperand(0);
- if (Mulhs.getOpcode() != ISD::MULHS)
- return SDValue();
-
- SDValue V0 = Mulhs.getOperand(0);
- SDValue V1 = Mulhs.getOperand(1);
-
- SDLoc DL(Mulhs);
- const unsigned LegalLanes = 128 / ShftAmt;
- EVT LegalVecVT = MVT::getVectorVT(ScalarType, LegalLanes);
- return DAG.getNode(AArch64ISD::SQDMULH, DL, LegalVecVT, V0, V1);
- };
-
- if (SDValue Val = trySQDMULHCombine(Op0.getNode(), DAG)) {
- SDLoc DL(N);
- EVT VecVT = N->getOperand(0).getValueType();
- // Clear lower bits for correctness
- SDValue RightShift =
- DAG.getNode(AArch64ISD::VASHR, DL, VecVT, Val, N->getOperand(1));
- return DAG.getNode(AArch64ISD::VSHL, DL, VecVT, RightShift,
- N->getOperand(1));
- }
-
- return SDValue();
-}
-
/// If the operand is a bitwise AND with a constant RHS, and the shift has a
/// constant RHS and is the only use, we can pull it out of the shift, i.e.
///
@@ -26882,8 +26891,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performMaskedGatherScatterCombine(N, DCI, DAG);
case ISD::FP_EXTEND:
return performFPExtendCombine(N, DAG, DCI, Subtarget);
- case AArch64ISD::VSHL:
- return performVSHLCombine(N, DCI, DAG);
case AArch64ISD::BRCOND:
return performBRCONDCombine(N, DCI, DAG);
case AArch64ISD::TBNZ:
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index b60d96cbecda3..f2610de98eecf 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -9443,6 +9443,9 @@ def : Pat<(v4i32 (mulhu V128:$Rn, V128:$Rm)),
(UMULLv4i32_v2i64 V128:$Rn, V128:$Rm))>;
+def : Pat<(v8i16 (AArch64sqdmulh (v8i16 V128:$Rn), (v8i16 V128:$Rm))),
+ (SQDMULHv8i16 V128:$Rn, V128:$Rm)>;
+
def : Pat<(v4i32 (AArch64sqdmulh (v4i32 V128:$Rn), (v4i32 V128:$Rm))),
(SQDMULHv4i32 V128:$Rn, V128:$Rm)>;
diff --git a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
index c1bb370ac3e89..2bc1a427a6b99 100644
--- a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
+++ b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
@@ -1,25 +1,58 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64-none-elf < %s | FileCheck %s
-define <4 x i32> @arm_mult_q31(ptr %0, ptr %1){
-; CHECK-LABEL: arm_mult_q31:
+define <8 x i16> @saturating_int16(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: saturating_int16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sqdmulh v0.8h, v1.8h, v0.8h
+; CHECK-NEXT: ret
+ %as = sext <8 x i16> %a to <8 x i32>
+ %bs = sext <8 x i16> %b to <8 x i32>
+ %m = mul <8 x i32> %bs, %as
+ %sh = ashr <8 x i32> %m, splat (i32 15)
+ %ma = tail call <8 x i32> @llvm.smin.v8i32(<8 x i32> %sh, <8 x i32> splat (i32 32767))
+ %t = trunc <8 x i32> %ma to <8 x i16>
+ ret <8 x i16> %t
+}
+
+define <4 x i32> @saturating_int32(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: saturating_int32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sqdmulh v0.4s, v1.4s, v0.4s
+; CHECK-NEXT: ret
+ %as = sext <4 x i32> %a to <4 x i64>
+ %bs = sext <4 x i32> %b to <4 x i64>
+ %m = mul <4 x i64> %bs, %as
+ %sh = ashr <4 x i64> %m, splat (i64 31)
+ %ma = tail call <4 x i64> @llvm.smin.v8i32(<4 x i64> %sh, <4 x i64> splat (i64 2147483647))
+ %t = trunc <4 x i64> %ma to <4 x i32>
+ ret <4 x i32> %t
+}
+
+define <8 x i16> @saturating_q15(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: saturating_q15:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sqdmulh v0.8h, v1.8h, v0.8h
+; CHECK-NEXT: ret
+ %as = sext <8 x i16> %a to <8 x i32>
+ %bs = sext <8 x i16> %b to <8 x i32>
+ %m = mul <8 x i32> %bs, %as
+ %sh = ashr <8 x i32> %m, splat (i32 16)
+ %ma = tail call <8 x i32> @llvm.smin.v8i32(<8 x i32> %sh, <8 x i32> splat (i32 16383))
+ %t = trunc <8 x i32> %ma to <8 x i16>
+ ret <8 x i16> %t
+}
+
+define <4 x i32> @saturating_q31(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: saturating_q31:
; CHECK: // %bb.0:
-; CHECK-NEXT: ldr q0, [x0]
-; CHECK-NEXT: ldr q1, [x1]
; CHECK-NEXT: sqdmulh v0.4s, v1.4s, v0.4s
-; CHECK-NEXT: sshr v0.4s, v0.4s, #1
-; CHECK-NEXT: add v0.4s, v0.4s, v0.4s
; CHECK-NEXT: ret
- %7 = getelementptr i8, ptr %0, i64 0
- %9 = getelementptr i8, ptr %1, i64 0
- %12 = load <4 x i32>, ptr %7, align 4
- %13 = sext <4 x i32> %12 to <4 x i64>
- %14 = load <4 x i32>, ptr %9, align 4
- %15 = sext <4 x i32> %14 to <4 x i64>
- %16 = mul nsw <4 x i64> %15, %13
- %17 = lshr <4 x i64> %16, splat (i64 32)
- %18 = trunc nuw <4 x i64> %17 to <4 x i32>
- %19 = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %18, <4 x i32> splat (i32 1073741823))
- %20 = shl <4 x i32> %19, splat (i32 1)
- ret <4 x i32> %20
+ %as = sext <4 x i32> %a to <4 x i64>
+ %bs = sext <4 x i32> %b to <4 x i64>
+ %m = mul <4 x i64> %bs, %as
+ %sh = ashr <4 x i64> %m, splat (i64 32)
+ %ma = tail call <4 x i64> @llvm.smin.v8i32(<4 x i64> %sh, <4 x i64> splat (i64 1073741823))
+ %t = trunc <4 x i64> %ma to <4 x i32>
+ ret <4 x i32> %t
}
>From e16c9a25a9a9d5d89a0d978f59458a8f3e158d1e Mon Sep 17 00:00:00 2001
From: nasmnc01 <nashe.mncube at arm.com>
Date: Thu, 26 Jun 2025 13:16:10 +0100
Subject: [PATCH 3/7] Arithmetic error for Q types
Spotted and fixed an artihmetic error when working
with Q types
Change-Id: I80f8e04bca08d3e6bc2740201bdd4978446a397f
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 2 +-
llvm/test/CodeGen/AArch64/saturating-vec-smull.ll | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 2a74113b72a65..3463c7907ad42 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21035,7 +21035,7 @@ static SDValue trySQDMULHCombine(SDNode *N,
return SDValue();
unsigned SExtValue = RightShiftVec->getSExtValue();
- if (SExtValue != ShiftAmt && SExtValue != (ShiftAmt - 1))
+ if (SExtValue != (ShiftAmt - 1))
return SDValue();
SDValue Mul = Sra.getOperand(0);
diff --git a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
index 2bc1a427a6b99..9d478462feae0 100644
--- a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
+++ b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
@@ -37,7 +37,7 @@ define <8 x i16> @saturating_q15(<8 x i16> %a, <8 x i16> %b) {
%as = sext <8 x i16> %a to <8 x i32>
%bs = sext <8 x i16> %b to <8 x i32>
%m = mul <8 x i32> %bs, %as
- %sh = ashr <8 x i32> %m, splat (i32 16)
+ %sh = ashr <8 x i32> %m, splat (i32 15)
%ma = tail call <8 x i32> @llvm.smin.v8i32(<8 x i32> %sh, <8 x i32> splat (i32 16383))
%t = trunc <8 x i32> %ma to <8 x i16>
ret <8 x i16> %t
@@ -51,7 +51,7 @@ define <4 x i32> @saturating_q31(<4 x i32> %a, <4 x i32> %b) {
%as = sext <4 x i32> %a to <4 x i64>
%bs = sext <4 x i32> %b to <4 x i64>
%m = mul <4 x i64> %bs, %as
- %sh = ashr <4 x i64> %m, splat (i64 32)
+ %sh = ashr <4 x i64> %m, splat (i64 31)
%ma = tail call <4 x i64> @llvm.smin.v8i32(<4 x i64> %sh, <4 x i64> splat (i64 1073741823))
%t = trunc <4 x i64> %ma to <4 x i32>
ret <4 x i32> %t
>From 8764be3875ef1cda7455db88400ac45f9bfa2fb4 Mon Sep 17 00:00:00 2001
From: nasmnc01 <nashe.mncube at arm.com>
Date: Fri, 27 Jun 2025 15:29:50 +0100
Subject: [PATCH 4/7] Responding to review comments
- support for v2i32 and v4i16 patterns
- extra type checking on sext
- matching on smin over sext
- cleaning trailing lines
Change-Id: I9f61b8d77a61f3d44ad5073b41555c9ad5653e1a
---
.../Target/AArch64/AArch64ISelLowering.cpp | 33 ++++----
llvm/lib/Target/AArch64/AArch64InstrInfo.td | 7 +-
.../CodeGen/AArch64/saturating-vec-smull.ll | 81 ++++++++++++-------
3 files changed, 74 insertions(+), 47 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 3463c7907ad42..53df92194a91e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20981,11 +20981,9 @@ static SDValue performBuildVectorCombine(SDNode *N,
}
// A special combine for the vqdmulh family of instructions.
-// truncate( smin( sra ( mul( sext v0, sext v1 ) ), SHIFT_AMOUNT ),
-// SATURATING_VAL ) can be reduced to sqdmulh(...)
-static SDValue trySQDMULHCombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI,
- SelectionDAG &DAG) {
+// smin( sra ( mul( sext v0, sext v1 ) ), SHIFT_AMOUNT ),
+// SATURATING_VAL ) can be reduced to sext(sqdmulh(...))
+static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
if (N->getOpcode() != ISD::TRUNCATE)
return SDValue();
@@ -21007,17 +21005,11 @@ static SDValue trySQDMULHCombine(SDNode *N,
MVT ScalarType;
unsigned ShiftAmt = 0;
- // Here we are considering clamped Arm Q format
- // data types which use 2 upper bits, one for the
- // integer part and one for the sign. We also consider
- // standard signed integer types
switch (Clamp->getSExtValue()) {
- case (1ULL << 14) - 1: // Q15 saturation
case (1ULL << 15) - 1:
ScalarType = MVT::i16;
ShiftAmt = 16;
break;
- case (1ULL << 30) - 1: // Q31 saturation
case (1ULL << 31) - 1:
ScalarType = MVT::i32;
ShiftAmt = 32;
@@ -21046,15 +21038,23 @@ static SDValue trySQDMULHCombine(SDNode *N,
SDValue SExt1 = Mul.getOperand(1);
if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
- SExt1.getOpcode() != ISD::SIGN_EXTEND)
+ SExt1.getOpcode() != ISD::SIGN_EXTEND ||
+ SExt0.getValueType() != SExt1.getValueType())
+ return SDValue();
+
+ if ((ShiftAmt == 16 && (SExt0.getValueType() != MVT::v8i32 &&
+ SExt0.getValueType() != MVT::v4i32)) ||
+ (ShiftAmt == 32 && (SExt0.getValueType() != MVT::v4i64 &&
+ SExt0.getValueType() != MVT::v2i64)))
return SDValue();
SDValue V0 = SExt0.getOperand(0);
SDValue V1 = SExt1.getOperand(0);
- SDLoc DL(N);
+ SDLoc DL(SMin);
EVT VecVT = N->getValueType(0);
- return DAG.getNode(AArch64ISD::SQDMULH, DL, VecVT, V0, V1);
+ SDValue SQDMULH = DAG.getNode(AArch64ISD::SQDMULH, DL, VecVT, V0, V1);
+ return DAG.getNode(ISD::SIGN_EXTEND, DL, N->getValueType(0), SQDMULH);
}
static SDValue performTruncateCombine(SDNode *N, SelectionDAG &DAG,
@@ -21071,8 +21071,9 @@ static SDValue performTruncateCombine(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(N0.getOpcode(), DL, VT, Op);
}
- if (SDValue V = trySQDMULHCombine(N, DCI, DAG))
- return V;
+ if (SDValue V = trySQDMULHCombine(N, DAG)) {
+ return DAG.getNode(ISD::TRUNCATE, DL, VT, V);
+ }
// Performing the following combine produces a preferable form for ISEL.
// i32 (trunc (extract Vi64, idx)) -> i32 (extract (nvcast Vi32), idx*2))
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index f2610de98eecf..f390caac0bdf2 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1225,7 +1225,6 @@ def AArch64gld1q_index_merge_zero
: SDNode<"AArch64ISD::GLD1Q_INDEX_MERGE_ZERO", SDTypeProfile<1, 4, []>,
[SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
-
// Match add node and also treat an 'or' node is as an 'add' if the or'ed operands
// have no common bits.
def add_and_or_is_add : PatFrags<(ops node:$lhs, node:$rhs),
@@ -8323,7 +8322,6 @@ def : Pat<(v2f64 (any_fmul V128:$Rn, (AArch64dup (f64 FPR64:$Rm)))),
defm SQDMULH : SIMDIndexedHS<0, 0b1100, "sqdmulh", int_aarch64_neon_sqdmulh>;
defm SQRDMULH : SIMDIndexedHS<0, 0b1101, "sqrdmulh", int_aarch64_neon_sqrdmulh>;
-
defm SQDMULH : SIMDIndexedHSPatterns<int_aarch64_neon_sqdmulh_lane,
int_aarch64_neon_sqdmulh_laneq>;
defm SQRDMULH : SIMDIndexedHSPatterns<int_aarch64_neon_sqrdmulh_lane,
@@ -9442,6 +9440,11 @@ def : Pat<(v4i32 (mulhu V128:$Rn, V128:$Rm)),
(EXTRACT_SUBREG V128:$Rm, dsub)),
(UMULLv4i32_v2i64 V128:$Rn, V128:$Rm))>;
+def : Pat<(v4i16 (AArch64sqdmulh (v4i16 V64:$Rn), (v4i16 V64:$Rm))),
+ (SQDMULHv4i16 V64:$Rn, V64:$Rm)>;
+
+def : Pat<(v2i32 (AArch64sqdmulh (v2i32 V64:$Rn), (v2i32 V64:$Rm))),
+ (SQDMULHv2i32 V64:$Rn, V64:$Rm)>;
def : Pat<(v8i16 (AArch64sqdmulh (v8i16 V128:$Rn), (v8i16 V128:$Rm))),
(SQDMULHv8i16 V128:$Rn, V128:$Rm)>;
diff --git a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
index 9d478462feae0..7094f6c8aafa7 100644
--- a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
+++ b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
@@ -1,8 +1,22 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64-none-elf < %s | FileCheck %s
-define <8 x i16> @saturating_int16(<8 x i16> %a, <8 x i16> %b) {
-; CHECK-LABEL: saturating_int16:
+define <4 x i16> @saturating_4xi16(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: saturating_4xi16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sqdmulh v0.4h, v1.4h, v0.4h
+; CHECK-NEXT: ret
+ %as = sext <4 x i16> %a to <4 x i32>
+ %bs = sext <4 x i16> %b to <4 x i32>
+ %m = mul <4 x i32> %bs, %as
+ %sh = ashr <4 x i32> %m, splat (i32 15)
+ %ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 32767))
+ %t = trunc <4 x i32> %ma to <4 x i16>
+ ret <4 x i16> %t
+}
+
+define <8 x i16> @saturating_8xi16(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: saturating_8xi16:
; CHECK: // %bb.0:
; CHECK-NEXT: sqdmulh v0.8h, v1.8h, v0.8h
; CHECK-NEXT: ret
@@ -15,36 +29,22 @@ define <8 x i16> @saturating_int16(<8 x i16> %a, <8 x i16> %b) {
ret <8 x i16> %t
}
-define <4 x i32> @saturating_int32(<4 x i32> %a, <4 x i32> %b) {
-; CHECK-LABEL: saturating_int32:
-; CHECK: // %bb.0:
-; CHECK-NEXT: sqdmulh v0.4s, v1.4s, v0.4s
-; CHECK-NEXT: ret
- %as = sext <4 x i32> %a to <4 x i64>
- %bs = sext <4 x i32> %b to <4 x i64>
- %m = mul <4 x i64> %bs, %as
- %sh = ashr <4 x i64> %m, splat (i64 31)
- %ma = tail call <4 x i64> @llvm.smin.v8i32(<4 x i64> %sh, <4 x i64> splat (i64 2147483647))
- %t = trunc <4 x i64> %ma to <4 x i32>
- ret <4 x i32> %t
-}
-
-define <8 x i16> @saturating_q15(<8 x i16> %a, <8 x i16> %b) {
-; CHECK-LABEL: saturating_q15:
+define <2 x i32> @saturating_2xi32(<2 x i32> %a, <2 x i32> %b) {
+; CHECK-LABEL: saturating_2xi32:
; CHECK: // %bb.0:
-; CHECK-NEXT: sqdmulh v0.8h, v1.8h, v0.8h
+; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s
; CHECK-NEXT: ret
- %as = sext <8 x i16> %a to <8 x i32>
- %bs = sext <8 x i16> %b to <8 x i32>
- %m = mul <8 x i32> %bs, %as
- %sh = ashr <8 x i32> %m, splat (i32 15)
- %ma = tail call <8 x i32> @llvm.smin.v8i32(<8 x i32> %sh, <8 x i32> splat (i32 16383))
- %t = trunc <8 x i32> %ma to <8 x i16>
- ret <8 x i16> %t
+ %as = sext <2 x i32> %a to <2 x i64>
+ %bs = sext <2 x i32> %b to <2 x i64>
+ %m = mul <2 x i64> %bs, %as
+ %sh = ashr <2 x i64> %m, splat (i64 31)
+ %ma = tail call <2 x i64> @llvm.smin.v8i64(<2 x i64> %sh, <2 x i64> splat (i64 2147483647))
+ %t = trunc <2 x i64> %ma to <2 x i32>
+ ret <2 x i32> %t
}
-define <4 x i32> @saturating_q31(<4 x i32> %a, <4 x i32> %b) {
-; CHECK-LABEL: saturating_q31:
+define <4 x i32> @saturating_4xi32(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: saturating_4xi32:
; CHECK: // %bb.0:
; CHECK-NEXT: sqdmulh v0.4s, v1.4s, v0.4s
; CHECK-NEXT: ret
@@ -52,7 +52,30 @@ define <4 x i32> @saturating_q31(<4 x i32> %a, <4 x i32> %b) {
%bs = sext <4 x i32> %b to <4 x i64>
%m = mul <4 x i64> %bs, %as
%sh = ashr <4 x i64> %m, splat (i64 31)
- %ma = tail call <4 x i64> @llvm.smin.v8i32(<4 x i64> %sh, <4 x i64> splat (i64 1073741823))
+ %ma = tail call <4 x i64> @llvm.smin.v4i64(<4 x i64> %sh, <4 x i64> splat (i64 2147483647))
%t = trunc <4 x i64> %ma to <4 x i32>
ret <4 x i32> %t
}
+
+define <8 x i32> @saturating_8xi32(<8 x i32> %a, <8 x i32> %b) {
+; CHECK-LABEL: saturating_8xi32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NEXT: ext v5.16b, v3.16b, v3.16b, #8
+; CHECK-NEXT: ext v6.16b, v0.16b, v0.16b, #8
+; CHECK-NEXT: ext v7.16b, v2.16b, v2.16b, #8
+; CHECK-NEXT: sqdmulh v1.2s, v3.2s, v1.2s
+; CHECK-NEXT: sqdmulh v0.2s, v2.2s, v0.2s
+; CHECK-NEXT: sqdmulh v4.2s, v5.2s, v4.2s
+; CHECK-NEXT: sqdmulh v3.2s, v7.2s, v6.2s
+; CHECK-NEXT: mov v1.d[1], v4.d[0]
+; CHECK-NEXT: mov v0.d[1], v3.d[0]
+; CHECK-NEXT: ret
+ %as = sext <8 x i32> %a to <8 x i64>
+ %bs = sext <8 x i32> %b to <8 x i64>
+ %m = mul <8 x i64> %bs, %as
+ %sh = ashr <8 x i64> %m, splat (i64 31)
+ %ma = tail call <8 x i64> @llvm.smin.v8i64(<8 x i64> %sh, <8 x i64> splat (i64 2147483647))
+ %t = trunc <8 x i64> %ma to <8 x i32>
+ ret <8 x i32> %t
+}
>From 3d61ef920ed0405c53827212bf615b466811d079 Mon Sep 17 00:00:00 2001
From: nasmnc01 <nashe.mncube at arm.com>
Date: Mon, 30 Jun 2025 16:05:45 +0100
Subject: [PATCH 5/7] Responding to review comments
- minor cleanup
- allow optimizing concat_vectors(sqdmulh,sqdmulh) -> sqdmulh
- testing EVTs better
Change-Id: I0404fb9900896050baac372b7f7ce3a5b03517b9
---
.../Target/AArch64/AArch64ISelLowering.cpp | 34 +++++++++++--------
llvm/lib/Target/AArch64/AArch64InstrInfo.td | 3 --
.../CodeGen/AArch64/saturating-vec-smull.ll | 12 ++-----
3 files changed, 22 insertions(+), 27 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 53df92194a91e..c9a707a39a8b9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2392,6 +2392,15 @@ static bool isIntImmediate(const SDNode *N, uint64_t &Imm) {
return false;
}
+bool isVectorizedBinOp(unsigned Opcode) {
+ switch (Opcode) {
+ case AArch64ISD::SQDMULH:
+ return true;
+ default:
+ return false;
+ }
+}
+
// isOpcWithIntImmediate - This method tests to see if the node is a specific
// opcode and that it has a immediate integer right operand.
// If so Imm will receive the value.
@@ -20120,8 +20129,9 @@ static SDValue performConcatVectorsCombine(SDNode *N,
// size, combine into an binop of two contacts of the source vectors. eg:
// concat(uhadd(a,b), uhadd(c, d)) -> uhadd(concat(a, c), concat(b, d))
if (N->getNumOperands() == 2 && N0Opc == N1Opc && VT.is128BitVector() &&
- DAG.getTargetLoweringInfo().isBinOp(N0Opc) && N0->hasOneUse() &&
- N1->hasOneUse()) {
+ (DAG.getTargetLoweringInfo().isBinOp(N0Opc) ||
+ isVectorizedBinOp(N0Opc)) &&
+ N0->hasOneUse() && N1->hasOneUse()) {
SDValue N00 = N0->getOperand(0);
SDValue N01 = N0->getOperand(1);
SDValue N10 = N1->getOperand(0);
@@ -20980,7 +20990,7 @@ static SDValue performBuildVectorCombine(SDNode *N,
return SDValue();
}
-// A special combine for the vqdmulh family of instructions.
+// A special combine for the sqdmulh family of instructions.
// smin( sra ( mul( sext v0, sext v1 ) ), SHIFT_AMOUNT ),
// SATURATING_VAL ) can be reduced to sext(sqdmulh(...))
static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
@@ -21037,24 +21047,20 @@ static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
SDValue SExt0 = Mul.getOperand(0);
SDValue SExt1 = Mul.getOperand(1);
- if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
- SExt1.getOpcode() != ISD::SIGN_EXTEND ||
- SExt0.getValueType() != SExt1.getValueType())
- return SDValue();
+ EVT SExt0Type = SExt0.getOperand(0).getValueType();
+ EVT SExt1Type = SExt1.getOperand(0).getValueType();
- if ((ShiftAmt == 16 && (SExt0.getValueType() != MVT::v8i32 &&
- SExt0.getValueType() != MVT::v4i32)) ||
- (ShiftAmt == 32 && (SExt0.getValueType() != MVT::v4i64 &&
- SExt0.getValueType() != MVT::v2i64)))
+ if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
+ SExt1.getOpcode() != ISD::SIGN_EXTEND || SExt0Type != SExt1Type ||
+ SExt0Type.getScalarType() != ScalarType ||
+ SExt0Type.getFixedSizeInBits() > 128)
return SDValue();
SDValue V0 = SExt0.getOperand(0);
SDValue V1 = SExt1.getOperand(0);
SDLoc DL(SMin);
- EVT VecVT = N->getValueType(0);
- SDValue SQDMULH = DAG.getNode(AArch64ISD::SQDMULH, DL, VecVT, V0, V1);
- return DAG.getNode(ISD::SIGN_EXTEND, DL, N->getValueType(0), SQDMULH);
+ return DAG.getNode(AArch64ISD::SQDMULH, DL, SExt0Type, V0, V1);
}
static SDValue performTruncateCombine(SDNode *N, SelectionDAG &DAG,
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index f390caac0bdf2..ce91b72fa24e5 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -9442,13 +9442,10 @@ def : Pat<(v4i32 (mulhu V128:$Rn, V128:$Rm)),
def : Pat<(v4i16 (AArch64sqdmulh (v4i16 V64:$Rn), (v4i16 V64:$Rm))),
(SQDMULHv4i16 V64:$Rn, V64:$Rm)>;
-
def : Pat<(v2i32 (AArch64sqdmulh (v2i32 V64:$Rn), (v2i32 V64:$Rm))),
(SQDMULHv2i32 V64:$Rn, V64:$Rm)>;
-
def : Pat<(v8i16 (AArch64sqdmulh (v8i16 V128:$Rn), (v8i16 V128:$Rm))),
(SQDMULHv8i16 V128:$Rn, V128:$Rm)>;
-
def : Pat<(v4i32 (AArch64sqdmulh (v4i32 V128:$Rn), (v4i32 V128:$Rm))),
(SQDMULHv4i32 V128:$Rn, V128:$Rm)>;
diff --git a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
index 7094f6c8aafa7..4bf689f373db3 100644
--- a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
+++ b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
@@ -60,16 +60,8 @@ define <4 x i32> @saturating_4xi32(<4 x i32> %a, <4 x i32> %b) {
define <8 x i32> @saturating_8xi32(<8 x i32> %a, <8 x i32> %b) {
; CHECK-LABEL: saturating_8xi32:
; CHECK: // %bb.0:
-; CHECK-NEXT: ext v4.16b, v1.16b, v1.16b, #8
-; CHECK-NEXT: ext v5.16b, v3.16b, v3.16b, #8
-; CHECK-NEXT: ext v6.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT: ext v7.16b, v2.16b, v2.16b, #8
-; CHECK-NEXT: sqdmulh v1.2s, v3.2s, v1.2s
-; CHECK-NEXT: sqdmulh v0.2s, v2.2s, v0.2s
-; CHECK-NEXT: sqdmulh v4.2s, v5.2s, v4.2s
-; CHECK-NEXT: sqdmulh v3.2s, v7.2s, v6.2s
-; CHECK-NEXT: mov v1.d[1], v4.d[0]
-; CHECK-NEXT: mov v0.d[1], v3.d[0]
+; CHECK-NEXT: sqdmulh v1.4s, v3.4s, v1.4s
+; CHECK-NEXT: sqdmulh v0.4s, v2.4s, v0.4s
; CHECK-NEXT: ret
%as = sext <8 x i32> %a to <8 x i64>
%bs = sext <8 x i32> %b to <8 x i64>
>From 2d083e3200e8382d7580866a224340662cc843a4 Mon Sep 17 00:00:00 2001
From: nasmnc01 <nashe.mncube at arm.com>
Date: Wed, 2 Jul 2025 10:58:24 +0100
Subject: [PATCH 6/7] Responding to review comments
- making sure transform only operates on smin nodes
- adding extra tests dealing with interesting edge cases
Change-Id: Ia1114ec9b93c4de3552b867e0d745beccdae69f1
---
.../Target/AArch64/AArch64ISelLowering.cpp | 46 ++++++++-----
.../CodeGen/AArch64/saturating-vec-smull.ll | 69 +++++++++++++++++++
2 files changed, 97 insertions(+), 18 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index c9a707a39a8b9..cb02bec889ec3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1143,6 +1143,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
ISD::SIGN_EXTEND_INREG, ISD::CONCAT_VECTORS,
ISD::EXTRACT_SUBVECTOR, ISD::INSERT_SUBVECTOR,
ISD::STORE, ISD::BUILD_VECTOR});
+ setTargetDAGCombine(ISD::SMIN);
setTargetDAGCombine(ISD::TRUNCATE);
setTargetDAGCombine(ISD::LOAD);
@@ -20992,10 +20993,10 @@ static SDValue performBuildVectorCombine(SDNode *N,
// A special combine for the sqdmulh family of instructions.
// smin( sra ( mul( sext v0, sext v1 ) ), SHIFT_AMOUNT ),
-// SATURATING_VAL ) can be reduced to sext(sqdmulh(...))
+// SATURATING_VAL ) can be reduced to sqdmulh(...)
static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
- if (N->getOpcode() != ISD::TRUNCATE)
+ if (N->getOpcode() != ISD::SMIN)
return SDValue();
EVT VT = N->getValueType(0);
@@ -21003,12 +21004,7 @@ static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
if (!VT.isVector() || VT.getScalarSizeInBits() > 64)
return SDValue();
- SDValue SMin = N->getOperand(0);
-
- if (SMin.getOpcode() != ISD::SMIN)
- return SDValue();
-
- ConstantSDNode *Clamp = isConstOrConstSplat(SMin.getOperand(1));
+ ConstantSDNode *Clamp = isConstOrConstSplat(N->getOperand(1));
if (!Clamp)
return SDValue();
@@ -21028,8 +21024,8 @@ static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
return SDValue();
}
- SDValue Sra = SMin.getOperand(0);
- if (Sra.getOpcode() != ISD::SRA)
+ SDValue Sra = N->getOperand(0);
+ if (Sra.getOpcode() != ISD::SRA || !Sra.hasOneUse())
return SDValue();
ConstantSDNode *RightShiftVec = isConstOrConstSplat(Sra.getOperand(1));
@@ -21056,11 +21052,27 @@ static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
SExt0Type.getFixedSizeInBits() > 128)
return SDValue();
- SDValue V0 = SExt0.getOperand(0);
- SDValue V1 = SExt1.getOperand(0);
+ // Source vectors with width < 64 are illegal and will need to be extended
+ unsigned SourceVectorWidth = SExt0Type.getFixedSizeInBits();
+ SDValue V0 = (SourceVectorWidth < 64) ? SExt0 : SExt0.getOperand(0);
+ SDValue V1 = (SourceVectorWidth < 64) ? SExt1 : SExt1.getOperand(0);
+
+ SDLoc DL(N);
+ SDValue SQDMULH =
+ DAG.getNode(AArch64ISD::SQDMULH, DL, V0.getValueType(), V0, V1);
+ EVT DestVT = N->getValueType(0);
+ if (DestVT.getScalarSizeInBits() > SExt0Type.getScalarSizeInBits())
+ return DAG.getNode(ISD::SIGN_EXTEND, DL, DestVT, SQDMULH);
+
+ return SQDMULH;
+}
+
+static SDValue performSMINCombine(SDNode *N, SelectionDAG &DAG) {
+ if (SDValue V = trySQDMULHCombine(N, DAG)) {
+ return V;
+ }
- SDLoc DL(SMin);
- return DAG.getNode(AArch64ISD::SQDMULH, DL, SExt0Type, V0, V1);
+ return SDValue();
}
static SDValue performTruncateCombine(SDNode *N, SelectionDAG &DAG,
@@ -21077,10 +21089,6 @@ static SDValue performTruncateCombine(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(N0.getOpcode(), DL, VT, Op);
}
- if (SDValue V = trySQDMULHCombine(N, DAG)) {
- return DAG.getNode(ISD::TRUNCATE, DL, VT, V);
- }
-
// Performing the following combine produces a preferable form for ISEL.
// i32 (trunc (extract Vi64, idx)) -> i32 (extract (nvcast Vi32), idx*2))
if (DCI.isAfterLegalizeDAG() && N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
@@ -26818,6 +26826,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performAddSubCombine(N, DCI);
case ISD::BUILD_VECTOR:
return performBuildVectorCombine(N, DCI, DAG);
+ case ISD::SMIN:
+ return performSMINCombine(N, DAG);
case ISD::TRUNCATE:
return performTruncateCombine(N, DAG, DCI);
case AArch64ISD::ANDS:
diff --git a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
index 4bf689f373db3..e9ca1769274b3 100644
--- a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
+++ b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
@@ -1,6 +1,25 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc -mtriple=aarch64-none-elf < %s | FileCheck %s
+
+define <2 x i16> @saturating_2xi16(<2 x i16> %a, <2 x i16> %b) {
+; CHECK-LABEL: saturating_2xi16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v0.2s, v0.2s, #16
+; CHECK-NEXT: shl v1.2s, v1.2s, #16
+; CHECK-NEXT: sshr v0.2s, v0.2s, #16
+; CHECK-NEXT: sshr v1.2s, v1.2s, #16
+; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: ret
+ %as = sext <2 x i16> %a to <2 x i32>
+ %bs = sext <2 x i16> %b to <2 x i32>
+ %m = mul <2 x i32> %bs, %as
+ %sh = ashr <2 x i32> %m, splat (i32 15)
+ %ma = tail call <2 x i32> @llvm.smin.v4i32(<2 x i32> %sh, <2 x i32> splat (i32 32767))
+ %t = trunc <2 x i32> %ma to <2 x i16>
+ ret <2 x i16> %t
+}
+
define <4 x i16> @saturating_4xi16(<4 x i16> %a, <4 x i16> %b) {
; CHECK-LABEL: saturating_4xi16:
; CHECK: // %bb.0:
@@ -71,3 +90,53 @@ define <8 x i32> @saturating_8xi32(<8 x i32> %a, <8 x i32> %b) {
%t = trunc <8 x i64> %ma to <8 x i32>
ret <8 x i32> %t
}
+
+define <2 x i64> @saturating_2xi32_2xi64(<2 x i32> %a, <2 x i32> %b) {
+; CHECK-LABEL: saturating_2xi32_2xi64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: sshll v0.2d, v0.2s, #0
+; CHECK-NEXT: ret
+ %as = sext <2 x i32> %a to <2 x i64>
+ %bs = sext <2 x i32> %b to <2 x i64>
+ %m = mul <2 x i64> %bs, %as
+ %sh = ashr <2 x i64> %m, splat (i64 31)
+ %ma = tail call <2 x i64> @llvm.smin.v8i64(<2 x i64> %sh, <2 x i64> splat (i64 2147483647))
+ ret <2 x i64> %ma
+}
+
+define <4 x i16> @unsupported_saturation_value_v4i16(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: unsupported_saturation_value_v4i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: smull v0.4s, v1.4h, v0.4h
+; CHECK-NEXT: movi v1.4s, #42
+; CHECK-NEXT: sshr v0.4s, v0.4s, #15
+; CHECK-NEXT: smin v0.4s, v0.4s, v1.4s
+; CHECK-NEXT: xtn v0.4h, v0.4s
+; CHECK-NEXT: ret
+ %as = sext <4 x i16> %a to <4 x i32>
+ %bs = sext <4 x i16> %b to <4 x i32>
+ %m = mul <4 x i32> %bs, %as
+ %sh = ashr <4 x i32> %m, splat (i32 15)
+ %ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 42))
+ %t = trunc <4 x i32> %ma to <4 x i16>
+ ret <4 x i16> %t
+}
+
+define <4 x i16> @unsupported_shift_value_v4i16(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: unsupported_shift_value_v4i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: smull v0.4s, v1.4h, v0.4h
+; CHECK-NEXT: movi v1.4s, #127, msl #8
+; CHECK-NEXT: sshr v0.4s, v0.4s, #3
+; CHECK-NEXT: smin v0.4s, v0.4s, v1.4s
+; CHECK-NEXT: xtn v0.4h, v0.4s
+; CHECK-NEXT: ret
+ %as = sext <4 x i16> %a to <4 x i32>
+ %bs = sext <4 x i16> %b to <4 x i32>
+ %m = mul <4 x i32> %bs, %as
+ %sh = ashr <4 x i32> %m, splat (i32 3)
+ %ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 32767))
+ %t = trunc <4 x i32> %ma to <4 x i16>
+ ret <4 x i16> %t
+}
>From 3568431c95b40a35698c5294c86178b621517bca Mon Sep 17 00:00:00 2001
From: nasmnc01 <nashe.mncube at arm.com>
Date: Mon, 14 Jul 2025 15:38:03 +0100
Subject: [PATCH 7/7] Response to review comments addressing:
- check for scalar type
- check for sign extends
- legalise vector inputs to sqdmulh
- always return sext(sqdmulh)
Change-Id: Ic58b7f267e94bc2592942fc29b829ffb6221770f
---
.../Target/AArch64/AArch64ISelLowering.cpp | 34 ++++++++++------
.../CodeGen/AArch64/saturating-vec-smull.ll | 40 +++++++++++++++++++
2 files changed, 61 insertions(+), 13 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index cb02bec889ec3..820742b70f1c0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -20999,9 +20999,9 @@ static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
if (N->getOpcode() != ISD::SMIN)
return SDValue();
- EVT VT = N->getValueType(0);
+ EVT DestVT = N->getValueType(0);
- if (!VT.isVector() || VT.getScalarSizeInBits() > 64)
+ if (!DestVT.isVector() || DestVT.getScalarSizeInBits() > 64 || DestVT.isScalableVector())
return SDValue();
ConstantSDNode *Clamp = isConstOrConstSplat(N->getOperand(1));
@@ -21043,28 +21043,36 @@ static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
SDValue SExt0 = Mul.getOperand(0);
SDValue SExt1 = Mul.getOperand(1);
+ if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
+ SExt1.getOpcode() != ISD::SIGN_EXTEND)
+ return SDValue();
+
EVT SExt0Type = SExt0.getOperand(0).getValueType();
EVT SExt1Type = SExt1.getOperand(0).getValueType();
- if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
- SExt1.getOpcode() != ISD::SIGN_EXTEND || SExt0Type != SExt1Type ||
+ if (SExt0Type != SExt1Type ||
SExt0Type.getScalarType() != ScalarType ||
SExt0Type.getFixedSizeInBits() > 128)
return SDValue();
- // Source vectors with width < 64 are illegal and will need to be extended
- unsigned SourceVectorWidth = SExt0Type.getFixedSizeInBits();
- SDValue V0 = (SourceVectorWidth < 64) ? SExt0 : SExt0.getOperand(0);
- SDValue V1 = (SourceVectorWidth < 64) ? SExt1 : SExt1.getOperand(0);
-
SDLoc DL(N);
+ SDValue V0 = SExt0.getOperand(0);
+ SDValue V1 = SExt1.getOperand(0);
+
+ // Ensure input vectors are extended to legal types
+ if (SExt0Type.getFixedSizeInBits() < 64) {
+ unsigned VecNumElements = SExt0Type.getVectorNumElements();
+ EVT ExtVecVT =
+ MVT::getVectorVT(MVT::getIntegerVT(64 / VecNumElements),
+ VecNumElements);
+ V0 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVecVT, V0);
+ V1 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVecVT, V1);
+ }
+
SDValue SQDMULH =
DAG.getNode(AArch64ISD::SQDMULH, DL, V0.getValueType(), V0, V1);
- EVT DestVT = N->getValueType(0);
- if (DestVT.getScalarSizeInBits() > SExt0Type.getScalarSizeInBits())
- return DAG.getNode(ISD::SIGN_EXTEND, DL, DestVT, SQDMULH);
- return SQDMULH;
+ return DAG.getNode(ISD::SIGN_EXTEND, DL, DestVT, SQDMULH);
}
static SDValue performSMINCombine(SDNode *N, SelectionDAG &DAG) {
diff --git a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
index e9ca1769274b3..b815cb16441cf 100644
--- a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
+++ b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
@@ -140,3 +140,43 @@ define <4 x i16> @unsupported_shift_value_v4i16(<4 x i16> %a, <4 x i16> %b) {
%t = trunc <4 x i32> %ma to <4 x i16>
ret <4 x i16> %t
}
+
+define <2 x i16> @extend_to_illegal_type(<2 x i16> %a, <2 x i16> %b) {
+; CHECK-LABEL: extend_to_illegal_type:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v0.2s, v0.2s, #16
+; CHECK-NEXT: shl v1.2s, v1.2s, #16
+; CHECK-NEXT: sshr v0.2s, v0.2s, #16
+; CHECK-NEXT: sshr v1.2s, v1.2s, #16
+; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: ret
+ %as = sext <2 x i16> %a to <2 x i48>
+ %bs = sext <2 x i16> %b to <2 x i48>
+ %m = mul <2 x i48> %bs, %as
+ %sh = ashr <2 x i48> %m, splat (i48 15)
+ %ma = tail call <2 x i48> @llvm.smin.v4i32(<2 x i48> %sh, <2 x i48> splat (i48 32767))
+ %t = trunc <2 x i48> %ma to <2 x i16>
+ ret <2 x i16> %t
+}
+
+define <2 x i11> @illegal_source(<2 x i11> %a, <2 x i11> %b) {
+; CHECK-LABEL: source_is_illegal:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v0.2s, v0.2s, #21
+; CHECK-NEXT: shl v1.2s, v1.2s, #21
+; CHECK-NEXT: sshr v0.2s, v0.2s, #21
+; CHECK-NEXT: sshr v1.2s, v1.2s, #21
+; CHECK-NEXT: mul v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: movi v1.2s, #127, msl #8
+; CHECK-NEXT: sshr v0.2s, v0.2s, #15
+; CHECK-NEXT: smin v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: ret
+ %as = sext <2 x i11> %a to <2 x i32>
+ %bs = sext <2 x i11> %b to <2 x i32>
+ %m = mul <2 x i32> %bs, %as
+ %sh = ashr <2 x i32> %m, splat (i32 15)
+ %ma = tail call <2 x i32> @llvm.smin.v4i32(<2 x i32> %sh, <2 x i32> splat (i32 32767))
+ %t = trunc <2 x i32> %ma to <2 x i11>
+ ret <2 x i11> %t
+}
+
More information about the llvm-commits
mailing list