[llvm] 0cc981e - [AArch64] implement isReassocProfitable, disable for (u|s)mlal.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon May 23 01:42:04 PDT 2022


Author: Florian Hahn
Date: 2022-05-23T09:39:00+01:00
New Revision: 0cc981e021eda2b9c14d41302c5f0409b0a42719

URL: https://github.com/llvm/llvm-project/commit/0cc981e021eda2b9c14d41302c5f0409b0a42719
DIFF: https://github.com/llvm/llvm-project/commit/0cc981e021eda2b9c14d41302c5f0409b0a42719.diff

LOG: [AArch64] implement isReassocProfitable, disable for (u|s)mlal.

Currently reassociating add expressions can lead to failing to select
(u|s)mlal. Implement isReassocProfitable to skip reassociating
expressions that can be lowered to (u|s)mlal.

The same issue exists for the *mlsl variants as well, but the DAG
combiner doesn't use the isReassocProfitable hook before reassociating.
To be fixed in a follow-up commit as this requires DAGCombiner changes
as well.

Reviewed By: dmgreen

Differential Revision: https://reviews.llvm.org/D125895

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/test/CodeGen/AArch64/arm64-vmul.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index e21de4c6270f..3aaadc767dbd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5468,6 +5468,36 @@ bool AArch64TargetLowering::useSVEForFixedLengthVectorVT(
 //                      Calling Convention Implementation
 //===----------------------------------------------------------------------===//
 
+static unsigned getIntrinsicID(const SDNode *N) {
+  unsigned Opcode = N->getOpcode();
+  switch (Opcode) {
+  default:
+    return Intrinsic::not_intrinsic;
+  case ISD::INTRINSIC_WO_CHAIN: {
+    unsigned IID = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue();
+    if (IID < Intrinsic::num_intrinsics)
+      return IID;
+    return Intrinsic::not_intrinsic;
+  }
+  }
+}
+
+bool AArch64TargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+                                                SDValue N1) const {
+  if (!N0.hasOneUse())
+    return false;
+
+  unsigned IID = getIntrinsicID(N1.getNode());
+  // Avoid reassociating expressions that can be lowered to smlal/umlal.
+  if (IID == Intrinsic::aarch64_neon_umull ||
+      N1.getOpcode() == AArch64ISD::UMULL ||
+      IID == Intrinsic::aarch64_neon_smull ||
+      N1.getOpcode() == AArch64ISD::SMULL)
+    return N0.getOpcode() != ISD::ADD;
+
+  return true;
+}
+
 /// Selects the correct CCAssignFn for a given CallingConvention value.
 CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
                                                      bool IsVarArg) const {
@@ -10692,20 +10722,6 @@ static bool isAllConstantBuildVector(const SDValue &PotentialBVec,
   return true;
 }
 
-static unsigned getIntrinsicID(const SDNode *N) {
-  unsigned Opcode = N->getOpcode();
-  switch (Opcode) {
-  default:
-    return Intrinsic::not_intrinsic;
-  case ISD::INTRINSIC_WO_CHAIN: {
-    unsigned IID = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue();
-    if (IID < Intrinsic::num_intrinsics)
-      return IID;
-    return Intrinsic::not_intrinsic;
-  }
-  }
-}
-
 // Attempt to form a vector S[LR]I from (or (and X, BvecC1), (lsl Y, C2)),
 // to (SLI X, Y, C2), where X and Y have matching vector types, BvecC1 is a
 // BUILD_VECTORs with constant element C1, C2 is a constant, and:

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 52ead327d0be..f3f11bb43e1f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -494,6 +494,11 @@ class AArch64TargetLowering : public TargetLowering {
   explicit AArch64TargetLowering(const TargetMachine &TM,
                                  const AArch64Subtarget &STI);
 
+  /// Control the following reassociation of operands: (op (op x, c1), y) -> (op
+  /// (op x, y), c1) where N0 is (op x, c1) and N1 is y.
+  bool isReassocProfitable(SelectionDAG &DAG, SDValue N0,
+                           SDValue N1) const override;
+
   /// Selects the correct CCAssignFn for a given CallingConvention value.
   CCAssignFn *CCAssignFnForCall(CallingConv::ID CC, bool IsVarArg) const;
 

diff  --git a/llvm/test/CodeGen/AArch64/arm64-vmul.ll b/llvm/test/CodeGen/AArch64/arm64-vmul.ll
index f09d21a920ea..46fcaafc8b44 100644
--- a/llvm/test/CodeGen/AArch64/arm64-vmul.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-vmul.ll
@@ -388,12 +388,11 @@ define <2 x i64> @smlal2d(<2 x i32>* %A, <2 x i32>* %B, <2 x i64>* %C) nounwind
 define void @smlal8h_chain_with_constant(<8 x i16>* %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) {
 ; CHECK-LABEL: smlal8h_chain_with_constant:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    smull.8h v0, v0, v2
-; CHECK-NEXT:    mvn.8b v2, v2
 ; CHECK-NEXT:    movi.16b v3, #1
-; CHECK-NEXT:    smlal.8h v0, v1, v2
-; CHECK-NEXT:    add.8h v0, v0, v3
-; CHECK-NEXT:    str q0, [x0]
+; CHECK-NEXT:    smlal.8h v3, v0, v2
+; CHECK-NEXT:    mvn.8b v0, v2
+; CHECK-NEXT:    smlal.8h v3, v1, v0
+; CHECK-NEXT:    str q3, [x0]
 ; CHECK-NEXT:    ret
   %xor = xor <8 x i8> %v3, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
   %smull.1 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %v1, <8 x i8> %v3)
@@ -407,13 +406,12 @@ define void @smlal8h_chain_with_constant(<8 x i16>* %dst, <8 x i8> %v1, <8 x i8>
 define void @smlal2d_chain_with_constant(<2 x i64>* %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) {
 ; CHECK-LABEL: smlal2d_chain_with_constant:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    smull.2d v0, v0, v2
 ; CHECK-NEXT:    mov w8, #257
-; CHECK-NEXT:    mvn.8b v2, v2
-; CHECK-NEXT:    smlal.2d v0, v1, v2
-; CHECK-NEXT:    dup.2d v1, x8
-; CHECK-NEXT:    add.2d v0, v0, v1
-; CHECK-NEXT:    str q0, [x0]
+; CHECK-NEXT:    dup.2d v3, x8
+; CHECK-NEXT:    smlal.2d v3, v0, v2
+; CHECK-NEXT:    mvn.8b v0, v2
+; CHECK-NEXT:    smlal.2d v3, v1, v0
+; CHECK-NEXT:    str q3, [x0]
 ; CHECK-NEXT:    ret
   %xor = xor <2 x i32> %v3, <i32 -1, i32 -1>
   %smull.1 = tail call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %v1, <2 x i32> %v3)
@@ -671,12 +669,11 @@ define <2 x i64> @umlal2d(<2 x i32>* %A, <2 x i32>* %B, <2 x i64>* %C) nounwind
 define void @umlal8h_chain_with_constant(<8 x i16>* %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) {
 ; CHECK-LABEL: umlal8h_chain_with_constant:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    umull.8h v0, v0, v2
-; CHECK-NEXT:    mvn.8b v2, v2
 ; CHECK-NEXT:    movi.16b v3, #1
-; CHECK-NEXT:    umlal.8h v0, v1, v2
-; CHECK-NEXT:    add.8h v0, v0, v3
-; CHECK-NEXT:    str q0, [x0]
+; CHECK-NEXT:    umlal.8h v3, v0, v2
+; CHECK-NEXT:    mvn.8b v0, v2
+; CHECK-NEXT:    umlal.8h v3, v1, v0
+; CHECK-NEXT:    str q3, [x0]
 ; CHECK-NEXT:    ret
   %xor = xor <8 x i8> %v3, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
   %umull.1 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %v1, <8 x i8> %v3)
@@ -690,13 +687,12 @@ define void @umlal8h_chain_with_constant(<8 x i16>* %dst, <8 x i8> %v1, <8 x i8>
 define void @umlal2d_chain_with_constant(<2 x i64>* %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) {
 ; CHECK-LABEL: umlal2d_chain_with_constant:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    umull.2d v0, v0, v2
 ; CHECK-NEXT:    mov w8, #257
-; CHECK-NEXT:    mvn.8b v2, v2
-; CHECK-NEXT:    umlal.2d v0, v1, v2
-; CHECK-NEXT:    dup.2d v1, x8
-; CHECK-NEXT:    add.2d v0, v0, v1
-; CHECK-NEXT:    str q0, [x0]
+; CHECK-NEXT:    dup.2d v3, x8
+; CHECK-NEXT:    umlal.2d v3, v0, v2
+; CHECK-NEXT:    mvn.8b v0, v2
+; CHECK-NEXT:    umlal.2d v3, v1, v0
+; CHECK-NEXT:    str q3, [x0]
 ; CHECK-NEXT:    ret
   %xor = xor <2 x i32> %v3, <i32 -1, i32 -1>
   %umull.1 = tail call <2 x i64> @llvm.aarch64.neon.umull.v2i64(<2 x i32> %v1, <2 x i32> %v3)


        


More information about the llvm-commits mailing list