[llvm] [DAG] fold `avgu(sext(x), sext(y))` -> `sext(avgu(x, y))` (PR #95365)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 13 00:32:40 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: None (c8ef)

<details>
<summary>Changes</summary>

Follow up of #<!-- -->95134.

Context: https://github.com/llvm/llvm-project/pull/95134#issuecomment-2162825594.

---
Full diff: https://github.com/llvm/llvm-project/pull/95365.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+15) 
- (modified) llvm/test/CodeGen/AArch64/avg.ll (+84) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 78970bc4fe4ab..0d4df4a7ecda5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -5237,6 +5237,7 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
                        DAG.getShiftAmountConstant(1, VT, DL));
 
   // fold avgu(zext(x), zext(y)) -> zext(avgu(x, y))
+  // fold avgu(sext(x), sext(y)) -> sext(avgu(x, y))
   if (sd_match(
           N, m_BinOp(ISD::AVGFLOORU, m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
       X.getValueType() == Y.getValueType() &&
@@ -5251,6 +5252,20 @@ SDValue DAGCombiner::visitAVG(SDNode *N) {
     SDValue AvgCeilU = DAG.getNode(ISD::AVGCEILU, DL, X.getValueType(), X, Y);
     return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, AvgCeilU);
   }
+  if (sd_match(
+          N, m_BinOp(ISD::AVGFLOORU, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
+      X.getValueType() == Y.getValueType() &&
+      hasOperation(ISD::AVGFLOORU, X.getValueType())) {
+    SDValue AvgFloorU = DAG.getNode(ISD::AVGFLOORU, DL, X.getValueType(), X, Y);
+    return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgFloorU);
+  }
+  if (sd_match(
+          N, m_BinOp(ISD::AVGCEILU, m_SExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
+      X.getValueType() == Y.getValueType() &&
+      hasOperation(ISD::AVGCEILU, X.getValueType())) {
+    SDValue AvgCeilU = DAG.getNode(ISD::AVGCEILU, DL, X.getValueType(), X, Y);
+    return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, AvgCeilU);
+  }
 
   // Fold avgflooru(x,y) -> avgceilu(x,y-1) iff y != 0
   // Fold avgflooru(x,y) -> avgceilu(x-1,y) iff x != 0
diff --git a/llvm/test/CodeGen/AArch64/avg.ll b/llvm/test/CodeGen/AArch64/avg.ll
index dc87708555987..e61b47772b7d7 100644
--- a/llvm/test/CodeGen/AArch64/avg.ll
+++ b/llvm/test/CodeGen/AArch64/avg.ll
@@ -68,3 +68,87 @@ define <16 x i16> @zext_avgceilu_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
   %avg = sub <16 x i16> %or, %shift
   ret <16 x i16> %avg
 }
+
+define <16 x i16> @sext_avgflooru(<16 x i8> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgflooru:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sshll v2.8h, v0.8b, #0
+; CHECK-NEXT:    sshll2 v0.8h, v0.16b, #0
+; CHECK-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NEXT:    shadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT:    shadd v0.8h, v2.8h, v3.8h
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i8> %a0 to <16 x i16>
+  %x1 = sext <16 x i8> %a1 to <16 x i16>
+  %and = and <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = add <16 x i16> %and, %shift
+  ret <16 x i16> %avg
+}
+
+define <16 x i16> @sext_avgflooru_mismatch(<16 x i8> %a0, <16 x i4> %a1) {
+; CHECK-LABEL: sext_avgflooru_mismatch:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll2 v2.8h, v1.16b, #0
+; CHECK-NEXT:    ushll v1.8h, v1.8b, #0
+; CHECK-NEXT:    sshll v3.8h, v0.8b, #0
+; CHECK-NEXT:    sshll2 v0.8h, v0.16b, #0
+; CHECK-NEXT:    shl v1.8h, v1.8h, #12
+; CHECK-NEXT:    shl v2.8h, v2.8h, #12
+; CHECK-NEXT:    sshr v4.8h, v1.8h, #12
+; CHECK-NEXT:    sshr v1.8h, v2.8h, #12
+; CHECK-NEXT:    shadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT:    shadd v0.8h, v3.8h, v4.8h
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i8> %a0 to <16 x i16>
+  %x1 = sext <16 x i4> %a1 to <16 x i16>
+  %and = and <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = add <16 x i16> %and, %shift
+  ret <16 x i16> %avg
+}
+
+define <16 x i16> @sext_avgceilu(<16 x i8> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgceilu:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sshll v2.8h, v0.8b, #0
+; CHECK-NEXT:    sshll2 v0.8h, v0.16b, #0
+; CHECK-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NEXT:    srhadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT:    srhadd v0.8h, v2.8h, v3.8h
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i8> %a0 to <16 x i16>
+  %x1 = sext <16 x i8> %a1 to <16 x i16>
+  %or = or <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = sub <16 x i16> %or, %shift
+  ret <16 x i16> %avg
+}
+
+define <16 x i16> @sext_avgceilu_mismatch(<16 x i4> %a0, <16 x i8> %a1) {
+; CHECK-LABEL: sext_avgceilu_mismatch:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ushll v2.8h, v0.8b, #0
+; CHECK-NEXT:    ushll2 v0.8h, v0.16b, #0
+; CHECK-NEXT:    sshll v3.8h, v1.8b, #0
+; CHECK-NEXT:    sshll2 v1.8h, v1.16b, #0
+; CHECK-NEXT:    shl v2.8h, v2.8h, #12
+; CHECK-NEXT:    shl v0.8h, v0.8h, #12
+; CHECK-NEXT:    sshr v2.8h, v2.8h, #12
+; CHECK-NEXT:    sshr v0.8h, v0.8h, #12
+; CHECK-NEXT:    srhadd v1.8h, v0.8h, v1.8h
+; CHECK-NEXT:    srhadd v0.8h, v2.8h, v3.8h
+; CHECK-NEXT:    ret
+  %x0 = sext <16 x i4> %a0 to <16 x i16>
+  %x1 = sext <16 x i8> %a1 to <16 x i16>
+  %or = or <16 x i16> %x0, %x1
+  %xor = xor <16 x i16> %x0, %x1
+  %shift = ashr <16 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %avg = sub <16 x i16> %or, %shift
+  ret <16 x i16> %avg
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/95365


More information about the llvm-commits mailing list