[llvm] [DAG] fold `avgu(sext(x), sext(y))` -> `sext(avgu(x, y))` (PR #95365)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 13 00:32:09 PDT 2024


https://github.com/c8ef created https://github.com/llvm/llvm-project/pull/95365

Follow up of #95134.

Context: https://github.com/llvm/llvm-project/pull/95134#issuecomment-2162825594.

>From df7448f15f2ced82e343678da0f5ab4b545c8f85 Mon Sep 17 00:00:00 2001
From: c8ef <c8ef at outlook.com>
Date: Thu, 13 Jun 2024 07:29:41 +0000
Subject: [PATCH] fold avgu(sext(x), sext(y)) -> sext(avgu(x, y))

---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 15 ++++
 llvm/test/CodeGen/AArch64/avg.ll              | 84 +++++++++++++++++++
 2 files changed, 99 insertions(+)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 78970bc4fe4ab..0d4df4a7ecda5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -5237,6 +5237,7 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
                        DAG.getShiftAmountConstant(1, VT, DL));
 
   // fold avgu(zext(x), zext(y)) -> zext(avgu(x, y))
+  // fold avgu(sext(x), sext(y)) -> sext(avgu(x, y))
   if (sd_match(
           N, m_BinOp(ISD::AVGFLOORU, m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
       X.getValueType() == Y.getValueType() &&
@@ -5251,6 +5252,20 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
     SDValue AvgCeilU = DAG.getNode(ISD::AVGCEILU, DL, X.getValueType(), X, Y);
     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, AvgCeilU);
   }
+  if (sd_match(
+          N, m_BinOp(ISD::AVGFLOORU, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
+      X.getValueType() == Y.getValueType() &&
+      hasOperation(ISD::AVGFLOORU, X.getValueType())) {
+    SDValue AvgFloorU = DAG.getNode(ISD::AVGFLOORU, DL, X.getValueType(), X, Y);
+    return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgFloorU);
+  }
+  if (sd_match(
+          N, m_BinOp(ISD::AVGCEILU, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
+      X.getValueType() == Y.getValueType() &&
+      hasOperation(ISD::AVGCEILU, X.getValueType())) {
+    SDValue AvgCeilU = DAG.getNode(ISD::AVGCEILU, DL, X.getValueType(), X, Y);
+    return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgCeilU);
+  }
 
   // Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
   // Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
diff --git a/llvm/test/CodeGen/AArch64/avg.ll b/llvm/test/CodeGen/AArch64/avg.ll
index dc87708555987..e61b47772b7d7 100644
--- a/llvm/test/CodeGen/AArch64/avg.ll
+++ b/llvm/test/CodeGen/AArch64/avg.ll
@@ -68,3 +68,87 @@ define <16 x i16> @zext_avgceilu_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
   %avg = sub <16 x i16> %or, %shift
   ret <16 x i16> %avg
 }
+
+define <16 x i16> @sext_avgflooru(<16 x i8> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgflooru:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sshll v2.8h, v0.8b, #0
+; CHECK-NEXT:    sshll2 v0.8h, v0.16b, #0
+; CHECK-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NEXT:    shadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT:    shadd v0.8h, v2.8h, v3.8h
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i8> %a0 to <16 x i16>
+  %x1 = sext <16 x i8> %a1 to <16 x i16>
+  %and = and <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = add <16 x i16> %and, %shift
+  ret <16 x i16> %avg
+}
+
+define <16 x i16> @sext_avgflooru_mismatch(<16 x i8> %a0, <16 x i4> %a1) {
+; CHECK-LABEL: sext_avgflooru_mismatch:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll2 v2.8h, v1.16b, #0
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    sshll v3.8h, v0.8b, #0
+; CHECK-NEXT:    sshll2 v0.8h, v0.16b, #0
+; CHECK-NEXT:    shl v1.8h, v1.8h, #12
+; CHECK-NEXT:    shl v2.8h, v2.8h, #12
+; CHECK-NEXT:    sshr v4.8h, v1.8h, #12
+; CHECK-NEXT:    sshr v1.8h, v2.8h, #12
+; CHECK-NEXT:    shadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT:    shadd v0.8h, v3.8h, v4.8h
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i8> %a0 to <16 x i16>
+  %x1 = sext <16 x i4> %a1 to <16 x i16>
+  %and = and <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = add <16 x i16> %and, %shift
+  ret <16 x i16> %avg
+}
+
+define <16 x i16> @sext_avgceilu(<16 x i8> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgceilu:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sshll v2.8h, v0.8b, #0
+; CHECK-NEXT:    sshll2 v0.8h, v0.16b, #0
+; CHECK-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NEXT:    srhadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT:    srhadd v0.8h, v2.8h, v3.8h
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i8> %a0 to <16 x i16>
+  %x1 = sext <16 x i8> %a1 to <16 x i16>
+  %or = or <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = sub <16 x i16> %or, %shift
+  ret <16 x i16> %avg
+}
+
+define <16 x i16> @sext_avgceilu_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgceilu_mismatch:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v2.8h, v0.8b, #0
+; CHECK-NEXT:    ushll2 v0.8h, v0.16b, #0
+; CHECK-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NEXT:    shl v2.8h, v2.8h, #12
+; CHECK-NEXT:    shl v0.8h, v0.8h, #12
+; CHECK-NEXT:    sshr v2.8h, v2.8h, #12
+; CHECK-NEXT:    sshr v0.8h, v0.8h, #12
+; CHECK-NEXT:    srhadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT:    srhadd v0.8h, v2.8h, v3.8h
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i4> %a0 to <16 x i16>
+  %x1 = sext <16 x i8> %a1 to <16 x i16>
+  %or = or <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = sub <16 x i16> %or, %shift
+  ret <16 x i16> %avg
+}



More information about the llvm-commits mailing list