[llvm] 0773854 - [DAG] Fold trunc(avg(x,y)) for avgceil/floor u/s nodes if they have sufficient leading zero/sign bits (#152273)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 18 08:36:29 PDT 2025


Author: 黃國庭
Date: 2025-08-18T16:36:26+01:00
New Revision: 07738545758be942cb674254ed4bc6d12db48563

URL: https://github.com/llvm/llvm-project/commit/07738545758be942cb674254ed4bc6d12db48563
DIFF: https://github.com/llvm/llvm-project/commit/07738545758be942cb674254ed4bc6d12db48563.diff

LOG: [DAG] Fold trunc(avg(x,y)) for avgceil/floor u/s nodes if they have sufficient leading zero/sign bits (#152273)

avgceil version :  https://alive2.llvm.org/ce/z/2CKrRh  

Fixes #147773 

---------

Co-authored-by: Simon Pilgrim <llvm-dev at redking.me.uk>

Added: 
    llvm/test/CodeGen/AArch64/trunc-avg-fold.ll

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 43d4138df8b49..c16ccaf926bc7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -16279,6 +16279,40 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
   // because targets may prefer a wider type during later combines and invert
   // this transform.
   switch (N0.getOpcode()) {
+  case ISD::AVGCEILU:
+  case ISD::AVGFLOORU:
+    if (!LegalOperations && N0.hasOneUse() &&
+        TLI.isOperationLegal(N0.getOpcode(), VT)) {
+      SDValue X = N0.getOperand(0);
+      SDValue Y = N0.getOperand(1);
+      unsigned SrcBits = X.getScalarValueSizeInBits();
+      unsigned DstBits = VT.getScalarSizeInBits();
+      APInt UpperBits = APInt::getBitsSetFrom(SrcBits, DstBits);
+      if (DAG.MaskedValueIsZero(X, UpperBits) &&
+          DAG.MaskedValueIsZero(Y, UpperBits)) {
+        SDValue Tx = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
+        SDValue Ty = DAG.getNode(ISD::TRUNCATE, DL, VT, Y);
+        return DAG.getNode(N0.getOpcode(), DL, VT, Tx, Ty);
+      }
+    }
+    break;
+  case ISD::AVGCEILS:
+  case ISD::AVGFLOORS:
+    if (!LegalOperations && N0.hasOneUse() &&
+        TLI.isOperationLegal(N0.getOpcode(), VT)) {
+      SDValue X = N0.getOperand(0);
+      SDValue Y = N0.getOperand(1);
+      unsigned SrcBits = X.getScalarValueSizeInBits();
+      unsigned DstBits = VT.getScalarSizeInBits();
+      unsigned NeededSignBits = SrcBits - DstBits + 1;
+      if (DAG.ComputeNumSignBits(X) >= NeededSignBits &&
+          DAG.ComputeNumSignBits(Y) >= NeededSignBits) {
+        SDValue Tx = DAG.getNode(ISD::TRUNCATE, DL, VT, X);
+        SDValue Ty = DAG.getNode(ISD::TRUNCATE, DL, VT, Y);
+        return DAG.getNode(N0.getOpcode(), DL, VT, Tx, Ty);
+      }
+    }
+    break;
   case ISD::ADD:
   case ISD::SUB:
   case ISD::MUL:

diff  --git a/llvm/test/CodeGen/AArch64/trunc-avg-fold.ll b/llvm/test/CodeGen/AArch64/trunc-avg-fold.ll
new file mode 100644
index 0000000000000..54fcae4ba28b7
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/trunc-avg-fold.ll
@@ -0,0 +1,53 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64-- -O2 -mattr=+neon < %s | FileCheck %s
+
+define <8 x i8> @avgceil_u_i8_to_i16(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: avgceil_u_i8_to_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    urhadd v0.8b, v0.8b, v1.8b
+; CHECK-NEXT:    ret
+  %a16 = zext <8 x i8> %a to <8 x i16>
+  %b16 = zext <8 x i8> %b to <8 x i16>
+  %avg16 = call <8 x i16> @llvm.aarch64.neon.urhadd.v8i16(<8 x i16> %a16, <8 x i16> %b16)
+  %r = trunc <8 x i16> %avg16 to <8 x i8>
+  ret <8 x i8> %r
+}
+
+
+define <8 x i8> @test_avgceil_s(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: test_avgceil_s:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    srhadd v0.8b, v0.8b, v1.8b
+; CHECK-NEXT:    ret
+  %a16 = sext <8 x i8> %a to <8 x i16>
+  %b16 = sext <8 x i8> %b to <8 x i16>
+  %avg16 = call <8 x i16> @llvm.aarch64.neon.srhadd.v8i16(<8 x i16> %a16, <8 x i16> %b16)
+  %res  = trunc <8 x i16> %avg16 to <8 x i8>
+  ret <8 x i8> %res
+}
+
+define <8 x i8> @avgfloor_u_i8_to_i16(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: avgfloor_u_i8_to_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uhadd v0.8b, v0.8b, v1.8b
+; CHECK-NEXT:    ret
+  %a16 = zext  <8 x i8>  %a to <8 x i16>
+  %b16 = zext  <8 x i8>  %b to <8 x i16>
+  %avg16 = call <8 x i16> @llvm.aarch64.neon.uhadd.v8i16(<8 x i16> %a16, <8 x i16> %b16)
+  %res = trunc <8 x i16> %avg16 to <8 x i8>
+  ret <8 x i8> %res
+}
+
+define <8 x i8> @test_avgfloor_s(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: test_avgfloor_s:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shadd v0.8b, v0.8b, v1.8b
+; CHECK-NEXT:    ret
+  %a16 = sext  <8 x i8>  %a to <8 x i16>
+  %b16 = sext  <8 x i8>  %b to <8 x i16>
+  %avg16 = call <8 x i16> @llvm.aarch64.neon.shadd.v8i16(<8 x i16> %a16, <8 x i16> %b16)
+  %res  = trunc <8 x i16> %avg16 to <8 x i8>
+  ret <8 x i8> %res
+}
+
+


        


More information about the llvm-commits mailing list