[llvm] [DAG] fold `avgs(sext(x), sext(y))` -> `sext(avgs(x, y))` (PR #95365)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 13 01:03:07 PDT 2024


https://github.com/c8ef updated https://github.com/llvm/llvm-project/pull/95365

>From df7448f15f2ced82e343678da0f5ab4b545c8f85 Mon Sep 17 00:00:00 2001
From: c8ef <c8ef at outlook.com>
Date: Thu, 13 Jun 2024 07:29:41 +0000
Subject: [PATCH 1/2] fold avgu(sext(x), sext(y)) -> sext(avgu(x, y))

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 15 ++++
 llvm/test/CodeGen/AArch64/avg.ll              | 84 +++++++++++++++++++
 2 files changed, 99 insertions(+)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 78970bc4fe4ab..0d4df4a7ecda5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -5237,6 +5237,7 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
                        DAG.getShiftAmountConstant(1, VT, DL));
 
   // fold avgu(zext(x), zext(y)) -> zext(avgu(x, y))
+  // fold avgu(sext(x), sext(y)) -> sext(avgu(x, y))
   if (sd_match(
           N, m_BinOp(ISD::AVGFLOORU, m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
       X.getValueType() == Y.getValueType() &&
@@ -5251,6 +5252,20 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
     SDValue AvgCeilU = DAG.getNode(ISD::AVGCEILU, DL, X.getValueType(), X, Y);
     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, AvgCeilU);
   }
+  if (sd_match(
+          N, m_BinOp(ISD::AVGFLOORU, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
+      X.getValueType() == Y.getValueType() &&
+      hasOperation(ISD::AVGFLOORU, X.getValueType())) {
+    SDValue AvgFloorU = DAG.getNode(ISD::AVGFLOORU, DL, X.getValueType(), X, Y);
+    return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgFloorU);
+  }
+  if (sd_match(
+          N, m_BinOp(ISD::AVGCEILU, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
+      X.getValueType() == Y.getValueType() &&
+      hasOperation(ISD::AVGCEILU, X.getValueType())) {
+    SDValue AvgCeilU = DAG.getNode(ISD::AVGCEILU, DL, X.getValueType(), X, Y);
+    return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgCeilU);
+  }
 
   // Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
   // Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
diff --git a/llvm/test/CodeGen/AArch64/avg.ll b/llvm/test/CodeGen/AArch64/avg.ll
index dc87708555987..e61b47772b7d7 100644
--- a/llvm/test/CodeGen/AArch64/avg.ll
+++ b/llvm/test/CodeGen/AArch64/avg.ll
@@ -68,3 +68,87 @@ define <16 x i16> @zext_avgceilu_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
   %avg = sub <16 x i16> %or, %shift
   ret <16 x i16> %avg
 }
+
+define <16 x i16> @sext_avgflooru(<16 x i8> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgflooru:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sshll v2.8h, v0.8b, #0
+; CHECK-NEXT:    sshll2 v0.8h, v0.16b, #0
+; CHECK-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NEXT:    shadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT:    shadd v0.8h, v2.8h, v3.8h
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i8> %a0 to <16 x i16>
+  %x1 = sext <16 x i8> %a1 to <16 x i16>
+  %and = and <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = add <16 x i16> %and, %shift
+  ret <16 x i16> %avg
+}
+
+define <16 x i16> @sext_avgflooru_mismatch(<16 x i8> %a0, <16 x i4> %a1) {
+; CHECK-LABEL: sext_avgflooru_mismatch:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll2 v2.8h, v1.16b, #0
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    sshll v3.8h, v0.8b, #0
+; CHECK-NEXT:    sshll2 v0.8h, v0.16b, #0
+; CHECK-NEXT:    shl v1.8h, v1.8h, #12
+; CHECK-NEXT:    shl v2.8h, v2.8h, #12
+; CHECK-NEXT:    sshr v4.8h, v1.8h, #12
+; CHECK-NEXT:    sshr v1.8h, v2.8h, #12
+; CHECK-NEXT:    shadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT:    shadd v0.8h, v3.8h, v4.8h
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i8> %a0 to <16 x i16>
+  %x1 = sext <16 x i4> %a1 to <16 x i16>
+  %and = and <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = add <16 x i16> %and, %shift
+  ret <16 x i16> %avg
+}
+
+define <16 x i16> @sext_avgceilu(<16 x i8> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgceilu:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sshll v2.8h, v0.8b, #0
+; CHECK-NEXT:    sshll2 v0.8h, v0.16b, #0
+; CHECK-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NEXT:    srhadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT:    srhadd v0.8h, v2.8h, v3.8h
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i8> %a0 to <16 x i16>
+  %x1 = sext <16 x i8> %a1 to <16 x i16>
+  %or = or <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = sub <16 x i16> %or, %shift
+  ret <16 x i16> %avg
+}
+
+define <16 x i16> @sext_avgceilu_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgceilu_mismatch:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v2.8h, v0.8b, #0
+; CHECK-NEXT:    ushll2 v0.8h, v0.16b, #0
+; CHECK-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NEXT:    shl v2.8h, v2.8h, #12
+; CHECK-NEXT:    shl v0.8h, v0.8h, #12
+; CHECK-NEXT:    sshr v2.8h, v2.8h, #12
+; CHECK-NEXT:    sshr v0.8h, v0.8h, #12
+; CHECK-NEXT:    srhadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT:    srhadd v0.8h, v2.8h, v3.8h
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i4> %a0 to <16 x i16>
+  %x1 = sext <16 x i8> %a1 to <16 x i16>
+  %or = or <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = sub <16 x i16> %or, %shift
+  ret <16 x i16> %avg
+}

>From fbe39f42ab0559fef068b810768ee81456a0b15e Mon Sep 17 00:00:00 2001
From: c8ef <c8ef at outlook.com>
Date: Thu, 13 Jun 2024 08:02:56 +0000
Subject: [PATCH 2/2] handle avgs, not avgu

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 18 +++++-----
 llvm/test/CodeGen/AArch64/avg.ll              | 34 ++++++++-----------
 2 files changed, 23 insertions(+), 29 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 0d4df4a7ecda5..80b8d48251472 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -5237,7 +5237,7 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
                        DAG.getShiftAmountConstant(1, VT, DL));
 
   // fold avgu(zext(x), zext(y)) -> zext(avgu(x, y))
-  // fold avgu(sext(x), sext(y)) -> sext(avgu(x, y))
+  // fold avgs(sext(x), sext(y)) -> sext(avgs(x, y))
   if (sd_match(
           N, m_BinOp(ISD::AVGFLOORU, m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
       X.getValueType() == Y.getValueType() &&
@@ -5253,18 +5253,18 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, AvgCeilU);
   }
   if (sd_match(
-          N, m_BinOp(ISD::AVGFLOORU, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
+          N, m_BinOp(ISD::AVGFLOORS, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
       X.getValueType() == Y.getValueType() &&
-      hasOperation(ISD::AVGFLOORU, X.getValueType())) {
-    SDValue AvgFloorU = DAG.getNode(ISD::AVGFLOORU, DL, X.getValueType(), X, Y);
-    return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgFloorU);
+      hasOperation(ISD::AVGFLOORS, X.getValueType())) {
+    SDValue AvgFloorS = DAG.getNode(ISD::AVGFLOORS, DL, X.getValueType(), X, Y);
+    return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgFloorS);
   }
   if (sd_match(
-          N, m_BinOp(ISD::AVGCEILU, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
+          N, m_BinOp(ISD::AVGCEILS, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
       X.getValueType() == Y.getValueType() &&
-      hasOperation(ISD::AVGCEILU, X.getValueType())) {
-    SDValue AvgCeilU = DAG.getNode(ISD::AVGCEILU, DL, X.getValueType(), X, Y);
-    return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgCeilU);
+      hasOperation(ISD::AVGCEILS, X.getValueType())) {
+    SDValue AvgCeilS = DAG.getNode(ISD::AVGCEILS, DL, X.getValueType(), X, Y);
+    return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgCeilS);
   }
 
   // Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
diff --git a/llvm/test/CodeGen/AArch64/avg.ll b/llvm/test/CodeGen/AArch64/avg.ll
index e61b47772b7d7..cabc0d346b806 100644
--- a/llvm/test/CodeGen/AArch64/avg.ll
+++ b/llvm/test/CodeGen/AArch64/avg.ll
@@ -69,15 +69,12 @@ define <16 x i16> @zext_avgceilu_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
   ret <16 x i16> %avg
 }
 
-define <16 x i16> @sext_avgflooru(<16 x i8> %a0, <16 x i8> %a1) {
-; CHECK-LABEL: sext_avgflooru:
+define <16 x i16> @sext_avgfloors(<16 x i8> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgfloors:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    sshll v2.8h, v0.8b, #0
-; CHECK-NEXT:    sshll2 v0.8h, v0.16b, #0
-; CHECK-NEXT:    sshll v3.8h, v1.8b, #0
-; CHECK-NEXT:    sshll2 v1.8h, v1.16b, #0
-; CHECK-NEXT:    shadd v1.8h, v0.8h, v1.8h
-; CHECK-NEXT:    shadd v0.8h, v2.8h, v3.8h
+; CHECK-NEXT:    shadd v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    sshll2 v1.8h, v0.16b, #0
+; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
 ; CHECK-NEXT:    ret
   %x0 = sext <16 x i8> %a0 to <16 x i16>
   %x1 = sext <16 x i8> %a1 to <16 x i16>
@@ -88,8 +85,8 @@ define <16 x i16> @sext_avgflooru(<16 x i8> %a0, <16 x i8> %a1) {
   ret <16 x i16> %avg
 }
 
-define <16 x i16> @sext_avgflooru_mismatch(<16 x i8> %a0, <16 x i4> %a1) {
-; CHECK-LABEL: sext_avgflooru_mismatch:
+define <16 x i16> @sext_avgfloors_mismatch(<16 x i8> %a0, <16 x i4> %a1) {
+; CHECK-LABEL: sext_avgfloors_mismatch:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ushll2 v2.8h, v1.16b, #0
 ; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
@@ -111,15 +108,12 @@ define <16 x i16> @sext_avgflooru_mismatch(<16 x i8> %a0, <16 x i4> %a1) {
   ret <16 x i16> %avg
 }
 
-define <16 x i16> @sext_avgceilu(<16 x i8> %a0, <16 x i8> %a1) {
-; CHECK-LABEL: sext_avgceilu:
+define <16 x i16> @sext_avgceils(<16 x i8> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgceils:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    sshll v2.8h, v0.8b, #0
-; CHECK-NEXT:    sshll2 v0.8h, v0.16b, #0
-; CHECK-NEXT:    sshll v3.8h, v1.8b, #0
-; CHECK-NEXT:    sshll2 v1.8h, v1.16b, #0
-; CHECK-NEXT:    srhadd v1.8h, v0.8h, v1.8h
-; CHECK-NEXT:    srhadd v0.8h, v2.8h, v3.8h
+; CHECK-NEXT:    srhadd v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    sshll2 v1.8h, v0.16b, #0
+; CHECK-NEXT:    sshll v0.8h, v0.8b, #0
 ; CHECK-NEXT:    ret
   %x0 = sext <16 x i8> %a0 to <16 x i16>
   %x1 = sext <16 x i8> %a1 to <16 x i16>
@@ -130,8 +124,8 @@ define <16 x i16> @sext_avgceilu(<16 x i8> %a0, <16 x i8> %a1) {
   ret <16 x i16> %avg
 }
 
-define <16 x i16> @sext_avgceilu_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
-; CHECK-LABEL: sext_avgceilu_mismatch:
+define <16 x i16> @sext_avgceils_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgceils_mismatch:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ushll v2.8h, v0.8b, #0
 ; CHECK-NEXT:    ushll2 v0.8h, v0.16b, #0



More information about the llvm-commits mailing list