[llvm] [ARM] Add NEON support for ISD::ABDS/ABDU nodes. (PR #94504)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 6 01:36:08 PDT 2024


https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/94504

>From 2f9a635bc94d4bd5d30d46ab0263d33b8133a09c Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Wed, 5 Jun 2024 18:06:29 +0100
Subject: [PATCH] [ARM] Add NEON support for ISD::ABDS/ABDU nodes.

As noted on #94466, NEON has ABDS/ABDU instructions but only handles then via intrinsics, plus some VABDL custom patterns.

Fixes #94466
---
 llvm/lib/Target/ARM/ARMISelLowering.cpp | 29 +++++++++-----------
 llvm/lib/Target/ARM/ARMInstrNEON.td     | 36 ++++++++-----------------
 llvm/test/CodeGen/ARM/neon_vabs.ll      |  5 +++-
 3 files changed, 27 insertions(+), 43 deletions(-)

diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 5212d2c620b7..faf639e39bce 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -205,9 +205,9 @@ void ARMTargetLowering::addTypeForNEON(MVT VT, MVT PromotedLdStVT) {
   setOperationAction(ISD::SDIVREM, VT, Expand);
   setOperationAction(ISD::UDIVREM, VT, Expand);
 
-  if (!VT.isFloatingPoint() &&
-      VT != MVT::v2i64 && VT != MVT::v1i64)
-    for (auto Opcode : {ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX})
+  if (!VT.isFloatingPoint() && VT != MVT::v2i64 && VT != MVT::v1i64)
+    for (auto Opcode : {ISD::ABS, ISD::ABDS, ISD::ABDU, ISD::SMIN, ISD::SMAX,
+                        ISD::UMIN, ISD::UMAX})
       setOperationAction(Opcode, VT, Legal);
   if (!VT.isFloatingPoint())
     for (auto Opcode : {ISD::SADDSAT, ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT})
@@ -4174,7 +4174,15 @@ ARMTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG,
   }
   case Intrinsic::arm_neon_vabs:
     return DAG.getNode(ISD::ABS, SDLoc(Op), Op.getValueType(),
-                        Op.getOperand(1));
+                       Op.getOperand(1));
+  case Intrinsic::arm_neon_vabds:
+    if (Op.getValueType().isInteger())
+      return DAG.getNode(ISD::ABDS, SDLoc(Op), Op.getValueType(),
+                         Op.getOperand(1), Op.getOperand(2));
+    return SDValue();
+  case Intrinsic::arm_neon_vabdu:
+    return DAG.getNode(ISD::ABDU, SDLoc(Op), Op.getValueType(),
+                       Op.getOperand(1), Op.getOperand(2));
   case Intrinsic::arm_neon_vmulls:
   case Intrinsic::arm_neon_vmullu: {
     unsigned NewOpc = (IntNo == Intrinsic::arm_neon_vmulls)
@@ -13496,18 +13504,6 @@ static SDValue PerformVSetCCToVCTPCombine(SDNode *N,
                          DCI.DAG.getZExtOrTrunc(Op1S, DL, MVT::i32));
 }
 
-static SDValue PerformABSCombine(SDNode *N,
-                                 TargetLowering::DAGCombinerInfo &DCI,
-                                 const ARMSubtarget *Subtarget) {
-  SelectionDAG &DAG = DCI.DAG;
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-
-  if (TLI.isOperationLegal(N->getOpcode(), N->getValueType(0)))
-    return SDValue();
-
-  return TLI.expandABS(N, DAG);
-}
-
 /// PerformADDECombine - Target-specific dag combine transform from
 /// ARMISD::ADDC, ARMISD::ADDE, and ISD::MUL_LOHI to MLAL or
 /// ARMISD::ADDC, ARMISD::ADDE and ARMISD::UMLAL to ARMISD::UMAAL
@@ -18871,7 +18867,6 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::SELECT:     return PerformSELECTCombine(N, DCI, Subtarget);
   case ISD::VSELECT:    return PerformVSELECTCombine(N, DCI, Subtarget);
   case ISD::SETCC:      return PerformVSetCCToVCTPCombine(N, DCI, Subtarget);
-  case ISD::ABS:        return PerformABSCombine(N, DCI, Subtarget);
   case ARMISD::ADDE:    return PerformADDECombine(N, DCI, Subtarget);
   case ARMISD::UMLAL:   return PerformUMLALCombine(N, DCI.DAG, Subtarget);
   case ISD::ADD:        return PerformADDCombine(N, DCI, Subtarget);
diff --git a/llvm/lib/Target/ARM/ARMInstrNEON.td b/llvm/lib/Target/ARM/ARMInstrNEON.td
index 21a5817252ae..fcabc9076e4d 100644
--- a/llvm/lib/Target/ARM/ARMInstrNEON.td
+++ b/llvm/lib/Target/ARM/ARMInstrNEON.td
@@ -5640,10 +5640,10 @@ def  VBITq    : N3VX<1, 0, 0b10, 0b0001, 1, 1,
 //   VABD     : Vector Absolute Difference
 defm VABDs    : N3VInt_QHS<0, 0, 0b0111, 0, N3RegFrm,
                            IIC_VSUBi4D, IIC_VSUBi4D, IIC_VSUBi4Q, IIC_VSUBi4Q,
-                           "vabd", "s", int_arm_neon_vabds, 1>;
+                           "vabd", "s", abds, 1>;
 defm VABDu    : N3VInt_QHS<1, 0, 0b0111, 0, N3RegFrm,
                            IIC_VSUBi4D, IIC_VSUBi4D, IIC_VSUBi4Q, IIC_VSUBi4Q,
-                           "vabd", "u", int_arm_neon_vabdu, 1>;
+                           "vabd", "u", abdu, 1>;
 def  VABDfd   : N3VDInt<1, 0, 0b10, 0b1101, 0, N3RegFrm, IIC_VBIND,
                         "vabd", "f32", v2f32, v2f32, int_arm_neon_vabds, 1>;
 def  VABDfq   : N3VQInt<1, 0, 0b10, 0b1101, 0, N3RegFrm, IIC_VBINQ,
@@ -5657,44 +5657,30 @@ def  VABDhq   : N3VQInt<1, 0, 0b11, 0b1101, 0, N3RegFrm, IIC_VBINQ,
 
 //   VABDL    : Vector Absolute Difference Long (Q = | D - D |)
 defm VABDLs   : N3VLIntExt_QHS<0,1,0b0111,0, IIC_VSUBi4Q,
-                               "vabdl", "s", int_arm_neon_vabds, zext, 1>;
+                               "vabdl", "s", abds, zext, 1>;
 defm VABDLu   : N3VLIntExt_QHS<1,1,0b0111,0, IIC_VSUBi4Q,
-                               "vabdl", "u", int_arm_neon_vabdu, zext, 1>;
+                               "vabdl", "u", abdu, zext, 1>;
 
 let Predicates = [HasNEON] in {
-def : Pat<(v8i16 (abs (sub (zext (v8i8 DPR:$opA)), (zext (v8i8 DPR:$opB))))),
+def : Pat<(v8i16 (zext (abdu (v8i8 DPR:$opA), (v8i8 DPR:$opB)))),
           (VABDLuv8i16 DPR:$opA, DPR:$opB)>;
-def : Pat<(v4i32 (abs (sub (zext (v4i16 DPR:$opA)), (zext (v4i16 DPR:$opB))))),
+def : Pat<(v4i32 (zext (abdu (v4i16 DPR:$opA), (v4i16 DPR:$opB)))),
           (VABDLuv4i32 DPR:$opA, DPR:$opB)>;
-}
-
-// ISD::ABS is not legal for v2i64, so VABDL needs to be matched from the
-// shift/xor pattern for ABS.
-
-def abd_shr :
-    PatFrag<(ops node:$in1, node:$in2, node:$shift),
-            (ARMvshrsImm (sub (zext node:$in1),
-                            (zext node:$in2)), (i32 $shift))>;
-
-let Predicates = [HasNEON] in {
-def : Pat<(xor (v2i64 (abd_shr (v2i32 DPR:$opA), (v2i32 DPR:$opB), 63)),
-               (v2i64 (add (sub (zext (v2i32 DPR:$opA)),
-                                (zext (v2i32 DPR:$opB))),
-                           (abd_shr (v2i32 DPR:$opA), (v2i32 DPR:$opB), 63)))),
+def : Pat<(v2i64 (zext (abdu (v2i32 DPR:$opA), (v2i32 DPR:$opB)))),
           (VABDLuv2i64 DPR:$opA, DPR:$opB)>;
 }
 
 //   VABA     : Vector Absolute Difference and Accumulate
 defm VABAs    : N3VIntOp_QHS<0,0,0b0111,1, IIC_VABAD, IIC_VABAQ,
-                             "vaba", "s", int_arm_neon_vabds, add>;
+                             "vaba", "s", abds, add>;
 defm VABAu    : N3VIntOp_QHS<1,0,0b0111,1, IIC_VABAD, IIC_VABAQ,
-                             "vaba", "u", int_arm_neon_vabdu, add>;
+                             "vaba", "u", abdu, add>;
 
 //   VABAL    : Vector Absolute Difference and Accumulate Long (Q += | D - D |)
 defm VABALs   : N3VLIntExtOp_QHS<0,1,0b0101,0, IIC_VABAD,
-                                 "vabal", "s", int_arm_neon_vabds, zext, add>;
+                                 "vabal", "s", abds, zext, add>;
 defm VABALu   : N3VLIntExtOp_QHS<1,1,0b0101,0, IIC_VABAD,
-                                 "vabal", "u", int_arm_neon_vabdu, zext, add>;
+                                 "vabal", "u", abdu, zext, add>;
 
 // Vector Maximum and Minimum.
 
diff --git a/llvm/test/CodeGen/ARM/neon_vabs.ll b/llvm/test/CodeGen/ARM/neon_vabs.ll
index 4064aae65f66..9e0d885bd948 100644
--- a/llvm/test/CodeGen/ARM/neon_vabs.ll
+++ b/llvm/test/CodeGen/ARM/neon_vabs.ll
@@ -184,7 +184,10 @@ define <2 x i64> @test13(<2 x i32> %a, <2 x i32> %b) nounwind {
 ; CHECK:       @ %bb.0:
 ; CHECK-NEXT:    vmov d16, r2, r3
 ; CHECK-NEXT:    vmov d17, r0, r1
-; CHECK-NEXT:    vabdl.u32 q8, d17, d16
+; CHECK-NEXT:    vsubl.u32 q8, d17, d16
+; CHECK-NEXT:    vshr.s64 q9, q8, #63
+; CHECK-NEXT:    vsra.s64 q8, q8, #63
+; CHECK-NEXT:    veor q8, q9, q8
 ; CHECK-NEXT:    vmov r0, r1, d16
 ; CHECK-NEXT:    vmov r2, r3, d17
 ; CHECK-NEXT:    mov pc, lr



More information about the llvm-commits mailing list