[llvm] 08fae46 - [DAG] fold `avgs(sext(x), sext(y))` -> `sext(avgs(x, y))` (#95365)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 14 06:53:33 PDT 2024


Author: c8ef
Date: 2024-06-14T21:53:29+08:00
New Revision: 08fae467e4c742e91c8fdff8519718cf2c7c9b0e

URL: https://github.com/llvm/llvm-project/commit/08fae467e4c742e91c8fdff8519718cf2c7c9b0e
DIFF: https://github.com/llvm/llvm-project/commit/08fae467e4c742e91c8fdff8519718cf2c7c9b0e.diff

LOG: [DAG] fold `avgs(sext(x), sext(y))` -> `sext(avgs(x, y))` (#95365)

Follow up of #95134.

Context:
https://github.com/llvm/llvm-project/pull/95134#issuecomment-2162825594.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/AArch64/aarch64-known-bits-hadd.ll
    llvm/test/CodeGen/AArch64/avg.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 78970bc4fe4ab..80b8d48251472 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -5237,6 +5237,7 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
                        DAG.getShiftAmountConstant(1, VT, DL));
 
   // fold avgu(zext(x), zext(y)) -> zext(avgu(x, y))
+  // fold avgs(sext(x), sext(y)) -> sext(avgs(x, y))
   if (sd_match(
           N, m_BinOp(ISD::AVGFLOORU, m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
       X.getValueType() == Y.getValueType() &&
@@ -5251,6 +5252,20 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
     SDValue AvgCeilU = DAG.getNode(ISD::AVGCEILU, DL, X.getValueType(), X, Y);
     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, AvgCeilU);
   }
+  if (sd_match(
+          N, m_BinOp(ISD::AVGFLOORS, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
+      X.getValueType() == Y.getValueType() &&
+      hasOperation(ISD::AVGFLOORS, X.getValueType())) {
+    SDValue AvgFloorS = DAG.getNode(ISD::AVGFLOORS, DL, X.getValueType(), X, Y);
+    return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgFloorS);
+  }
+  if (sd_match(
+          N, m_BinOp(ISD::AVGCEILS, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
+      X.getValueType() == Y.getValueType() &&
+      hasOperation(ISD::AVGCEILS, X.getValueType())) {
+    SDValue AvgCeilS = DAG.getNode(ISD::AVGCEILS, DL, X.getValueType(), X, Y);
+    return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgCeilS);
+  }
 
   // Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
   // Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0

diff  --git a/llvm/test/CodeGen/AArch64/aarch64-known-bits-hadd.ll b/llvm/test/CodeGen/AArch64/aarch64-known-bits-hadd.ll
index b2cf089d8145f..0506e1ed9710b 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-known-bits-hadd.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-known-bits-hadd.ll
@@ -95,9 +95,8 @@ define <8 x i16> @urhadd_sext(<8 x i8> %a0, <8 x i8> %a1) {
 define <8 x i16> @hadds_sext(<8 x i8> %a0, <8 x i8> %a1) {
 ; CHECK-LABEL: hadds_sext:
 ; CHECK:       // %bb.0:
+; CHECK-NEXT:    shadd v0.8b, v0.8b, v1.8b
 ; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-NEXT:    sshll v1.8h, v1.8b, #0
-; CHECK-NEXT:    shadd v0.8h, v0.8h, v1.8h
 ; CHECK-NEXT:    bic v0.8h, #254, lsl #8
 ; CHECK-NEXT:    ret
   %x0 = sext <8 x i8> %a0 to <8 x i16>
@@ -110,9 +109,8 @@ define <8 x i16> @hadds_sext(<8 x i8> %a0, <8 x i8> %a1) {
 define <8 x i16> @shaddu_sext(<8 x i8> %a0, <8 x i8> %a1) {
 ; CHECK-LABEL: shaddu_sext:
 ; CHECK:       // %bb.0:
+; CHECK-NEXT:    srhadd v0.8b, v0.8b, v1.8b
 ; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
-; CHECK-NEXT:    sshll v1.8h, v1.8b, #0
-; CHECK-NEXT:    srhadd v0.8h, v0.8h, v1.8h
 ; CHECK-NEXT:    bic v0.8h, #254, lsl #8
 ; CHECK-NEXT:    ret
   %x0 = sext <8 x i8> %a0 to <8 x i16>

diff  --git a/llvm/test/CodeGen/AArch64/avg.ll b/llvm/test/CodeGen/AArch64/avg.ll
index dc87708555987..cabc0d346b806 100644
--- a/llvm/test/CodeGen/AArch64/avg.ll
+++ b/llvm/test/CodeGen/AArch64/avg.ll
@@ -68,3 +68,81 @@ define <16 x i16> @zext_avgceilu_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
   %avg = sub <16 x i16> %or, %shift
   ret <16 x i16> %avg
 }
+
+define <16 x i16> @sext_avgfloors(<16 x i8> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgfloors:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shadd v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    sshll2 v1.8h, v0.16b, #0
+; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i8> %a0 to <16 x i16>
+  %x1 = sext <16 x i8> %a1 to <16 x i16>
+  %and = and <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = add <16 x i16> %and, %shift
+  ret <16 x i16> %avg
+}
+
+define <16 x i16> @sext_avgfloors_mismatch(<16 x i8> %a0, <16 x i4> %a1) {
+; CHECK-LABEL: sext_avgfloors_mismatch:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll2 v2.8h, v1.16b, #0
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    sshll v3.8h, v0.8b, #0
+; CHECK-NEXT:    sshll2 v0.8h, v0.16b, #0
+; CHECK-NEXT:    shl v1.8h, v1.8h, #12
+; CHECK-NEXT:    shl v2.8h, v2.8h, #12
+; CHECK-NEXT:    sshr v4.8h, v1.8h, #12
+; CHECK-NEXT:    sshr v1.8h, v2.8h, #12
+; CHECK-NEXT:    shadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT:    shadd v0.8h, v3.8h, v4.8h
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i8> %a0 to <16 x i16>
+  %x1 = sext <16 x i4> %a1 to <16 x i16>
+  %and = and <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = add <16 x i16> %and, %shift
+  ret <16 x i16> %avg
+}
+
+define <16 x i16> @sext_avgceils(<16 x i8> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgceils:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    srhadd v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    sshll2 v1.8h, v0.16b, #0
+; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i8> %a0 to <16 x i16>
+  %x1 = sext <16 x i8> %a1 to <16 x i16>
+  %or = or <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = sub <16 x i16> %or, %shift
+  ret <16 x i16> %avg
+}
+
+define <16 x i16> @sext_avgceils_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgceils_mismatch:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v2.8h, v0.8b, #0
+; CHECK-NEXT:    ushll2 v0.8h, v0.16b, #0
+; CHECK-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NEXT:    shl v2.8h, v2.8h, #12
+; CHECK-NEXT:    shl v0.8h, v0.8h, #12
+; CHECK-NEXT:    sshr v2.8h, v2.8h, #12
+; CHECK-NEXT:    sshr v0.8h, v0.8h, #12
+; CHECK-NEXT:    srhadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT:    srhadd v0.8h, v2.8h, v3.8h
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i4> %a0 to <16 x i16>
+  %x1 = sext <16 x i8> %a1 to <16 x i16>
+  %or = or <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = sub <16 x i16> %or, %shift
+  ret <16 x i16> %avg
+}


        


More information about the llvm-commits mailing list