[llvm] Match fixed width ISD::AVGFLOORS + ISD::AVGCEILS patterns (PR #86222)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 21 17:24:17 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: None (houndlord)

<details>
<summary>Changes</summary>

Fixes [#<!-- -->8577 ].

---
Full diff: https://github.com/llvm/llvm-project/pull/86222.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+42) 
- (modified) llvm/test/CodeGen/AArch64/hadd-combine.ll (+24) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 7009f375df1151..8debff80628432 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2546,6 +2546,23 @@ static SDValue combineFixedwidthToAVGCEILU(SDNode *N, SelectionDAG &DAG) {
   return SDValue();
 }
 
+// Attempt to form avgceils(A, B) from (A | B) - ((A ^ B) >> 1)
+static SDValue combineFixedwidthToAVGCEILS(SDNode *N, SelectionDAG &DAG) {
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  SDValue N0 = N->getOperand(0);
+  EVT VT = N0.getValueType();
+  SDLoc DL(N);
+  if (TLI.isOperationLegal(ISD::AVGCEILS, VT)) {
+    SDValue A, B;
+    if (sd_match(N, m_Sub(m_Or(m_Value(A), m_Value(B)),
+                          m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)),
+                                m_SpecificInt(1))))) {
+      return DAG.getNode(ISD::AVGCEILS, DL, VT, A, B);
+    }
+  }
+  return SDValue();
+}
+
 /// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
 /// a shift and add with a different constant.
 static SDValue foldAddSubOfSignBit(SDNode *N, SelectionDAG &DAG) {
@@ -2854,6 +2871,23 @@ static SDValue combineFixedwidthToAVGFLOORU(SDNode *N, SelectionDAG &DAG) {
   return SDValue();
 }
 
+// Attempt to form avgfloors(A, B) from (A & B) + ((A ^ B) >> 1)
+static SDValue combineFixedwidthToAVGFLOORS(SDNode *N, SelectionDAG &DAG) {
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  SDValue N0 = N->getOperand(0);
+  EVT VT = N0.getValueType();
+  SDLoc DL(N);
+  if (TLI.isOperationLegal(ISD::AVGFLOORS, VT)) {
+    SDValue A, B;
+    if (sd_match(N, m_Add(m_And(m_Value(A), m_Value(B)),
+                          m_Sra(m_Xor(m_Deferred(A), m_Deferred(B)),
+                                m_SpecificInt(1))))) {
+      return DAG.getNode(ISD::AVGFLOORS, DL, VT, A, B);
+    }
+  }
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitADD(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
@@ -2872,6 +2906,10 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
   // Try to match AVGFLOORU fixedwidth pattern
   if (SDValue V = combineFixedwidthToAVGFLOORU(N, DAG))
     return V;
+  
+  // Try to match AVGFLOORU fixedwidth pattern
+  if (SDValue V = combineFixedwidthToAVGFLOORS(N, DAG))
+    return V;
 
   // fold (a+b) -> (a|b) iff a and b share no bits.
   if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) &&
@@ -3867,6 +3905,10 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
   if (SDValue V = combineFixedwidthToAVGCEILU(N, DAG))
     return V;
 
+  // Try to match AVGCEILS fixedwidth pattern
+  if (SDValue V = combineFixedwidthToAVGCEILS(N, DAG))
+    return V;
+
   if (SDValue V = foldAddSubMasked1(false, N0, N1, DAG, SDLoc(N)))
     return V;
 
diff --git a/llvm/test/CodeGen/AArch64/hadd-combine.ll b/llvm/test/CodeGen/AArch64/hadd-combine.ll
index e12502980790da..cbc0386ea8553a 100644
--- a/llvm/test/CodeGen/AArch64/hadd-combine.ll
+++ b/llvm/test/CodeGen/AArch64/hadd-combine.ll
@@ -341,6 +341,18 @@ define <8 x i16> @sub_fixedwidth_v4i32(<8 x i16> %a0, <8 x i16> %a1)  {
   ret <8 x i16> %res
 }
 
+define <8 x i16> @shsub_fixedwidth_v4i32(<8 x i16> %a0, <8 x i16> %a1)  {
+; CHECK-LABEL: shsub_fixedwidth_v4i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    srhadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    ret
+  %or = or <8 x i16> %a0, %a1
+  %xor = xor <8 x i16> %a0, %a1
+  %srl = ashr <8 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %res = sub <8 x i16> %or, %srl
+  ret <8 x i16> %res
+}
+
 define <8 x i16> @rhaddu_base(<8 x i16> %src1, <8 x i16> %src2) {
 ; CHECK-LABEL: rhaddu_base:
 ; CHECK:       // %bb.0:
@@ -879,6 +891,18 @@ define <8 x i16> @uhadd_fixedwidth_v4i32(<8 x i16> %a0, <8 x i16> %a1)  {
   ret <8 x i16> %res
 }
 
+define <8 x i16> @shadd_fixedwidth_v4i32(<8 x i16> %a0, <8 x i16> %a1)  {
+; CHECK-LABEL: shadd_fixedwidth_v4i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shadd v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    ret
+  %and = and <8 x i16> %a0, %a1
+  %xor = xor <8 x i16> %a0, %a1
+  %srl = ashr <8 x i16> %xor, <i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1, i16 1>
+  %res = add <8 x i16> %and, %srl
+  ret <8 x i16> %res
+}
+
 declare <8 x i8> @llvm.aarch64.neon.shadd.v8i8(<8 x i8>, <8 x i8>)
 declare <4 x i16> @llvm.aarch64.neon.shadd.v4i16(<4 x i16>, <4 x i16>)
 declare <2 x i32> @llvm.aarch64.neon.shadd.v2i32(<2 x i32>, <2 x i32>)

``````````

</details>


https://github.com/llvm/llvm-project/pull/86222


More information about the llvm-commits mailing list