[llvm] 2b8db40 - [SVE] Restrict the usage of REINTERPRET_CAST.

Paul Walker via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 15 03:33:56 PST 2021


Author: Paul Walker
Date: 2021-01-15T11:32:13Z
New Revision: 2b8db40c92186731effd8948049919db8cf37dee

URL: https://github.com/llvm/llvm-project/commit/2b8db40c92186731effd8948049919db8cf37dee
DIFF: https://github.com/llvm/llvm-project/commit/2b8db40c92186731effd8948049919db8cf37dee.diff

LOG: [SVE] Restrict the usage of REINTERPRET_CAST.

In order to limit the number of combinations of REINTERPRET_CAST,
whilst at the same time prevent overlap with BITCAST, this patch
establishes the following rules:

1. The operand and result element types must be the same.
2. The operand and/or result type must be an unpacked type.

Differential Revision: https://reviews.llvm.org/D94593

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index adfe492d6181..d72eee5abc26 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -144,6 +144,25 @@ static inline EVT getPackedSVEVectorVT(EVT VT) {
     return MVT::nxv4f32;
   case MVT::f64:
     return MVT::nxv2f64;
+  case MVT::bf16:
+    return MVT::nxv8bf16;
+  }
+}
+
+// NOTE: Currently there's only a need to return integer vector types. If this
+// changes then just add an extra "type" parameter.
+static inline EVT getPackedSVEVectorVT(ElementCount EC) {
+  switch (EC.getKnownMinValue()) {
+  default:
+    llvm_unreachable("unexpected element count for vector");
+  case 16:
+    return MVT::nxv16i8;
+  case 8:
+    return MVT::nxv8i16;
+  case 4:
+    return MVT::nxv4i32;
+  case 2:
+    return MVT::nxv2i64;
   }
 }
 
@@ -3988,14 +4007,10 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
       !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
     return SDValue();
 
-  // Handle FP data
+  // Handle FP data by using an integer gather and casting the result.
   if (VT.isFloatingPoint()) {
-    ElementCount EC = VT.getVectorElementCount();
-    auto ScalarIntVT =
-        MVT::getIntegerVT(AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
-    PassThru = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL,
-                           MVT::getVectorVT(ScalarIntVT, EC), PassThru);
-
+    EVT PassThruVT = getPackedSVEVectorVT(VT.getVectorElementCount());
+    PassThru = getSVESafeBitCast(PassThruVT, PassThru, DAG);
     InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
   }
 
@@ -4015,7 +4030,7 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
   SDValue Gather = DAG.getNode(Opcode, DL, VTs, Ops);
 
   if (VT.isFloatingPoint()) {
-    SDValue Cast = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Gather);
+    SDValue Cast = getSVESafeBitCast(VT, Gather, DAG);
     return DAG.getMergeValues({Cast, Gather}, DL);
   }
 
@@ -4052,15 +4067,10 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
       !static_cast<const AArch64Subtarget &>(DAG.getSubtarget()).hasBF16())
     return SDValue();
 
-  // Handle FP data
+  // Handle FP data by casting the data so an integer scatter can be used.
   if (VT.isFloatingPoint()) {
-    VT = VT.changeVectorElementTypeToInteger();
-    ElementCount EC = VT.getVectorElementCount();
-    auto ScalarIntVT =
-        MVT::getIntegerVT(AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
-    StoreVal = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL,
-                           MVT::getVectorVT(ScalarIntVT, EC), StoreVal);
-
+    EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount());
+    StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG);
     InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger());
   }
 
@@ -17157,3 +17167,40 @@ SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE(
   auto Promote = DAG.getBoolExtOrTrunc(Cmp, DL, PromoteVT, InVT);
   return convertFromScalableVector(DAG, Op.getValueType(), Promote);
 }
+
+SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
+                                                 SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  EVT InVT = Op.getValueType();
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+
+  assert(VT.isScalableVector() && TLI.isTypeLegal(VT) &&
+         InVT.isScalableVector() && TLI.isTypeLegal(InVT) &&
+         "Only expect to cast between legal scalable vector types!");
+  assert((VT.getVectorElementType() == MVT::i1) ==
+             (InVT.getVectorElementType() == MVT::i1) &&
+         "Cannot cast between data and predicate scalable vector types!");
+
+  if (InVT == VT)
+    return Op;
+
+  if (VT.getVectorElementType() == MVT::i1)
+    return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
+
+  EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType());
+  EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType());
+  assert((VT == PackedVT || InVT == PackedInVT) &&
+         "Cannot cast between unpacked scalable vector types!");
+
+  // Pack input if required.
+  if (InVT != PackedInVT)
+    Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedInVT, Op);
+
+  Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
+
+  // Unpack result if required.
+  if (VT != PackedVT)
+    Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
+
+  return Op;
+}

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 23d5ce91b3e3..71e59f8f6ed7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -314,6 +314,7 @@ enum NodeType : unsigned {
   DUP_MERGE_PASSTHRU,
   INDEX_VECTOR,
 
+  // Cast between vectors of the same element type but 
diff er in length.
   REINTERPRET_CAST,
 
   LD1_MERGE_ZERO,
@@ -1022,6 +1023,17 @@ class AArch64TargetLowering : public TargetLowering {
   // NEON vector. This changes when OverrideNEON is true, allowing SVE to be
   // used for 64bit and 128bit vectors as well.
   bool useSVEForFixedLengthVectorVT(EVT VT, bool OverrideNEON = false) const;
+
+  // With the exception of data-predicate transitions, no instructions are
+  // required to cast between legal scalable vector types. However:
+  //  1. Packed and unpacked types have 
diff erent bit lengths, meaning BITCAST
+  //     is not universally useable.
+  //  2. Most unpacked integer types are not legal and thus integer extends
+  //     cannot be used to convert between unpacked and packed types.
+  // These can make "bitcasting" a multiphase process. REINTERPRET_CAST is used
+  // to transition between unpacked and packed types of the same element type,
+  // with BITCAST used otherwise.
+  SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
 };
 
 namespace AArch64 {

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index cd80f3801fb2..e09b8401c0e0 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -1721,6 +1721,7 @@ let Predicates = [HasSVE] in {
     def : Pat<(nxv2f64 (bitconvert (nxv8bf16 ZPR:$src))), (nxv2f64 ZPR:$src)>;
   }
 
+  // These allow casting from/to unpacked predicate types.
   def : Pat<(nxv16i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
   def : Pat<(nxv16i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
   def : Pat<(nxv16i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
@@ -1735,23 +1736,17 @@ let Predicates = [HasSVE] in {
   def : Pat<(nxv2i1 (reinterpret_cast  (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
   def : Pat<(nxv2i1 (reinterpret_cast  (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
 
-  def : Pat<(nxv2i64 (reinterpret_cast (nxv2f64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv2i64 (reinterpret_cast (nxv2f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv2i64 (reinterpret_cast (nxv2f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv4i32 (reinterpret_cast (nxv4f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv4i32 (reinterpret_cast (nxv4f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv2i64 (reinterpret_cast (nxv2bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv4i32 (reinterpret_cast (nxv4bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-
-  def : Pat<(nxv2f16 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv2f32 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv2f64 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv4f16 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv4f32 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv8f16 (reinterpret_cast (nxv8i16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv2bf16 (reinterpret_cast (nxv2i64 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv4bf16 (reinterpret_cast (nxv4i32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
-  def : Pat<(nxv8bf16 (reinterpret_cast (nxv8i16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+  // These allow casting from/to unpacked floating-point types.
+  def : Pat<(nxv2f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+  def : Pat<(nxv8f16 (reinterpret_cast (nxv2f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+  def : Pat<(nxv4f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+  def : Pat<(nxv8f16 (reinterpret_cast (nxv4f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+  def : Pat<(nxv2f32 (reinterpret_cast (nxv4f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+  def : Pat<(nxv4f32 (reinterpret_cast (nxv2f32 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+  def : Pat<(nxv2bf16 (reinterpret_cast (nxv8bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+  def : Pat<(nxv8bf16 (reinterpret_cast (nxv2bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+  def : Pat<(nxv4bf16 (reinterpret_cast (nxv8bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
+  def : Pat<(nxv8bf16 (reinterpret_cast (nxv4bf16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
 
   def : Pat<(nxv16i1 (and PPR:$Ps1, PPR:$Ps2)),
             (AND_PPzPP (PTRUE_B 31), PPR:$Ps1, PPR:$Ps2)>;


        


More information about the llvm-commits mailing list