[llvm] [AArch64] Improve lowering of truncating uzp1 (PR #82457)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 22 02:04:58 PST 2024


================
@@ -21059,20 +21055,53 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   auto getSourceOp = [](SDValue Operand) -> SDValue {
-    const unsigned Opcode = Operand.getOpcode();
-    if (Opcode == ISD::TRUNCATE)
-      return Operand->getOperand(0);
-    if (Opcode == ISD::BITCAST &&
-        Operand->getOperand(0).getOpcode() == ISD::TRUNCATE)
-      return Operand->getOperand(0)->getOperand(0);
-    return SDValue();
+    if (Operand.getOpcode() == ISD::BITCAST)
+      Operand = Operand->getOperand(0);
+    if (Operand.getOpcode() == ISD::AssertSext ||
+        Operand.getOpcode() == ISD::AssertZext)
+      Operand = Operand->getOperand(0);
+    return Operand;
   };
 
   SDValue SourceOp0 = getSourceOp(Op0);
   SDValue SourceOp1 = getSourceOp(Op1);
 
-  if (!SourceOp0 || !SourceOp1)
+  auto IsTruncatingUZP1Concat = [](SDNode *N, LLVMContext &Ctx) -> bool {
+    if (N->getOpcode() != AArch64ISD::UZP1)
+      return false;
+    SDValue Op0 = N->getOperand(0);
+    SDValue Op1 = N->getOperand(1);
+    if (Op0.getOpcode() != ISD::BITCAST || Op1.getOpcode() != ISD::BITCAST)
+      return false;
+    EVT Op0Ty = Op0.getOperand(0).getValueType();
+    if (Op0Ty != Op1.getOperand(0).getValueType())
+      return false;
+
+    EVT ResVT = N->getValueType(0);
+    return ResVT.widenIntegerVectorElementType(Ctx).getHalfNumVectorElementsVT(
+               Ctx) == Op0Ty;
+  };
+
+  // truncating uzp1(x=uzp1, y=uzp1) -> trunc(concat (x, y))
+  // This is similar to the transform below, except that it looks for truncation
+  // done using the uzp1 node.
+  LLVMContext &Ctx = *DAG.getContext();
+  if (IsTruncatingUZP1Concat(N, Ctx) &&
+      IsTruncatingUZP1Concat(SourceOp0.getNode(), Ctx) &&
+      IsTruncatingUZP1Concat(SourceOp1.getNode(), Ctx)) {
+    SDValue UZP1 =
----------------
david-arm wrote:

nit: You're actually creating a CONCAT_VECTOR node here so perhaps the name `UZP1` here is a bit misleading? How about `ConcatOfUzp1` or something like that? 

https://github.com/llvm/llvm-project/pull/82457


More information about the llvm-commits mailing list