[llvm] de4b458 - [AArch64][Codegen]Transform saturating smull to sqdmulh (#143671)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 16 01:50:08 PDT 2025
Author: Nashe Mncube
Date: 2025-07-16T09:50:04+01:00
New Revision: de4b458aa5e52812aa9c392f62a616b6c6c1716f
URL: https://github.com/llvm/llvm-project/commit/de4b458aa5e52812aa9c392f62a616b6c6c1716f
DIFF: https://github.com/llvm/llvm-project/commit/de4b458aa5e52812aa9c392f62a616b6c6c1716f.diff
LOG: [AArch64][Codegen]Transform saturating smull to sqdmulh (#143671)
This patch adds a pattern for recognizing saturating vector
smull. Prior to this patch these were performed using a
combination of smull+smull2+uzp+smin like the following
```
smull2 v5.2d, v1.4s, v2.4s
smull v1.2d, v1.2s, v2.2s
uzp2 v1.4s, v1.4s, v5.4s
smin v1.4s, v1.4s, v0.4s
```
which now optimizes to
```
sqdmulh v0.4s, v1.4s, v0.4s
```
This only operates on vectors containing int32 and int16 types
Added:
llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64InstrInfo.td
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 235df9022c6fb..4f13a14d24649 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1143,6 +1143,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
ISD::SIGN_EXTEND_INREG, ISD::CONCAT_VECTORS,
ISD::EXTRACT_SUBVECTOR, ISD::INSERT_SUBVECTOR,
ISD::STORE, ISD::BUILD_VECTOR});
+ setTargetDAGCombine(ISD::SMIN);
setTargetDAGCombine(ISD::TRUNCATE);
setTargetDAGCombine(ISD::LOAD);
@@ -2392,6 +2393,15 @@ static bool isIntImmediate(const SDNode *N, uint64_t &Imm) {
return false;
}
+bool isVectorizedBinOp(unsigned Opcode) {
+ switch (Opcode) {
+ case AArch64ISD::SQDMULH:
+ return true;
+ default:
+ return false;
+ }
+}
+
// isOpcWithIntImmediate - This method tests to see if the node is a specific
// opcode and that it has a immediate integer right operand.
// If so Imm will receive the value.
@@ -20131,8 +20141,9 @@ static SDValue performConcatVectorsCombine(SDNode *N,
// size, combine into an binop of two contacts of the source vectors. eg:
// concat(uhadd(a,b), uhadd(c, d)) -> uhadd(concat(a, c), concat(b, d))
if (N->getNumOperands() == 2 && N0Opc == N1Opc && VT.is128BitVector() &&
- DAG.getTargetLoweringInfo().isBinOp(N0Opc) && N0->hasOneUse() &&
- N1->hasOneUse()) {
+ (DAG.getTargetLoweringInfo().isBinOp(N0Opc) ||
+ isVectorizedBinOp(N0Opc)) &&
+ N0->hasOneUse() && N1->hasOneUse()) {
SDValue N00 = N0->getOperand(0);
SDValue N01 = N0->getOperand(1);
SDValue N10 = N1->getOperand(0);
@@ -20991,6 +21002,98 @@ static SDValue performBuildVectorCombine(SDNode *N,
return SDValue();
}
+// A special combine for the sqdmulh family of instructions.
+// smin( sra ( mul( sext v0, sext v1 ) ), SHIFT_AMOUNT ),
+// SATURATING_VAL ) can be reduced to sqdmulh(...)
+static SDValue trySQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
+
+ if (N->getOpcode() != ISD::SMIN)
+ return SDValue();
+
+ EVT DestVT = N->getValueType(0);
+
+ if (!DestVT.isVector() || DestVT.getScalarSizeInBits() > 64 ||
+ DestVT.isScalableVector())
+ return SDValue();
+
+ ConstantSDNode *Clamp = isConstOrConstSplat(N->getOperand(1));
+
+ if (!Clamp)
+ return SDValue();
+
+ MVT ScalarType;
+ unsigned ShiftAmt = 0;
+ switch (Clamp->getSExtValue()) {
+ case (1ULL << 15) - 1:
+ ScalarType = MVT::i16;
+ ShiftAmt = 16;
+ break;
+ case (1ULL << 31) - 1:
+ ScalarType = MVT::i32;
+ ShiftAmt = 32;
+ break;
+ default:
+ return SDValue();
+ }
+
+ SDValue Sra = N->getOperand(0);
+ if (Sra.getOpcode() != ISD::SRA || !Sra.hasOneUse())
+ return SDValue();
+
+ ConstantSDNode *RightShiftVec = isConstOrConstSplat(Sra.getOperand(1));
+ if (!RightShiftVec)
+ return SDValue();
+ unsigned SExtValue = RightShiftVec->getSExtValue();
+
+ if (SExtValue != (ShiftAmt - 1))
+ return SDValue();
+
+ SDValue Mul = Sra.getOperand(0);
+ if (Mul.getOpcode() != ISD::MUL)
+ return SDValue();
+
+ SDValue SExt0 = Mul.getOperand(0);
+ SDValue SExt1 = Mul.getOperand(1);
+
+ if (SExt0.getOpcode() != ISD::SIGN_EXTEND ||
+ SExt1.getOpcode() != ISD::SIGN_EXTEND)
+ return SDValue();
+
+ EVT SExt0Type = SExt0.getOperand(0).getValueType();
+ EVT SExt1Type = SExt1.getOperand(0).getValueType();
+
+ if (SExt0Type != SExt1Type || SExt0Type.getScalarType() != ScalarType ||
+ SExt0Type.getFixedSizeInBits() > 128 || !SExt0Type.isPow2VectorType() ||
+ SExt0Type.getVectorNumElements() == 1)
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue V0 = SExt0.getOperand(0);
+ SDValue V1 = SExt1.getOperand(0);
+
+ // Ensure input vectors are extended to legal types
+ if (SExt0Type.getFixedSizeInBits() < 64) {
+ unsigned VecNumElements = SExt0Type.getVectorNumElements();
+ EVT ExtVecVT = MVT::getVectorVT(MVT::getIntegerVT(64 / VecNumElements),
+ VecNumElements);
+ V0 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVecVT, V0);
+ V1 = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVecVT, V1);
+ }
+
+ SDValue SQDMULH =
+ DAG.getNode(AArch64ISD::SQDMULH, DL, V0.getValueType(), V0, V1);
+
+ return DAG.getNode(ISD::SIGN_EXTEND, DL, DestVT, SQDMULH);
+}
+
+static SDValue performSMINCombine(SDNode *N, SelectionDAG &DAG) {
+ if (SDValue V = trySQDMULHCombine(N, DAG)) {
+ return V;
+ }
+
+ return SDValue();
+}
+
static SDValue performTruncateCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
SDLoc DL(N);
@@ -26742,6 +26845,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performAddSubCombine(N, DCI);
case ISD::BUILD_VECTOR:
return performBuildVectorCombine(N, DCI, DAG);
+ case ISD::SMIN:
+ return performSMINCombine(N, DAG);
case ISD::TRUNCATE:
return performTruncateCombine(N, DAG, DCI);
case AArch64ISD::ANDS:
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index ddc685fae5e9a..ce91b72fa24e5 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -1022,6 +1022,7 @@ def AArch64smull : SDNode<"AArch64ISD::SMULL", SDT_AArch64mull,
[SDNPCommutative]>;
def AArch64umull : SDNode<"AArch64ISD::UMULL", SDT_AArch64mull,
[SDNPCommutative]>;
+def AArch64sqdmulh : SDNode<"AArch64ISD::SQDMULH", SDT_AArch64mull>;
// Reciprocal estimates and steps.
def AArch64frecpe : SDNode<"AArch64ISD::FRECPE", SDTFPUnaryOp>;
@@ -9439,6 +9440,15 @@ def : Pat<(v4i32 (mulhu V128:$Rn, V128:$Rm)),
(EXTRACT_SUBREG V128:$Rm, dsub)),
(UMULLv4i32_v2i64 V128:$Rn, V128:$Rm))>;
+def : Pat<(v4i16 (AArch64sqdmulh (v4i16 V64:$Rn), (v4i16 V64:$Rm))),
+ (SQDMULHv4i16 V64:$Rn, V64:$Rm)>;
+def : Pat<(v2i32 (AArch64sqdmulh (v2i32 V64:$Rn), (v2i32 V64:$Rm))),
+ (SQDMULHv2i32 V64:$Rn, V64:$Rm)>;
+def : Pat<(v8i16 (AArch64sqdmulh (v8i16 V128:$Rn), (v8i16 V128:$Rm))),
+ (SQDMULHv8i16 V128:$Rn, V128:$Rm)>;
+def : Pat<(v4i32 (AArch64sqdmulh (v4i32 V128:$Rn), (v4i32 V128:$Rm))),
+ (SQDMULHv4i32 V128:$Rn, V128:$Rm)>;
+
// Conversions within AdvSIMD types in the same register size are free.
// But because we need a consistent lane ordering, in big endian many
// conversions require one or more REV instructions.
diff --git a/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
new file mode 100644
index 0000000000000..b647daf72ca35
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/saturating-vec-smull.ll
@@ -0,0 +1,223 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-none-elf < %s | FileCheck %s
+
+
+define <2 x i16> @saturating_2xi16(<2 x i16> %a, <2 x i16> %b) {
+; CHECK-LABEL: saturating_2xi16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v0.2s, v0.2s, #16
+; CHECK-NEXT: shl v1.2s, v1.2s, #16
+; CHECK-NEXT: sshr v0.2s, v0.2s, #16
+; CHECK-NEXT: sshr v1.2s, v1.2s, #16
+; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: ret
+ %as = sext <2 x i16> %a to <2 x i32>
+ %bs = sext <2 x i16> %b to <2 x i32>
+ %m = mul <2 x i32> %bs, %as
+ %sh = ashr <2 x i32> %m, splat (i32 15)
+ %ma = tail call <2 x i32> @llvm.smin.v4i32(<2 x i32> %sh, <2 x i32> splat (i32 32767))
+ %t = trunc <2 x i32> %ma to <2 x i16>
+ ret <2 x i16> %t
+}
+
+define <4 x i16> @saturating_4xi16(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: saturating_4xi16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sqdmulh v0.4h, v1.4h, v0.4h
+; CHECK-NEXT: ret
+ %as = sext <4 x i16> %a to <4 x i32>
+ %bs = sext <4 x i16> %b to <4 x i32>
+ %m = mul <4 x i32> %bs, %as
+ %sh = ashr <4 x i32> %m, splat (i32 15)
+ %ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 32767))
+ %t = trunc <4 x i32> %ma to <4 x i16>
+ ret <4 x i16> %t
+}
+
+define <8 x i16> @saturating_8xi16(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: saturating_8xi16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sqdmulh v0.8h, v1.8h, v0.8h
+; CHECK-NEXT: ret
+ %as = sext <8 x i16> %a to <8 x i32>
+ %bs = sext <8 x i16> %b to <8 x i32>
+ %m = mul <8 x i32> %bs, %as
+ %sh = ashr <8 x i32> %m, splat (i32 15)
+ %ma = tail call <8 x i32> @llvm.smin.v8i32(<8 x i32> %sh, <8 x i32> splat (i32 32767))
+ %t = trunc <8 x i32> %ma to <8 x i16>
+ ret <8 x i16> %t
+}
+
+define <2 x i32> @saturating_2xi32(<2 x i32> %a, <2 x i32> %b) {
+; CHECK-LABEL: saturating_2xi32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: ret
+ %as = sext <2 x i32> %a to <2 x i64>
+ %bs = sext <2 x i32> %b to <2 x i64>
+ %m = mul <2 x i64> %bs, %as
+ %sh = ashr <2 x i64> %m, splat (i64 31)
+ %ma = tail call <2 x i64> @llvm.smin.v8i64(<2 x i64> %sh, <2 x i64> splat (i64 2147483647))
+ %t = trunc <2 x i64> %ma to <2 x i32>
+ ret <2 x i32> %t
+}
+
+define <4 x i32> @saturating_4xi32(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-LABEL: saturating_4xi32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sqdmulh v0.4s, v1.4s, v0.4s
+; CHECK-NEXT: ret
+ %as = sext <4 x i32> %a to <4 x i64>
+ %bs = sext <4 x i32> %b to <4 x i64>
+ %m = mul <4 x i64> %bs, %as
+ %sh = ashr <4 x i64> %m, splat (i64 31)
+ %ma = tail call <4 x i64> @llvm.smin.v4i64(<4 x i64> %sh, <4 x i64> splat (i64 2147483647))
+ %t = trunc <4 x i64> %ma to <4 x i32>
+ ret <4 x i32> %t
+}
+
+define <8 x i32> @saturating_8xi32(<8 x i32> %a, <8 x i32> %b) {
+; CHECK-LABEL: saturating_8xi32:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sqdmulh v1.4s, v3.4s, v1.4s
+; CHECK-NEXT: sqdmulh v0.4s, v2.4s, v0.4s
+; CHECK-NEXT: ret
+ %as = sext <8 x i32> %a to <8 x i64>
+ %bs = sext <8 x i32> %b to <8 x i64>
+ %m = mul <8 x i64> %bs, %as
+ %sh = ashr <8 x i64> %m, splat (i64 31)
+ %ma = tail call <8 x i64> @llvm.smin.v8i64(<8 x i64> %sh, <8 x i64> splat (i64 2147483647))
+ %t = trunc <8 x i64> %ma to <8 x i32>
+ ret <8 x i32> %t
+}
+
+define <2 x i64> @saturating_2xi32_2xi64(<2 x i32> %a, <2 x i32> %b) {
+; CHECK-LABEL: saturating_2xi32_2xi64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: sshll v0.2d, v0.2s, #0
+; CHECK-NEXT: ret
+ %as = sext <2 x i32> %a to <2 x i64>
+ %bs = sext <2 x i32> %b to <2 x i64>
+ %m = mul <2 x i64> %bs, %as
+ %sh = ashr <2 x i64> %m, splat (i64 31)
+ %ma = tail call <2 x i64> @llvm.smin.v8i64(<2 x i64> %sh, <2 x i64> splat (i64 2147483647))
+ ret <2 x i64> %ma
+}
+
+define <6 x i16> @saturating_6xi16(<6 x i16> %a, <6 x i16> %b) {
+; CHECK-LABEL: saturating_6xi16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: smull2 v3.4s, v1.8h, v0.8h
+; CHECK-NEXT: movi v2.4s, #127, msl #8
+; CHECK-NEXT: sqdmulh v0.4h, v1.4h, v0.4h
+; CHECK-NEXT: sshr v3.4s, v3.4s, #15
+; CHECK-NEXT: smin v2.4s, v3.4s, v2.4s
+; CHECK-NEXT: xtn2 v0.8h, v2.4s
+; CHECK-NEXT: ret
+ %as = sext <6 x i16> %a to <6 x i32>
+ %bs = sext <6 x i16> %b to <6 x i32>
+ %m = mul <6 x i32> %bs, %as
+ %sh = ashr <6 x i32> %m, splat (i32 15)
+ %ma = tail call <6 x i32> @llvm.smin.v6i32(<6 x i32> %sh, <6 x i32> splat (i32 32767))
+ %t = trunc <6 x i32> %ma to <6 x i16>
+ ret <6 x i16> %t
+}
+
+define <4 x i16> @unsupported_saturation_value_v4i16(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: unsupported_saturation_value_v4i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: smull v0.4s, v1.4h, v0.4h
+; CHECK-NEXT: movi v1.4s, #42
+; CHECK-NEXT: sshr v0.4s, v0.4s, #15
+; CHECK-NEXT: smin v0.4s, v0.4s, v1.4s
+; CHECK-NEXT: xtn v0.4h, v0.4s
+; CHECK-NEXT: ret
+ %as = sext <4 x i16> %a to <4 x i32>
+ %bs = sext <4 x i16> %b to <4 x i32>
+ %m = mul <4 x i32> %bs, %as
+ %sh = ashr <4 x i32> %m, splat (i32 15)
+ %ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 42))
+ %t = trunc <4 x i32> %ma to <4 x i16>
+ ret <4 x i16> %t
+}
+
+define <4 x i16> @unsupported_shift_value_v4i16(<4 x i16> %a, <4 x i16> %b) {
+; CHECK-LABEL: unsupported_shift_value_v4i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: smull v0.4s, v1.4h, v0.4h
+; CHECK-NEXT: movi v1.4s, #127, msl #8
+; CHECK-NEXT: sshr v0.4s, v0.4s, #3
+; CHECK-NEXT: smin v0.4s, v0.4s, v1.4s
+; CHECK-NEXT: xtn v0.4h, v0.4s
+; CHECK-NEXT: ret
+ %as = sext <4 x i16> %a to <4 x i32>
+ %bs = sext <4 x i16> %b to <4 x i32>
+ %m = mul <4 x i32> %bs, %as
+ %sh = ashr <4 x i32> %m, splat (i32 3)
+ %ma = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %sh, <4 x i32> splat (i32 32767))
+ %t = trunc <4 x i32> %ma to <4 x i16>
+ ret <4 x i16> %t
+}
+
+define <2 x i16> @extend_to_illegal_type(<2 x i16> %a, <2 x i16> %b) {
+; CHECK-LABEL: extend_to_illegal_type:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v0.2s, v0.2s, #16
+; CHECK-NEXT: shl v1.2s, v1.2s, #16
+; CHECK-NEXT: sshr v0.2s, v0.2s, #16
+; CHECK-NEXT: sshr v1.2s, v1.2s, #16
+; CHECK-NEXT: sqdmulh v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: ret
+ %as = sext <2 x i16> %a to <2 x i48>
+ %bs = sext <2 x i16> %b to <2 x i48>
+ %m = mul <2 x i48> %bs, %as
+ %sh = ashr <2 x i48> %m, splat (i48 15)
+ %ma = tail call <2 x i48> @llvm.smin.v4i32(<2 x i48> %sh, <2 x i48> splat (i48 32767))
+ %t = trunc <2 x i48> %ma to <2 x i16>
+ ret <2 x i16> %t
+}
+
+define <2 x i11> @illegal_source(<2 x i11> %a, <2 x i11> %b) {
+; CHECK-LABEL: illegal_source:
+; CHECK: // %bb.0:
+; CHECK-NEXT: shl v0.2s, v0.2s, #21
+; CHECK-NEXT: shl v1.2s, v1.2s, #21
+; CHECK-NEXT: sshr v0.2s, v0.2s, #21
+; CHECK-NEXT: sshr v1.2s, v1.2s, #21
+; CHECK-NEXT: mul v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: movi v1.2s, #127, msl #8
+; CHECK-NEXT: sshr v0.2s, v0.2s, #15
+; CHECK-NEXT: smin v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: ret
+ %as = sext <2 x i11> %a to <2 x i32>
+ %bs = sext <2 x i11> %b to <2 x i32>
+ %m = mul <2 x i32> %bs, %as
+ %sh = ashr <2 x i32> %m, splat (i32 15)
+ %ma = tail call <2 x i32> @llvm.smin.v2i32(<2 x i32> %sh, <2 x i32> splat (i32 32767))
+ %t = trunc <2 x i32> %ma to <2 x i11>
+ ret <2 x i11> %t
+}
+define <1 x i16> @saturating_1xi16(<1 x i16> %a, <1 x i16> %b) {
+; CHECK-LABEL: saturating_1xi16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: zip1 v0.4h, v0.4h, v0.4h
+; CHECK-NEXT: zip1 v1.4h, v1.4h, v0.4h
+; CHECK-NEXT: shl v0.2s, v0.2s, #16
+; CHECK-NEXT: sshr v0.2s, v0.2s, #16
+; CHECK-NEXT: shl v1.2s, v1.2s, #16
+; CHECK-NEXT: sshr v1.2s, v1.2s, #16
+; CHECK-NEXT: mul v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: movi v1.2s, #127, msl #8
+; CHECK-NEXT: sshr v0.2s, v0.2s, #15
+; CHECK-NEXT: smin v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: uzp1 v0.4h, v0.4h, v0.4h
+; CHECK-NEXT: ret
+ %as = sext <1 x i16> %a to <1 x i32>
+ %bs = sext <1 x i16> %b to <1 x i32>
+ %m = mul <1 x i32> %bs, %as
+ %sh = ashr <1 x i32> %m, splat (i32 15)
+ %ma = tail call <1 x i32> @llvm.smin.v1i32(<1 x i32> %sh, <1 x i32> splat (i32 32767))
+ %t = trunc <1 x i32> %ma to <1 x i16>
+ ret <1 x i16> %t
+}
More information about the llvm-commits
mailing list