[llvm] [AArch64] Don't expand RSHRN intrinsics to add+srl+trunc. (PR #67451)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 26 09:17:21 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

<details>
<summary>Changes</summary>

We expand aarch64_neon_rshrn intrinsics to trunc(srl(add)), having tablegen patterns to combine the results back into rshrn. See D140297.  Unfortunately, but perhaps not surprisingly, other combines can happen that prevent us converting back.  For example sext(rshrn) becomes sext(trunc(srl(add))) which will turn into sext_inreg(srl(add))).

This patch just prevents the expansion of rshrn intrinsics, reinstating the old tablegen patterns for selecting them. This should allow us to still regognize the rshrn instructions from trunc+shift+add, without performing any negative optimizations for the intrinsics.

---
Full diff: https://github.com/llvm/llvm-project/pull/67451.diff


3 Files Affected:

- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (-11) 
- (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+28-2) 
- (modified) llvm/test/CodeGen/AArch64/arm64-vshift.ll (+6-12) 


``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 3de6bd1ec94a82a..de88be2e258fddd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -19390,17 +19390,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_neon_sshl:
   case Intrinsic::aarch64_neon_ushl:
     return tryCombineShiftImm(IID, N, DAG);
-  case Intrinsic::aarch64_neon_rshrn: {
-    EVT VT = N->getOperand(1).getValueType();
-    SDLoc DL(N);
-    SDValue Imm =
-        DAG.getConstant(1LLU << (N->getConstantOperandVal(2) - 1), DL, VT);
-    SDValue Add = DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Imm);
-    SDValue Sht =
-        DAG.getNode(ISD::SRL, DL, VT, Add,
-                    DAG.getConstant(N->getConstantOperandVal(2), DL, VT));
-    return DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), Sht);
-  }
   case Intrinsic::aarch64_neon_sabd:
     return DAG.getNode(ISD::ABDS, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2));
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 8350bf71e29e5b4..cd377d659ad3ee9 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -777,6 +777,9 @@ def AArch64faddp     : PatFrags<(ops node:$Rn, node:$Rm),
                                 [(AArch64addp_n node:$Rn, node:$Rm),
                                  (int_aarch64_neon_faddp node:$Rn, node:$Rm)]>;
 def AArch64roundingvlshr : ComplexPattern<vAny, 2, "SelectRoundingVLShr", [AArch64vlshr]>;
+def AArch64rshrn : PatFrags<(ops node:$LHS, node:$RHS),
+                            [(trunc (AArch64roundingvlshr node:$LHS, node:$RHS)),
+                             (int_aarch64_neon_rshrn node:$LHS, node:$RHS)]>;
 def AArch64facge     : PatFrags<(ops node:$Rn, node:$Rm),
                                 [(AArch64fcmge (fabs node:$Rn), (fabs node:$Rm)),
                                  (int_aarch64_neon_facge node:$Rn, node:$Rm)]>;
@@ -7191,8 +7194,7 @@ defm FCVTZS:SIMDVectorRShiftSD<0, 0b11111, "fcvtzs", int_aarch64_neon_vcvtfp2fxs
 defm FCVTZU:SIMDVectorRShiftSD<1, 0b11111, "fcvtzu", int_aarch64_neon_vcvtfp2fxu>;
 defm SCVTF: SIMDVectorRShiftToFP<0, 0b11100, "scvtf",
                                    int_aarch64_neon_vcvtfxs2fp>;
-defm RSHRN   : SIMDVectorRShiftNarrowBHS<0, 0b10001, "rshrn",
-                          BinOpFrag<(trunc (AArch64roundingvlshr node:$LHS, node:$RHS))>>;
+defm RSHRN   : SIMDVectorRShiftNarrowBHS<0, 0b10001, "rshrn", AArch64rshrn>;
 defm SHL     : SIMDVectorLShiftBHSD<0, 0b01010, "shl", AArch64vshl>;
 
 // X << 1 ==> X + X
@@ -7263,6 +7265,12 @@ def : Pat<(v4i16 (trunc (AArch64vlshr (add (v4i32 V128:$Vn), (AArch64movi_shift
 let AddedComplexity = 5 in
 def : Pat<(v2i32 (trunc (AArch64vlshr (add (v2i64 V128:$Vn), (AArch64dup (i64 2147483648))), (i32 32)))),
           (RADDHNv2i64_v2i32 V128:$Vn, (v2i64 (MOVIv2d_ns (i32 0))))>;
+def : Pat<(v8i8 (int_aarch64_neon_rshrn (v8i16 V128:$Vn), (i32 8))),
+          (RADDHNv8i16_v8i8 V128:$Vn, (v8i16 (MOVIv2d_ns (i32 0))))>;
+def : Pat<(v4i16 (int_aarch64_neon_rshrn (v4i32 V128:$Vn), (i32 16))),
+          (RADDHNv4i32_v4i16 V128:$Vn, (v4i32 (MOVIv2d_ns (i32 0))))>;
+def : Pat<(v2i32 (int_aarch64_neon_rshrn (v2i64 V128:$Vn), (i32 32))),
+          (RADDHNv2i64_v2i32 V128:$Vn, (v2i64 (MOVIv2d_ns (i32 0))))>;
 
 // RADDHN2 patterns for when RSHRN shifts by half the size of the vector element
 def : Pat<(v16i8 (concat_vectors
@@ -7284,6 +7292,24 @@ def : Pat<(v4i32 (concat_vectors
           (RADDHNv2i64_v4i32
                  (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn,
                  (v2i64 (MOVIv2d_ns (i32 0))))>;
+def : Pat<(v16i8 (concat_vectors
+                 (v8i8 V64:$Vd),
+                 (v8i8 (int_aarch64_neon_rshrn (v8i16 V128:$Vn), (i32 8))))),
+          (RADDHNv8i16_v16i8
+                 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn,
+                 (v8i16 (MOVIv2d_ns (i32 0))))>;
+def : Pat<(v8i16 (concat_vectors
+                 (v4i16 V64:$Vd),
+                 (v4i16 (int_aarch64_neon_rshrn (v4i32 V128:$Vn), (i32 16))))),
+          (RADDHNv4i32_v8i16
+                 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn,
+                 (v4i32 (MOVIv2d_ns (i32 0))))>;
+def : Pat<(v4i32 (concat_vectors
+                 (v2i32 V64:$Vd),
+                 (v2i32 (int_aarch64_neon_rshrn (v2i64 V128:$Vn), (i32 32))))),
+          (RADDHNv2i64_v4i32
+                 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn,
+                 (v2i64 (MOVIv2d_ns (i32 0))))>;
 
 // SHRN patterns for when a logical right shift was used instead of arithmetic
 // (the immediate guarantees no sign bits actually end up in the result so it
diff --git a/llvm/test/CodeGen/AArch64/arm64-vshift.ll b/llvm/test/CodeGen/AArch64/arm64-vshift.ll
index ef54f6d2bb1828f..367c3be242a17fa 100644
--- a/llvm/test/CodeGen/AArch64/arm64-vshift.ll
+++ b/llvm/test/CodeGen/AArch64/arm64-vshift.ll
@@ -3531,11 +3531,8 @@ entry:
 define <4 x i32> @sext_rshrn(<4 x i32> noundef %a) {
 ; CHECK-LABEL: sext_rshrn:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    movi.4s v1, #16, lsl #8
-; CHECK-NEXT:    add.4s v0, v0, v1
-; CHECK-NEXT:    ushr.4s v0, v0, #13
-; CHECK-NEXT:    shl.4s v0, v0, #16
-; CHECK-NEXT:    sshr.4s v0, v0, #16
+; CHECK-NEXT:    rshrn.4h v0, v0, #13
+; CHECK-NEXT:    sshll.4s v0, v0, #0
 ; CHECK-NEXT:    ret
 entry:
   %vrshrn_n1 = tail call <4 x i16> @llvm.aarch64.neon.rshrn.v4i16(<4 x i32> %a, i32 13)
@@ -3546,10 +3543,8 @@ entry:
 define <4 x i32> @zext_rshrn(<4 x i32> noundef %a) {
 ; CHECK-LABEL: zext_rshrn:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    movi.4s v1, #16, lsl #8
-; CHECK-NEXT:    add.4s v0, v0, v1
-; CHECK-NEXT:    ushr.4s v0, v0, #13
-; CHECK-NEXT:    bic.4s v0, #7, lsl #16
+; CHECK-NEXT:    rshrn.4h v0, v0, #13
+; CHECK-NEXT:    ushll.4s v0, v0, #0
 ; CHECK-NEXT:    ret
 entry:
   %vrshrn_n1 = tail call <4 x i16> @llvm.aarch64.neon.rshrn.v4i16(<4 x i32> %a, i32 13)
@@ -3560,10 +3555,9 @@ entry:
 define <4 x i16> @mul_rshrn(<4 x i32> noundef %a) {
 ; CHECK-LABEL: mul_rshrn:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    mov w8, #4099 // =0x1003
-; CHECK-NEXT:    dup.4s v1, w8
+; CHECK-NEXT:    movi.4s v1, #3
 ; CHECK-NEXT:    add.4s v0, v0, v1
-; CHECK-NEXT:    shrn.4h v0, v0, #13
+; CHECK-NEXT:    rshrn.4h v0, v0, #13
 ; CHECK-NEXT:    ret
 entry:
   %b = add <4 x i32> %a, <i32 3, i32 3, i32 3, i32 3>

``````````

</details>


https://github.com/llvm/llvm-project/pull/67451


More information about the llvm-commits mailing list