[llvm] [LLVM][AArch64] Improve big endian code generation for SVE BITCASTs. (PR #104769)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 19 05:42:16 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-aarch64

Author: Paul Walker (paulwalker-arm)

<details>
<summary>Changes</summary>

For the most part I've tried to maintain the use of ISD::BITCAST wherever possible. I'm assuming this will keep access to more DAG combines, but perhaps it's more likely to just encourage the proliferation of invalid combines than if I ensure only AArch64ISD::NVCAST/REINTERPRET_CAST survives lowering?

---

Patch is 98.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104769.diff


4 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+2) 
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+40-19) 
- (modified) llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td (+56-107) 
- (modified) llvm/test/CodeGen/AArch64/sve-bitcast.ll (+249-769) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index ab12c3b0e728a8..8d06206755d4f8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -6140,6 +6140,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
     break;
   case ISD::BSWAP:
     assert(VT.isInteger() && VT == N1.getValueType() && "Invalid BSWAP!");
+    if (VT.getScalarSizeInBits() == 8)
+      return N1;
     assert((VT.getScalarSizeInBits() % 16 == 0) &&
            "BSWAP types must be a multiple of 16 bits!");
     if (OpOpcode == ISD::UNDEF)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 97fb2c5f552731..4d5034d67c5ed8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1496,7 +1496,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::AVGCEILU, VT, Custom);
 
       if (!Subtarget->isLittleEndian())
-        setOperationAction(ISD::BITCAST, VT, Expand);
+        setOperationAction(ISD::BITCAST, VT, Custom);
 
       if (Subtarget->hasSVE2() ||
           (Subtarget->hasSME() && Subtarget->isStreaming()))
@@ -1510,9 +1510,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
     }
 
-    // Legalize unpacked bitcasts to REINTERPRET_CAST.
-    for (auto VT : {MVT::nxv2i16, MVT::nxv4i16, MVT::nxv2i32, MVT::nxv2bf16,
-                    MVT::nxv4bf16, MVT::nxv2f16, MVT::nxv4f16, MVT::nxv2f32})
+    // Type legalize unpacked bitcasts.
+    for (auto VT : {MVT::nxv2i16, MVT::nxv4i16, MVT::nxv2i32})
       setOperationAction(ISD::BITCAST, VT, Custom);
 
     for (auto VT :
@@ -1587,6 +1586,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
     for (auto VT : {MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv2f32,
                     MVT::nxv4f32, MVT::nxv2f64}) {
+      setOperationAction(ISD::BITCAST, VT, Custom);
       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
       setOperationAction(ISD::MLOAD, VT, Custom);
@@ -1658,20 +1658,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setCondCodeAction(ISD::SETUGT, VT, Expand);
       setCondCodeAction(ISD::SETUEQ, VT, Expand);
       setCondCodeAction(ISD::SETONE, VT, Expand);
-
-      if (!Subtarget->isLittleEndian())
-        setOperationAction(ISD::BITCAST, VT, Expand);
     }
 
     for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
+      setOperationAction(ISD::BITCAST, VT, Custom);
       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
       setOperationAction(ISD::MLOAD, VT, Custom);
       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
       setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
       setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
-
-      if (!Subtarget->isLittleEndian())
-        setOperationAction(ISD::BITCAST, VT, Expand);
     }
 
     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -4960,22 +4955,35 @@ SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
     return LowerFixedLengthBitcastToSVE(Op, DAG);
 
   if (OpVT.isScalableVector()) {
-    // Bitcasting between unpacked vector types of different element counts is
-    // not a NOP because the live elements are laid out differently.
-    //                01234567
-    // e.g. nxv2i32 = XX??XX??
-    //      nxv4f16 = X?X?X?X?
-    if (OpVT.getVectorElementCount() != ArgVT.getVectorElementCount())
-      return SDValue();
+    assert(isTypeLegal(OpVT) && "Unexpected result type!");
 
-    if (isTypeLegal(OpVT) && !isTypeLegal(ArgVT)) {
+    // Handle type legalisation first.
+    if (!isTypeLegal(ArgVT)) {
       assert(OpVT.isFloatingPoint() && !ArgVT.isFloatingPoint() &&
              "Expected int->fp bitcast!");
+
+      // Bitcasting between unpacked vector types of different element counts is
+      // not a NOP because the live elements are laid out differently.
+      //                01234567
+      // e.g. nxv2i32 = XX??XX??
+      //      nxv4f16 = X?X?X?X?
+      if (OpVT.getVectorElementCount() != ArgVT.getVectorElementCount())
+        return SDValue();
+
       SDValue ExtResult =
           DAG.getNode(ISD::ANY_EXTEND, SDLoc(Op), getSVEContainerType(ArgVT),
                       Op.getOperand(0));
       return getSVESafeBitCast(OpVT, ExtResult, DAG);
     }
+
+    // Bitcasts between legal types with the same element count are legal.
+    if (OpVT.getVectorElementCount() == ArgVT.getVectorElementCount())
+      return Op;
+
+    // getSVESafeBitCast does not support casting between unpacked types.
+    if (!isPackedVectorType(OpVT, DAG))
+      return SDValue();
+
     return getSVESafeBitCast(OpVT, Op.getOperand(0), DAG);
   }
 
@@ -28877,7 +28885,20 @@ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
   if (InVT != PackedInVT)
     Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedInVT, Op);
 
-  Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
+  if (Subtarget->isLittleEndian() ||
+      PackedVT.getScalarSizeInBits() == PackedInVT.getScalarSizeInBits())
+    Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
+  else {
+    EVT PackedVTAsInt = PackedVT.changeTypeToInteger();
+    EVT PackedInVTAsInt = PackedInVT.changeTypeToInteger();
+
+    // Simulate the effect of casting through memory.
+    Op = DAG.getNode(ISD::BITCAST, DL, PackedInVTAsInt, Op);
+    Op = DAG.getNode(ISD::BSWAP, DL, PackedInVTAsInt, Op);
+    Op = DAG.getNode(AArch64ISD::NVCAST, DL, PackedVTAsInt, Op);
+    Op = DAG.getNode(ISD::BSWAP, DL, PackedVTAsInt, Op);
+    Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
+  }
 
   // Unpack result if required.
   if (VT != PackedVT)
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index d9a70b5ef02fcb..35035aae05ecb6 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2650,113 +2650,62 @@ let Predicates = [HasSVEorSME] in {
                                sub_32)>;
   }
 
-  // FIXME: BigEndian requires an additional REV instruction to satisfy the
-  // constraint that none of the bits change when stored to memory as one
-  // type, and reloaded as another type.
-  let Predicates = [IsLE] in {
-    def : Pat<(nxv16i8 (bitconvert nxv8i16:$src)), (nxv16i8 ZPR:$src)>;
-    def : Pat<(nxv16i8 (bitconvert nxv4i32:$src)), (nxv16i8 ZPR:$src)>;
-    def : Pat<(nxv16i8 (bitconvert nxv2i64:$src)), (nxv16i8 ZPR:$src)>;
-    def : Pat<(nxv16i8 (bitconvert nxv8f16:$src)), (nxv16i8 ZPR:$src)>;
-    def : Pat<(nxv16i8 (bitconvert nxv4f32:$src)), (nxv16i8 ZPR:$src)>;
-    def : Pat<(nxv16i8 (bitconvert nxv2f64:$src)), (nxv16i8 ZPR:$src)>;
-
-    def : Pat<(nxv8i16 (bitconvert nxv16i8:$src)), (nxv8i16 ZPR:$src)>;
-    def : Pat<(nxv8i16 (bitconvert nxv4i32:$src)), (nxv8i16 ZPR:$src)>;
-    def : Pat<(nxv8i16 (bitconvert nxv2i64:$src)), (nxv8i16 ZPR:$src)>;
-    def : Pat<(nxv8i16 (bitconvert nxv8f16:$src)), (nxv8i16 ZPR:$src)>;
-    def : Pat<(nxv8i16 (bitconvert nxv4f32:$src)), (nxv8i16 ZPR:$src)>;
-    def : Pat<(nxv8i16 (bitconvert nxv2f64:$src)), (nxv8i16 ZPR:$src)>;
-
-    def : Pat<(nxv4i32 (bitconvert nxv16i8:$src)), (nxv4i32 ZPR:$src)>;
-    def : Pat<(nxv4i32 (bitconvert nxv8i16:$src)), (nxv4i32 ZPR:$src)>;
-    def : Pat<(nxv4i32 (bitconvert nxv2i64:$src)), (nxv4i32 ZPR:$src)>;
-    def : Pat<(nxv4i32 (bitconvert nxv8f16:$src)), (nxv4i32 ZPR:$src)>;
-    def : Pat<(nxv4i32 (bitconvert nxv4f32:$src)), (nxv4i32 ZPR:$src)>;
-    def : Pat<(nxv4i32 (bitconvert nxv2f64:$src)), (nxv4i32 ZPR:$src)>;
-
-    def : Pat<(nxv2i64 (bitconvert nxv16i8:$src)), (nxv2i64 ZPR:$src)>;
-    def : Pat<(nxv2i64 (bitconvert nxv8i16:$src)), (nxv2i64 ZPR:$src)>;
-    def : Pat<(nxv2i64 (bitconvert nxv4i32:$src)), (nxv2i64 ZPR:$src)>;
-    def : Pat<(nxv2i64 (bitconvert nxv8f16:$src)), (nxv2i64 ZPR:$src)>;
-    def : Pat<(nxv2i64 (bitconvert nxv4f32:$src)), (nxv2i64 ZPR:$src)>;
-    def : Pat<(nxv2i64 (bitconvert nxv2f64:$src)), (nxv2i64 ZPR:$src)>;
-
-    def : Pat<(nxv8f16 (bitconvert nxv16i8:$src)), (nxv8f16 ZPR:$src)>;
-    def : Pat<(nxv8f16 (bitconvert nxv8i16:$src)), (nxv8f16 ZPR:$src)>;
-    def : Pat<(nxv8f16 (bitconvert nxv4i32:$src)), (nxv8f16 ZPR:$src)>;
-    def : Pat<(nxv8f16 (bitconvert nxv2i64:$src)), (nxv8f16 ZPR:$src)>;
-    def : Pat<(nxv8f16 (bitconvert nxv4f32:$src)), (nxv8f16 ZPR:$src)>;
-    def : Pat<(nxv8f16 (bitconvert nxv2f64:$src)), (nxv8f16 ZPR:$src)>;
-
-    def : Pat<(nxv4f32 (bitconvert nxv16i8:$src)), (nxv4f32 ZPR:$src)>;
-    def : Pat<(nxv4f32 (bitconvert nxv8i16:$src)), (nxv4f32 ZPR:$src)>;
-    def : Pat<(nxv4f32 (bitconvert nxv4i32:$src)), (nxv4f32 ZPR:$src)>;
-    def : Pat<(nxv4f32 (bitconvert nxv2i64:$src)), (nxv4f32 ZPR:$src)>;
-    def : Pat<(nxv4f32 (bitconvert nxv8f16:$src)), (nxv4f32 ZPR:$src)>;
-    def : Pat<(nxv4f32 (bitconvert nxv2f64:$src)), (nxv4f32 ZPR:$src)>;
-
-    def : Pat<(nxv2f64 (bitconvert nxv16i8:$src)), (nxv2f64 ZPR:$src)>;
-    def : Pat<(nxv2f64 (bitconvert nxv8i16:$src)), (nxv2f64 ZPR:$src)>;
-    def : Pat<(nxv2f64 (bitconvert nxv4i32:$src)), (nxv2f64 ZPR:$src)>;
-    def : Pat<(nxv2f64 (bitconvert nxv2i64:$src)), (nxv2f64 ZPR:$src)>;
-    def : Pat<(nxv2f64 (bitconvert nxv8f16:$src)), (nxv2f64 ZPR:$src)>;
-    def : Pat<(nxv2f64 (bitconvert nxv4f32:$src)), (nxv2f64 ZPR:$src)>;
-
-    def : Pat<(nxv8bf16 (bitconvert nxv16i8:$src)), (nxv8bf16 ZPR:$src)>;
-    def : Pat<(nxv8bf16 (bitconvert nxv8i16:$src)), (nxv8bf16 ZPR:$src)>;
-    def : Pat<(nxv8bf16 (bitconvert nxv4i32:$src)), (nxv8bf16 ZPR:$src)>;
-    def : Pat<(nxv8bf16 (bitconvert nxv2i64:$src)), (nxv8bf16 ZPR:$src)>;
-    def : Pat<(nxv8bf16 (bitconvert nxv8f16:$src)), (nxv8bf16 ZPR:$src)>;
-    def : Pat<(nxv8bf16 (bitconvert nxv4f32:$src)), (nxv8bf16 ZPR:$src)>;
-    def : Pat<(nxv8bf16 (bitconvert nxv2f64:$src)), (nxv8bf16 ZPR:$src)>;
-
-    def : Pat<(nxv16i8 (bitconvert nxv8bf16:$src)), (nxv16i8 ZPR:$src)>;
-    def : Pat<(nxv8i16 (bitconvert nxv8bf16:$src)), (nxv8i16 ZPR:$src)>;
-    def : Pat<(nxv4i32 (bitconvert nxv8bf16:$src)), (nxv4i32 ZPR:$src)>;
-    def : Pat<(nxv2i64 (bitconvert nxv8bf16:$src)), (nxv2i64 ZPR:$src)>;
-    def : Pat<(nxv8f16 (bitconvert nxv8bf16:$src)), (nxv8f16 ZPR:$src)>;
-    def : Pat<(nxv4f32 (bitconvert nxv8bf16:$src)), (nxv4f32 ZPR:$src)>;
-    def : Pat<(nxv2f64 (bitconvert nxv8bf16:$src)), (nxv2f64 ZPR:$src)>;
-
-    def : Pat<(nxv16i1 (bitconvert aarch64svcount:$src)), (nxv16i1 PPR:$src)>;
-    def : Pat<(aarch64svcount (bitconvert nxv16i1:$src)), (aarch64svcount PNR:$src)>;
-  }
-
-  // These allow casting from/to unpacked predicate types.
-  def : Pat<(nxv16i1 (reinterpret_cast nxv16i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv16i1 (reinterpret_cast nxv8i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv16i1 (reinterpret_cast nxv4i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv16i1 (reinterpret_cast nxv2i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv16i1 (reinterpret_cast nxv1i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv8i1 (reinterpret_cast nxv16i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv8i1 (reinterpret_cast  nxv4i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv8i1 (reinterpret_cast  nxv2i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv8i1 (reinterpret_cast  nxv1i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv4i1 (reinterpret_cast nxv16i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv4i1 (reinterpret_cast  nxv8i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv4i1 (reinterpret_cast  nxv2i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv4i1 (reinterpret_cast  nxv1i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv2i1 (reinterpret_cast nxv16i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv2i1 (reinterpret_cast  nxv8i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv2i1 (reinterpret_cast  nxv4i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv2i1 (reinterpret_cast  nxv1i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv1i1 (reinterpret_cast nxv16i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv1i1 (reinterpret_cast  nxv8i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv1i1 (reinterpret_cast  nxv4i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-  def : Pat<(nxv1i1 (reinterpret_cast  nxv2i1:$src)), (COPY_TO_REGCLASS PPR:$src, PPR)>;
-
-  // These allow casting from/to unpacked floating-point types.
-  def : Pat<(nxv2f16 (reinterpret_cast nxv8f16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv8f16 (reinterpret_cast nxv2f16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv4f16 (reinterpret_cast nxv8f16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv8f16 (reinterpret_cast nxv4f16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv2f32 (reinterpret_cast nxv4f32:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv4f32 (reinterpret_cast nxv2f32:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv2bf16 (reinterpret_cast nxv8bf16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv8bf16 (reinterpret_cast nxv2bf16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv4bf16 (reinterpret_cast nxv8bf16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv8bf16 (reinterpret_cast nxv4bf16:$src)), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+  // For big endian, only BITCASTs involving same sized vector types with same
+  // size vector elements can be isel'd directly.
+  let Predicates = [IsLE] in
+    foreach VT = [ nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv8f16, nxv4f32, nxv2f64, nxv8bf16 ] in
+      foreach VT2 = [ nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv8f16, nxv4f32, nxv2f64, nxv8bf16 ] in
+        if !ne(VT,VT2) then
+          def : Pat<(VT (bitconvert (VT2 ZPR:$src))), (VT ZPR:$src)>;
+
+  def : Pat<(nxv8i16 (bitconvert (nxv8f16 ZPR:$src))), (nxv8i16 ZPR:$src)>;
+  def : Pat<(nxv8f16 (bitconvert (nxv8i16 ZPR:$src))), (nxv8f16 ZPR:$src)>;
+
+  def : Pat<(nxv4i32 (bitconvert (nxv4f32 ZPR:$src))), (nxv4i32 ZPR:$src)>;
+  def : Pat<(nxv4f32 (bitconvert (nxv4i32 ZPR:$src))), (nxv4f32 ZPR:$src)>;
+
+  def : Pat<(nxv2i64 (bitconvert (nxv2f64 ZPR:$src))), (nxv2i64 ZPR:$src)>;
+  def : Pat<(nxv2f64 (bitconvert (nxv2i64 ZPR:$src))), (nxv2f64 ZPR:$src)>;
+
+  def : Pat<(nxv8i16 (bitconvert (nxv8bf16 ZPR:$src))), (nxv8i16 ZPR:$src)>;
+  def : Pat<(nxv8bf16 (bitconvert (nxv8i16 ZPR:$src))), (nxv8bf16 ZPR:$src)>;
+
+  def : Pat<(nxv8bf16 (bitconvert (nxv8f16 ZPR:$src))), (nxv8bf16 ZPR:$src)>;
+  def : Pat<(nxv8f16 (bitconvert (nxv8bf16 ZPR:$src))), (nxv8f16 ZPR:$src)>;
+
+  def : Pat<(nxv4bf16 (bitconvert (nxv4f16 ZPR:$src))), (nxv4bf16 ZPR:$src)>;
+  def : Pat<(nxv4f16 (bitconvert (nxv4bf16 ZPR:$src))), (nxv4f16 ZPR:$src)>;
+
+  def : Pat<(nxv2bf16 (bitconvert (nxv2f16 ZPR:$src))), (nxv2bf16 ZPR:$src)>;
+  def : Pat<(nxv2f16 (bitconvert (nxv2bf16 ZPR:$src))), (nxv2f16 ZPR:$src)>;
+
+  def : Pat<(nxv16i1 (bitconvert (aarch64svcount PNR:$src))), (nxv16i1 PPR:$src)>;
+  def : Pat<(aarch64svcount (bitconvert (nxv16i1 PPR:$src))), (aarch64svcount PNR:$src)>;
+
+  // These allow nop casting between predicate vector types.
+  foreach VT = [ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1 ] in
+    foreach VT2 = [ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1 ] in
+      def : Pat<(VT (reinterpret_cast (VT2 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
+
+  // These allow nop casting between half vector types.
+  foreach VT = [ nxv2f16, nxv4f16, nxv8f16 ] in
+    foreach VT2 = [ nxv2f16, nxv4f16, nxv8f16 ] in
+      def : Pat<(VT (reinterpret_cast (VT2 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+
+  // These allow nop casting between float vector types.
+  foreach VT = [ nxv2f32, nxv4f32 ] in
+    foreach VT2 = [ nxv2f32, nxv4f32 ] in
+      def : Pat<(VT (reinterpret_cast (VT2 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+
+  // These allow nop casting between bfloat vector types.
+  foreach VT = [ nxv2bf16, nxv4bf16, nxv8bf16 ] in
+    foreach VT2 = [ nxv2bf16, nxv4bf16, nxv8bf16 ] in
+      def : Pat<(VT (reinterpret_cast (VT2 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+
+  // These allow nop casting between all packed vector types.
+  foreach VT = [ nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv8f16, nxv4f32, nxv2f64, nxv8bf16 ] in
+    foreach VT2 = [ nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv8f16, nxv4f32, nxv2f64, nxv8bf16 ] in
+      def : Pat<(VT (AArch64NvCast (VT2 ZPR:$src))), (VT ZPR:$src)>;
 
   def : Pat<(nxv16i1 (and PPR:$Ps1, PPR:$Ps2)),
             (AND_PPzPP (PTRUE_B 31), PPR:$Ps1, PPR:$Ps2)>;
diff --git a/llvm/test/CodeGen/AArch64/sve-bitcast.ll b/llvm/test/CodeGen/AArch64/sve-bitcast.ll
index 95f43ba5126323..5d12d41ac3332f 100644
--- a/llvm/test/CodeGen/AArch64/sve-bitcast.ll
+++ b/llvm/test/CodeGen/AArch64/sve-bitcast.ll
@@ -13,14 +13,8 @@ define <vscale x 16 x i8> @bitcast_nxv8i16_to_nxv16i8(<vscale x 8 x i16> %v) #0
 ;
 ; CHECK_BE-LABEL: bitcast_nxv8i16_to_nxv16i8:
 ; CHECK_BE:       // %bb.0:
-; CHECK_BE-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK_BE-NEXT:    addvl sp, sp, #-1
 ; CHECK_BE-NEXT:    ptrue p0.h
-; CHECK_BE-NEXT:    ptrue p1.b
-; CHECK_BE-NEXT:    st1h { z0.h }, p0, [sp]
-; CHECK_BE-NEXT:    ld1b { z0.b }, p1/z, [sp]
-; CHECK_BE-NEXT:    addvl sp, sp, #1
-; CHECK_BE-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK_BE-NEXT:    revb z0.h, p0/m, z0.h
 ; CHECK_BE-NEXT:    ret
   %bc = bitcast <vscale x 8 x i16> %v to <vscale x 16 x i8>
   ret <vscale x 16 x i8> %bc
@@ -33,14 +27,8 @@ define <vscale x 16 x i8> @bitcast_nxv4i32_to_nxv16i8(<vscale x 4 x i32> %v) #0
 ;
 ; CHECK_BE-LABEL: bitcast_nxv4i32_to_nxv16i8:
 ; CHECK_BE:       // %bb.0:
-; CHECK_BE-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK_BE-NEXT:    addvl sp, sp, #-1
 ; CHECK_BE-NEXT:    ptrue p0.s
-; CHECK_BE-NEXT:    ptrue p1.b
-; CHECK_BE-NEXT:    st1w { z0.s }, p0, [sp]
-; CHECK_BE-NEXT:    ld1b { z0.b }, p1/z, [sp]
-; CHECK_BE-NEXT:    addvl sp, sp, #1
-; CHECK_BE-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK_BE-NEXT:    revb z0.s, p0/m, z0.s
 ; CHECK_BE-NEXT:    ret
   %bc = bitcast <vscale x 4 x i32> %v to <vscale x 16 x i8>
   ret <vscale x 16 x i8> %bc
@@ -53,14 +41,8 @@ define <vscale x 16 x i8> @bitcast_nxv2i64_to_nxv16i8(<vscale x 2 x i64> %v) #0
 ;
 ; CHECK_BE-LABEL: bitcast_nxv2i64_to_nxv16i8:
 ; CHECK_BE:       // %bb.0:
-; CHECK_BE-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK_BE-NEXT:    addvl sp, sp, #-1
 ; CHECK_BE-NEXT:    ptrue p0.d
-; CHECK_BE-NEXT:    ptrue p1.b
-; CHECK_BE-NEXT:    st1d { z0.d }, p0, [sp]
-; CHECK_BE-NEXT:    ld1b { z0.b }, p1/z, [sp]
-; CHECK_BE-NEXT:    addvl sp, sp, #1
-; CHECK_BE-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK_BE-NEXT:    revb z0.d, p0/m, z0.d
 ; CHECK_BE-NEXT:    ret
   %bc = bitcast <vscale x 2 x i64> %v to <vscale x 16 x i8>
   ret <vscale x 16 x i8> %bc
@@ -73,14 +55,8 @@ define <vscale x 16 x i8> @bitcast_nxv8f16_to_nxv16i8(<vscale x 8 x half> %v) #0
 ;
 ; CHECK_BE-LABEL: bitcast_nxv8f16_to_nxv16i8:
 ; CHECK_BE:       // %bb.0:
-; CHECK_BE-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK_BE-NEXT:    addvl sp, sp, #-1
 ; CHECK_BE-NEXT:    ptrue p0.h
-; CHECK_BE-NEXT:    ptrue p1.b
-; CHECK_BE-NEXT:    st1h { z0.h }, p0, [sp]
-; CHECK_BE-NEXT:    ld1b { z0.b }, p1/z, [sp]
-; CHECK_BE-NEXT:    addvl sp, sp, #1
-; CHECK_BE-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK_BE-NEXT:    revb z0.h, p0/m, z0.h
 ; CHECK_BE-NEXT:    ret
   %bc = bitcast <vscale x 8 x half> %v to <vscale x 16 x i8>
   ret <vscale x 16 x i8> %bc
@@ -93,14 +69,8 @@ define <vscale x 16 x i8> @bitcast_nxv4f32_to_nxv16i8(<vscale x 4 x float> %v) #
 ;
 ; CHECK_BE-LABEL: bitcast_nxv4f32_to_nxv16i8:
 ; CHECK_BE:       // %bb.0:
-; CHECK_BE-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
-; CHECK_BE-NEXT:    addvl sp, sp, #-1
 ; CHECK_BE-NEXT:    ptrue p0.s
-; CHECK_BE-NEXT:    ptrue p1.b
-; CHECK_BE-NEXT:    st1w { z0.s }, p0, [sp]
-; CHECK_BE-NEXT: ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/104769


More information about the llvm-commits mailing list