[llvm] [AArch64] Convert UADDV(add(zext, zext)) into UADDLV(concat). (PR #78301)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 18 07:39:49 PST 2024


================
@@ -16613,11 +16613,56 @@ static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
   return SDValue();
 }
 
+// We can convert a UADDV(add(zext(64-bit source), zext(64-bit source))) into
+// UADDLV(concat), where the concat represents the 64-bit zext sources.
+static SDValue performUADDVZextCombine(SDValue A, SelectionDAG &DAG) {
+  // Look for add(zext(64-bit source), zext(64-bit source)), returning
+  // UADDLV(concat(zext, zext)) if found.
+  if (A.getOpcode() != ISD::ADD)
+    return SDValue();
+  EVT VT = A.getValueType();
+  if (VT != MVT::v8i16 && VT != MVT::v4i32 && VT != MVT::v2i64)
+    return SDValue();
+  SDValue Op0 = A.getOperand(0);
+  SDValue Op1 = A.getOperand(1);
+  if (Op0.getOpcode() != ISD::ZERO_EXTEND || Op0.getOpcode() != Op1.getOpcode())
+    return SDValue();
+  SDValue Ext0 = Op0.getOperand(0);
+  SDValue Ext1 = Op1.getOperand(0);
+  EVT ExtVT0 = Ext0.getValueType();
+  EVT ExtVT1 = Ext1.getValueType();
+  // Check zext VTs are the same and 64-bit length.
+  if (ExtVT0 != ExtVT1 ||
+      !(ExtVT0 == MVT::v8i8 || ExtVT0 == MVT::v4i16 || ExtVT0 == MVT::v2i32))
+    return SDValue();
+  // Get VT for concat of zext sources.
+  EVT PairVT = ExtVT0.getDoubleNumVectorElementsVT(*DAG.getContext());
+  SDValue Concat =
+      DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(A), PairVT, Ext0, Ext1);
+
+  switch (VT.getSimpleVT().SimpleTy) {
+  case MVT::v2i64:
+    return DAG.getNode(AArch64ISD::UADDLV, SDLoc(A), MVT::v2i64, Concat);
+  case MVT::v4i32:
+    return DAG.getNode(AArch64ISD::UADDLV, SDLoc(A), MVT::v4i32, Concat);
+  case MVT::v8i16: {
+    SDValue Uaddlv =
----------------
david-arm wrote:

It's not your fault, but this does look really weird. I wonder why we wrote the tablegen pattern in a way that doesn't really match the result type:

```
  def : Pat<(v4i32 (addlv (v16i8 V128:$Rn))),
            (v4i32 (SUBREG_TO_REG (i64 0), (!cast<Instruction>(Opc#"v16i8v") V128:$Rn), hsub))>;
```

I would have thought this would be more natural and avoid having to cast:

```
  def : Pat<(v816 (addlv (v16i8 V128:$Rn))),
            (v816 (SUBREG_TO_REG (i64 0), (!cast<Instruction>(Opc#"v16i8v") V128:$Rn), hsub))>;
```


https://github.com/llvm/llvm-project/pull/78301


More information about the llvm-commits mailing list