[llvm] [AArch64] Improve lowering of truncating uzp1 (PR #82457)

Eli Friedman via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 21 12:09:45 PST 2024


================
@@ -21059,20 +21055,53 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   auto getSourceOp = [](SDValue Operand) -> SDValue {
-    const unsigned Opcode = Operand.getOpcode();
-    if (Opcode == ISD::TRUNCATE)
-      return Operand->getOperand(0);
-    if (Opcode == ISD::BITCAST &&
-        Operand->getOperand(0).getOpcode() == ISD::TRUNCATE)
-      return Operand->getOperand(0)->getOperand(0);
-    return SDValue();
+    if (Operand.getOpcode() == ISD::BITCAST)
+      Operand = Operand->getOperand(0);
+    if (Operand.getOpcode() == ISD::AssertSext ||
+        Operand.getOpcode() == ISD::AssertZext)
+      Operand = Operand->getOperand(0);
+    return Operand;
   };
 
   SDValue SourceOp0 = getSourceOp(Op0);
   SDValue SourceOp1 = getSourceOp(Op1);
 
-  if (!SourceOp0 || !SourceOp1)
+  auto IsTruncatingUZP1Concat = [](SDNode *N, LLVMContext &Ctx) -> bool {
+    if (N->getOpcode() != AArch64ISD::UZP1)
+      return false;
+    SDValue Op0 = N->getOperand(0);
+    SDValue Op1 = N->getOperand(1);
+    if (Op0.getOpcode() != ISD::BITCAST || Op1.getOpcode() != ISD::BITCAST)
+      return false;
+    EVT Op0Ty = Op0.getOperand(0).getValueType();
+    if (Op0Ty != Op1.getOperand(0).getValueType())
+      return false;
+
+    EVT ResVT = N->getValueType(0);
+    return ResVT.widenIntegerVectorElementType(Ctx).getHalfNumVectorElementsVT(
+               Ctx) == Op0Ty;
+  };
+
+  // truncating uzp1(x=uzp1, y=uzp1) -> trunc(concat (x, y))
+  // This is similar to the transform below, except that it looks for truncation
+  // done using the uzp1 node.
+  LLVMContext &Ctx = *DAG.getContext();
+  if (IsTruncatingUZP1Concat(N, Ctx) &&
+      IsTruncatingUZP1Concat(SourceOp0.getNode(), Ctx) &&
+      IsTruncatingUZP1Concat(SourceOp1.getNode(), Ctx)) {
+    SDValue UZP1 =
+        DAG.getNode(ISD::CONCAT_VECTORS, DL,
----------------
efriedma-quic wrote:

I'm a bit confused about the legality checks surrounding this transform.

You're taking something like `UZP1(BITCAST(v4i16 SourceOp0 to v8i8), BITCAST(v4i16 SourceOp1 to v8i8)` and transforming it to `TRUNC(CONCAT_VECTORS(SourceOp0, SourceOp1)`.  But then... you're restricting the transform to cases where SourceOp0 and SourceOp1 are also UZP1 operations.  This is, as far as I can tell, not necessary for correctness; the transform is legal no matter what SourceOp0 and SourceOp1 actually are.

Is there some other reason to restrict the transform?

https://github.com/llvm/llvm-project/pull/82457


More information about the llvm-commits mailing list