[llvm] b96967a - [AArch64] Combine concat through rshrn

David Green via llvm-commits llvm-commits at lists.llvm.org
Wed May 3 06:48:56 PDT 2023


Author: David Green
Date: 2023-05-03T14:48:50+01:00
New Revision: b96967ad172a51060ed77fdc6c46aecb168cb35e

URL: https://github.com/llvm/llvm-project/commit/b96967ad172a51060ed77fdc6c46aecb168cb35e
DIFF: https://github.com/llvm/llvm-project/commit/b96967ad172a51060ed77fdc6c46aecb168cb35e.diff

LOG: [AArch64] Combine concat through rshrn

This tries to push the concat in trunc(concat(rshr, rshr)) into the leaves, so
that we can generate rshrn(concat). This helps improve the codegen for small
types, using the existing rshrn patterns.

Differential Revision: https://reviews.llvm.org/D149636

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/test/CodeGen/AArch64/neon-rshrn.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 93a1259146c67..ae1969b03ce4b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -17069,6 +17069,51 @@ static SDValue performConcatVectorsCombine(SDNode *N,
     }
   }
 
+  auto IsRSHRN = [](SDValue Shr) {
+    if (Shr.getOpcode() != AArch64ISD::VLSHR)
+      return false;
+    SDValue Op = Shr.getOperand(0);
+    EVT VT = Op.getValueType();
+    unsigned ShtAmt = Shr.getConstantOperandVal(1);
+    if (ShtAmt > VT.getScalarSizeInBits() / 2 || Op.getOpcode() != ISD::ADD)
+      return false;
+
+    APInt Imm;
+    if (Op.getOperand(1).getOpcode() == AArch64ISD::MOVIshift)
+      Imm = APInt(VT.getScalarSizeInBits(),
+                  Op.getOperand(1).getConstantOperandVal(0)
+                      << Op.getOperand(1).getConstantOperandVal(1));
+    else if (Op.getOperand(1).getOpcode() == AArch64ISD::DUP &&
+             isa<ConstantSDNode>(Op.getOperand(1).getOperand(0)))
+      Imm = APInt(VT.getScalarSizeInBits(),
+                  Op.getOperand(1).getConstantOperandVal(0));
+    else
+      return false;
+
+    if (Imm != 1ULL << (ShtAmt - 1))
+      return false;
+    return true;
+  };
+
+  // concat(rshrn(x), rshrn(y)) -> rshrn(concat(x, y))
+  if (N->getNumOperands() == 2 && IsRSHRN(N0) &&
+      ((IsRSHRN(N1) &&
+        N0.getConstantOperandVal(1) == N1.getConstantOperandVal(1)) ||
+       N1.isUndef())) {
+    SDValue X = N0.getOperand(0).getOperand(0);
+    SDValue Y = N1.isUndef() ? DAG.getUNDEF(X.getValueType())
+                             : N1.getOperand(0).getOperand(0);
+    EVT BVT =
+        X.getValueType().getDoubleNumVectorElementsVT(*DCI.DAG.getContext());
+    SDValue CC = DAG.getNode(ISD::CONCAT_VECTORS, dl, BVT, X, Y);
+    SDValue Add = DAG.getNode(
+        ISD::ADD, dl, BVT, CC,
+        DAG.getConstant(1ULL << (N0.getConstantOperandVal(1) - 1), dl, BVT));
+    SDValue Shr =
+        DAG.getNode(AArch64ISD::VLSHR, dl, BVT, Add, N0.getOperand(1));
+    return Shr;
+  }
+
   // concat(zip1(a, b), zip2(a, b)) is zip1(a, b)
   if (N->getNumOperands() == 2 && N0Opc == AArch64ISD::ZIP1 &&
       N1Opc == AArch64ISD::ZIP2 && N0.getOperand(0) == N1.getOperand(0) &&

diff  --git a/llvm/test/CodeGen/AArch64/neon-rshrn.ll b/llvm/test/CodeGen/AArch64/neon-rshrn.ll
index 7271050ada752..b29d1a52fb762 100644
--- a/llvm/test/CodeGen/AArch64/neon-rshrn.ll
+++ b/llvm/test/CodeGen/AArch64/neon-rshrn.ll
@@ -823,15 +823,8 @@ entry:
 define void @rshrn_v8i32i8_5(<8 x i32> %a, ptr %p) {
 ; CHECK-LABEL: rshrn_v8i32i8_5:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    movi v2.4h, #16
-; CHECK-NEXT:    xtn v1.4h, v1.4s
-; CHECK-NEXT:    xtn v0.4h, v0.4s
-; CHECK-NEXT:    add v1.4h, v1.4h, v2.4h
-; CHECK-NEXT:    add v0.4h, v0.4h, v2.4h
-; CHECK-NEXT:    ushr v1.4h, v1.4h, #5
-; CHECK-NEXT:    ushr v0.4h, v0.4h, #5
-; CHECK-NEXT:    mov v0.d[1], v1.d[0]
-; CHECK-NEXT:    xtn v0.8b, v0.8h
+; CHECK-NEXT:    uzp1 v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    rshrn v0.8b, v0.8h, #5
 ; CHECK-NEXT:    str d0, [x0]
 ; CHECK-NEXT:    ret
 entry:
@@ -845,15 +838,8 @@ entry:
 define void @rshrn_v4i64i16_4(<4 x i64> %a, ptr %p) {
 ; CHECK-LABEL: rshrn_v4i64i16_4:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    movi v2.2s, #8
-; CHECK-NEXT:    xtn v1.2s, v1.2d
-; CHECK-NEXT:    xtn v0.2s, v0.2d
-; CHECK-NEXT:    add v1.2s, v1.2s, v2.2s
-; CHECK-NEXT:    add v0.2s, v0.2s, v2.2s
-; CHECK-NEXT:    ushr v1.2s, v1.2s, #4
-; CHECK-NEXT:    ushr v0.2s, v0.2s, #4
-; CHECK-NEXT:    mov v0.d[1], v1.d[0]
-; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    uzp1 v0.4s, v0.4s, v1.4s
+; CHECK-NEXT:    rshrn v0.4h, v0.4s, #4
 ; CHECK-NEXT:    str d0, [x0]
 ; CHECK-NEXT:    ret
 entry:
@@ -867,10 +853,8 @@ entry:
 define void @rshrn_v4i16_5(<4 x i16> %a, ptr %p) {
 ; CHECK-LABEL: rshrn_v4i16_5:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    movi v1.4h, #16
-; CHECK-NEXT:    add v0.4h, v0.4h, v1.4h
-; CHECK-NEXT:    ushr v0.4h, v0.4h, #5
-; CHECK-NEXT:    xtn v0.8b, v0.8h
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    rshrn v0.8b, v0.8h, #5
 ; CHECK-NEXT:    str s0, [x0]
 ; CHECK-NEXT:    ret
 entry:
@@ -903,11 +887,8 @@ entry:
 define void @rshrn_v1i64_5(<1 x i64> %a, ptr %p) {
 ; CHECK-LABEL: rshrn_v1i64_5:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    mov w8, #16 // =0x10
-; CHECK-NEXT:    fmov d1, x8
-; CHECK-NEXT:    add d0, d0, d1
-; CHECK-NEXT:    ushr d0, d0, #5
-; CHECK-NEXT:    xtn v0.2s, v0.2d
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    rshrn v0.2s, v0.2d, #5
 ; CHECK-NEXT:    str s0, [x0]
 ; CHECK-NEXT:    ret
 entry:


        


More information about the llvm-commits mailing list