[llvm] 0cc981e - [AArch64] implement isReassocProfitable, disable for (u|s)mlal.
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Mon May 23 01:42:04 PDT 2022
Author: Florian Hahn
Date: 2022-05-23T09:39:00+01:00
New Revision: 0cc981e021eda2b9c14d41302c5f0409b0a42719
URL: https://github.com/llvm/llvm-project/commit/0cc981e021eda2b9c14d41302c5f0409b0a42719
DIFF: https://github.com/llvm/llvm-project/commit/0cc981e021eda2b9c14d41302c5f0409b0a42719.diff
LOG: [AArch64] implement isReassocProfitable, disable for (u|s)mlal.
Currently reassociating add expressions can lead to failing to select
(u|s)mlal. Implement isReassocProfitable to skip reassociating
expressions that can be lowered to (u|s)mlal.
The same issue exists for the *mlsl variants as well, but the DAG
combiner doesn't use the isReassocProfitable hook before reassociating.
To be fixed in a follow-up commit as this requires DAGCombiner changes
as well.
Reviewed By: dmgreen
Differential Revision: https://reviews.llvm.org/D125895
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/test/CodeGen/AArch64/arm64-vmul.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e21de4c6270f..3aaadc767dbd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5468,6 +5468,36 @@ bool AArch64TargetLowering::useSVEForFixedLengthVectorVT(
// Calling Convention Implementation
//===----------------------------------------------------------------------===//
+static unsigned getIntrinsicID(const SDNode *N) {
+ unsigned Opcode = N->getOpcode();
+ switch (Opcode) {
+ default:
+ return Intrinsic::not_intrinsic;
+ case ISD::INTRINSIC_WO_CHAIN: {
+ unsigned IID = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue();
+ if (IID < Intrinsic::num_intrinsics)
+ return IID;
+ return Intrinsic::not_intrinsic;
+ }
+ }
+}
+
+bool AArch64TargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+ SDValue N1) const {
+ if (!N0.hasOneUse())
+ return false;
+
+ unsigned IID = getIntrinsicID(N1.getNode());
+ // Avoid reassociating expressions that can be lowered to smlal/umlal.
+ if (IID == Intrinsic::aarch64_neon_umull ||
+ N1.getOpcode() == AArch64ISD::UMULL ||
+ IID == Intrinsic::aarch64_neon_smull ||
+ N1.getOpcode() == AArch64ISD::SMULL)
+ return N0.getOpcode() != ISD::ADD;
+
+ return true;
+}
+
/// Selects the correct CCAssignFn for a given CallingConvention value.
CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
bool IsVarArg) const {
@@ -10692,20 +10722,6 @@ static bool isAllConstantBuildVector(const SDValue &PotentialBVec,
return true;
}
-static unsigned getIntrinsicID(const SDNode *N) {
- unsigned Opcode = N->getOpcode();
- switch (Opcode) {
- default:
- return Intrinsic::not_intrinsic;
- case ISD::INTRINSIC_WO_CHAIN: {
- unsigned IID = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue();
- if (IID < Intrinsic::num_intrinsics)
- return IID;
- return Intrinsic::not_intrinsic;
- }
- }
-}
-
// Attempt to form a vector S[LR]I from (or (and X, BvecC1), (lsl Y, C2)),
// to (SLI X, Y, C2), where X and Y have matching vector types, BvecC1 is a
// BUILD_VECTORs with constant element C1, C2 is a constant, and:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 52ead327d0be..f3f11bb43e1f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -494,6 +494,11 @@ class AArch64TargetLowering : public TargetLowering {
explicit AArch64TargetLowering(const TargetMachine &TM,
const AArch64Subtarget &STI);
+ /// Control the following reassociation of operands: (op (op x, c1), y) -> (op
+ /// (op x, y), c1) where N0 is (op x, c1) and N1 is y.
+ bool isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+ SDValue N1) const override;
+
/// Selects the correct CCAssignFn for a given CallingConvention value.
CCAssignFn *CCAssignFnForCall(CallingConv::ID CC, bool IsVarArg) const;
diff --git a/llvm/test/CodeGen/AArch64/arm64-vmul.ll b/llvm/test/CodeGen/AArch64/arm64-vmul.ll
index f09d21a920ea..46fcaafc8b44 100644
--- a/llvm/test/CodeGen/AArch64/arm64-vmul.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-vmul.ll
@@ -388,12 +388,11 @@ define <2 x i64> @smlal2d(<2 x i32>* %A, <2 x i32>* %B, <2 x i64>* %C) nounwind
define void @smlal8h_chain_with_constant(<8 x i16>* %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) {
; CHECK-LABEL: smlal8h_chain_with_constant:
; CHECK: // %bb.0:
-; CHECK-NEXT: smull.8h v0, v0, v2
-; CHECK-NEXT: mvn.8b v2, v2
; CHECK-NEXT: movi.16b v3, #1
-; CHECK-NEXT: smlal.8h v0, v1, v2
-; CHECK-NEXT: add.8h v0, v0, v3
-; CHECK-NEXT: str q0, [x0]
+; CHECK-NEXT: smlal.8h v3, v0, v2
+; CHECK-NEXT: mvn.8b v0, v2
+; CHECK-NEXT: smlal.8h v3, v1, v0
+; CHECK-NEXT: str q3, [x0]
; CHECK-NEXT: ret
%xor = xor <8 x i8> %v3, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
%smull.1 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %v1, <8 x i8> %v3)
@@ -407,13 +406,12 @@ define void @smlal8h_chain_with_constant(<8 x i16>* %dst, <8 x i8> %v1, <8 x i8>
define void @smlal2d_chain_with_constant(<2 x i64>* %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) {
; CHECK-LABEL: smlal2d_chain_with_constant:
; CHECK: // %bb.0:
-; CHECK-NEXT: smull.2d v0, v0, v2
; CHECK-NEXT: mov w8, #257
-; CHECK-NEXT: mvn.8b v2, v2
-; CHECK-NEXT: smlal.2d v0, v1, v2
-; CHECK-NEXT: dup.2d v1, x8
-; CHECK-NEXT: add.2d v0, v0, v1
-; CHECK-NEXT: str q0, [x0]
+; CHECK-NEXT: dup.2d v3, x8
+; CHECK-NEXT: smlal.2d v3, v0, v2
+; CHECK-NEXT: mvn.8b v0, v2
+; CHECK-NEXT: smlal.2d v3, v1, v0
+; CHECK-NEXT: str q3, [x0]
; CHECK-NEXT: ret
%xor = xor <2 x i32> %v3, <i32 -1, i32 -1>
%smull.1 = tail call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %v1, <2 x i32> %v3)
@@ -671,12 +669,11 @@ define <2 x i64> @umlal2d(<2 x i32>* %A, <2 x i32>* %B, <2 x i64>* %C) nounwind
define void @umlal8h_chain_with_constant(<8 x i16>* %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) {
; CHECK-LABEL: umlal8h_chain_with_constant:
; CHECK: // %bb.0:
-; CHECK-NEXT: umull.8h v0, v0, v2
-; CHECK-NEXT: mvn.8b v2, v2
; CHECK-NEXT: movi.16b v3, #1
-; CHECK-NEXT: umlal.8h v0, v1, v2
-; CHECK-NEXT: add.8h v0, v0, v3
-; CHECK-NEXT: str q0, [x0]
+; CHECK-NEXT: umlal.8h v3, v0, v2
+; CHECK-NEXT: mvn.8b v0, v2
+; CHECK-NEXT: umlal.8h v3, v1, v0
+; CHECK-NEXT: str q3, [x0]
; CHECK-NEXT: ret
%xor = xor <8 x i8> %v3, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
%umull.1 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %v1, <8 x i8> %v3)
@@ -690,13 +687,12 @@ define void @umlal8h_chain_with_constant(<8 x i16>* %dst, <8 x i8> %v1, <8 x i8>
define void @umlal2d_chain_with_constant(<2 x i64>* %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) {
; CHECK-LABEL: umlal2d_chain_with_constant:
; CHECK: // %bb.0:
-; CHECK-NEXT: umull.2d v0, v0, v2
; CHECK-NEXT: mov w8, #257
-; CHECK-NEXT: mvn.8b v2, v2
-; CHECK-NEXT: umlal.2d v0, v1, v2
-; CHECK-NEXT: dup.2d v1, x8
-; CHECK-NEXT: add.2d v0, v0, v1
-; CHECK-NEXT: str q0, [x0]
+; CHECK-NEXT: dup.2d v3, x8
+; CHECK-NEXT: umlal.2d v3, v0, v2
+; CHECK-NEXT: mvn.8b v0, v2
+; CHECK-NEXT: umlal.2d v3, v1, v0
+; CHECK-NEXT: str q3, [x0]
; CHECK-NEXT: ret
%xor = xor <2 x i32> %v3, <i32 -1, i32 -1>
%umull.1 = tail call <2 x i64> @llvm.aarch64.neon.umull.v2i64(<2 x i32> %v1, <2 x i32> %v3)
More information about the llvm-commits
mailing list