[llvm] [AArch64] Improve lowering of truncating uzp1 (PR #82457)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 29 05:34:16 PST 2024


================
@@ -21058,21 +21054,28 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
   if (ResVT != MVT::v2i32 && ResVT != MVT::v4i16 && ResVT != MVT::v8i8)
     return SDValue();
 
-  auto getSourceOp = [](SDValue Operand) -> SDValue {
-    const unsigned Opcode = Operand.getOpcode();
-    if (Opcode == ISD::TRUNCATE)
-      return Operand->getOperand(0);
-    if (Opcode == ISD::BITCAST &&
-        Operand->getOperand(0).getOpcode() == ISD::TRUNCATE)
-      return Operand->getOperand(0)->getOperand(0);
-    return SDValue();
-  };
+  SDValue SourceOp0 = peekThroughBitcasts(Op0);
+  SDValue SourceOp1 = peekThroughBitcasts(Op1);
 
-  SDValue SourceOp0 = getSourceOp(Op0);
-  SDValue SourceOp1 = getSourceOp(Op1);
+  // truncating uzp1(x, y) -> xtn(concat (x, y))
+  if (SourceOp0.getValueType() == SourceOp1.getValueType()) {
+    EVT Op0Ty = SourceOp0.getValueType();
+    if ((ResVT == MVT::v4i16 && Op0Ty == MVT::v2i32) ||
----------------
david-arm wrote:

I just realised that we actually don't care about the operand type. The Arm developer website does say this about uzp1:

`Note: UZP1 is equivalent to truncating and packing each element from two source vectors into a single destination vector with elements of half the size.`

So really you should be able to always transform this into truncate(concat(x, y)) regardless of the type of x or y. If x or y have the wrong type you can just create a new bitcast and write the code like this, i.e.

```
  if (SourceOp0.getValueType() == SourceOp1.getValueType() && (ResVT == MVT::v4i16 || ResVT == MVT::v8i8)) {
    EVT Op0Ty = SourceOp0.getValueType();
    EVT RequiredOpTy = ResVT == MVT::v4i16 ? MVT::v2i32 : MVT::v4i16;
    if (Op0Ty != RequiredOpTy) {
      SourceOp0 = DAG.getNode(ISD::BITCAST, DL, RequiredOpTy, SourceOp0);
      SourceOp1 = DAG.getNode(ISD::BITCAST, DL, RequiredOpTy, SourceOp1);
    }
    SDValue Concat =
           DAG.getNode(ISD::CONCAT_VECTORS, DL,
                       Op0Ty.getDoubleNumVectorElementsVT(*DAG.getContext()),
                       SourceOp0, SourceOp1);
    return DAG.getNode(ISD::TRUNCATE, DL, ResVT, Concat);
  }

https://github.com/llvm/llvm-project/pull/82457


More information about the llvm-commits mailing list