[llvm] 170e7a0 - [AArch64][SME2] Add CodeGen support for target("aarch64.svcount").

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 2 04:09:48 PST 2023


Author: Sander de Smalen
Date: 2023-03-02T12:07:41Z
New Revision: 170e7a0ec2e6d29cb642ece0bf34f395453d5e68

URL: https://github.com/llvm/llvm-project/commit/170e7a0ec2e6d29cb642ece0bf34f395453d5e68
DIFF: https://github.com/llvm/llvm-project/commit/170e7a0ec2e6d29cb642ece0bf34f395453d5e68.diff

LOG: [AArch64][SME2] Add CodeGen support for target("aarch64.svcount").

This patch adds AArch64 CodeGen support such that the type can be passed
and returned to/from functions, and also adds support to use this type in
load/store operations and PHI nodes.

Reviewed By: paulwalker-arm

Differential Revision: https://reviews.llvm.org/D136862

Added: 
    llvm/test/CodeGen/AArch64/sme-aarch64-svcount-O3.ll
    llvm/test/CodeGen/AArch64/sme-aarch64-svcount.ll

Modified: 
    llvm/include/llvm/CodeGen/ValueTypes.h
    llvm/include/llvm/CodeGen/ValueTypes.td
    llvm/include/llvm/IR/Type.h
    llvm/include/llvm/Support/MachineValueType.h
    llvm/lib/Analysis/Loads.cpp
    llvm/lib/CodeGen/CodeGenPrepare.cpp
    llvm/lib/CodeGen/LowLevelType.cpp
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/lib/CodeGen/ValueTypes.cpp
    llvm/lib/IR/Type.cpp
    llvm/lib/Support/LowLevelType.cpp
    llvm/lib/Target/AArch64/AArch64CallingConvention.td
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64RegisterInfo.td
    llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
    llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
    llvm/utils/TableGen/CodeGenTarget.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/ValueTypes.h b/llvm/include/llvm/CodeGen/ValueTypes.h
index 9818ff2ee0ca0..0719c742b22cd 100644
--- a/llvm/include/llvm/CodeGen/ValueTypes.h
+++ b/llvm/include/llvm/CodeGen/ValueTypes.h
@@ -122,7 +122,7 @@ namespace llvm {
     /// Test if the given EVT has zero size, this will fail if called on a
     /// scalable type
     bool isZeroSized() const {
-      return !isScalableVector() && getSizeInBits() == 0;
+      return getSizeInBits().isZero();
     }
 
     /// Test if the given EVT is simple (as opposed to being extended).
@@ -150,6 +150,12 @@ namespace llvm {
       return isSimple() ? V.isScalarInteger() : isExtendedScalarInteger();
     }
 
+    /// Return true if this is a vector type where the runtime
+    /// length is machine dependent
+    bool isScalableTargetExtVT() const {
+      return isSimple() && V.isScalableTargetExtVT();
+    }
+
     /// Return true if this is a vector value type.
     bool isVector() const {
       return isSimple() ? V.isVector() : isExtendedVector();
@@ -166,6 +172,11 @@ namespace llvm {
                         : isExtendedFixedLengthVector();
     }
 
+    /// Return true if the type is a scalable type.
+    bool isScalableVT() const {
+      return isScalableVector() || isScalableTargetExtVT();
+    }
+
     /// Return true if this is a 16-bit vector type.
     bool is16BitVector() const {
       return isSimple() ? V.is16BitVector() : isExtended16BitVector();

diff  --git a/llvm/include/llvm/CodeGen/ValueTypes.td b/llvm/include/llvm/CodeGen/ValueTypes.td
index c22553855c556..934800f107473 100644
--- a/llvm/include/llvm/CodeGen/ValueTypes.td
+++ b/llvm/include/llvm/CodeGen/ValueTypes.td
@@ -236,6 +236,8 @@ def funcref   : ValueType<0,    192>;  // WebAssembly's funcref type
 def externref : ValueType<0,    193>;  // WebAssembly's externref type
 def x86amx    : ValueType<8192, 194>;  // X86 AMX value
 def i64x8     : ValueType<512,  195>;  // 8 Consecutive GPRs (AArch64)
+def aarch64svcount
+              : ValueType<16,   196>;  // AArch64 predicate-as-counter
 
 def token      : ValueType<0, 248>;  // TokenTy
 def MetadataVT : ValueType<0, 249>;  // Metadata

diff  --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h
index 2a8fddf0132d1..c516f3755eaf5 100644
--- a/llvm/include/llvm/IR/Type.h
+++ b/llvm/include/llvm/IR/Type.h
@@ -206,6 +206,15 @@ class Type {
   /// Return true if this is a target extension type.
   bool isTargetExtTy() const { return getTypeID() == TargetExtTyID; }
 
+  /// Return true if this is a target extension type with a scalable layout.
+  bool isScalableTargetExtTy() const;
+
+  /// Return true if this is a scalable vector type or a target extension type
+  /// with a scalable layout.
+  bool isScalableTy() const {
+    return getTypeID() == ScalableVectorTyID || isScalableTargetExtTy();
+  }
+
   /// Return true if this is a FP type or a vector of FP.
   bool isFPOrFPVectorTy() const { return getScalarType()->isFloatingPointTy(); }
 

diff  --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h
index f5a1ab942e87b..58f294d1b8731 100644
--- a/llvm/include/llvm/Support/MachineValueType.h
+++ b/llvm/include/llvm/Support/MachineValueType.h
@@ -291,9 +291,10 @@ namespace llvm {
       externref      = 193,    // WebAssembly's externref type
       x86amx         = 194,    // This is an X86 AMX value
       i64x8          = 195,    // 8 Consecutive GPRs (AArch64)
+      aarch64svcount = 196,    // AArch64 predicate-as-counter
 
       FIRST_VALUETYPE =  1,    // This is always the beginning of the list.
-      LAST_VALUETYPE = i64x8,  // This always remains at the end of the list.
+      LAST_VALUETYPE = aarch64svcount, // This always remains at the end of the list.
       VALUETYPE_SIZE = LAST_VALUETYPE + 1,
 
       // This is the current maximum for LAST_VALUETYPE.
@@ -401,6 +402,16 @@ namespace llvm {
               SimpleTy <= MVT::LAST_SCALABLE_VECTOR_VALUETYPE);
     }
 
+    /// Return true if this is a custom target type that has a scalable size.
+    bool isScalableTargetExtVT() const {
+      return SimpleTy == MVT::aarch64svcount;
+    }
+
+    /// Return true if the type is a scalable type.
+    bool isScalableVT() const {
+      return isScalableVector() || isScalableTargetExtVT();
+    }
+
     bool isFixedLengthVector() const {
       return (SimpleTy >= MVT::FIRST_FIXEDLEN_VECTOR_VALUETYPE &&
               SimpleTy <= MVT::LAST_FIXEDLEN_VECTOR_VALUETYPE);
@@ -962,6 +973,7 @@ namespace llvm {
       case v2i8:
       case v1i16:
       case v1f16: return TypeSize::Fixed(16);
+      case aarch64svcount:
       case nxv16i1:
       case nxv2i8:
       case nxv1i16:

diff  --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp
index f55333303f8d9..48e435a06ba15 100644
--- a/llvm/lib/Analysis/Loads.cpp
+++ b/llvm/lib/Analysis/Loads.cpp
@@ -204,7 +204,7 @@ bool llvm::isDereferenceableAndAlignedPointer(
     const TargetLibraryInfo *TLI) {
   // For unsized types or scalable vectors we don't know exactly how many bytes
   // are dereferenced, so bail out.
-  if (!Ty->isSized() || isa<ScalableVectorType>(Ty))
+  if (!Ty->isSized() || Ty->isScalableTy())
     return false;
 
   // When dereferenceability information is provided by a dereferenceable

diff  --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index f028381793d3f..54c71c1e39461 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -7696,7 +7696,7 @@ static bool splitMergedValStore(StoreInst &SI, const DataLayout &DL,
   // whereas scalable vectors would have to be shifted by
   // <2log(vscale) + number of bits> in order to store the
   // low/high parts. Bailing out for now.
-  if (isa<ScalableVectorType>(StoreType))
+  if (StoreType->isScalableTy())
     return false;
 
   if (!DL.typeSizeEqualsStoreSize(StoreType) ||

diff  --git a/llvm/lib/CodeGen/LowLevelType.cpp b/llvm/lib/CodeGen/LowLevelType.cpp
index b47c96e508310..539f56fbc5a96 100644
--- a/llvm/lib/CodeGen/LowLevelType.cpp
+++ b/llvm/lib/CodeGen/LowLevelType.cpp
@@ -31,7 +31,7 @@ LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) {
     return LLT::pointer(AddrSpace, DL.getPointerSizeInBits(AddrSpace));
   }
 
-  if (Ty.isSized()) {
+  if (Ty.isSized() && !Ty.isScalableTargetExtTy()) {
     // Aggregates are no 
diff erent from real scalars as far as GlobalISel is
     // concerned.
     auto SizeInBits = DL.getTypeSizeInBits(&Ty);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 3da873dec1c2b..e29f26f6991c5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -17769,8 +17769,8 @@ SDValue DAGCombiner::ForwardStoreValueToDirectLoad(LoadSDNode *LD) {
   //  2. The store is scalable and the load is fixed width. We could
   //     potentially support a limited number of cases here, but there has been
   //     no cost-benefit analysis to prove it's worth it.
-  bool LdStScalable = LDMemType.isScalableVector();
-  if (LdStScalable != STMemType.isScalableVector())
+  bool LdStScalable = LDMemType.isScalableVT();
+  if (LdStScalable != STMemType.isScalableVT())
     return SDValue();
 
   // If we are dealing with scalable vectors on a big endian platform the
@@ -19925,7 +19925,7 @@ bool DAGCombiner::mergeConsecutiveStores(StoreSDNode *St) {
   // store since we know <vscale x 16 x i8> is exactly twice as large as
   // <vscale x 8 x i8>). Until then, bail out for scalable vectors.
   EVT MemVT = St->getMemoryVT();
-  if (MemVT.isScalableVector())
+  if (MemVT.isScalableVT())
     return false;
   if (!MemVT.isSimple() || MemVT.getSizeInBits() * 2 > MaximumLegalStoreInBits)
     return false;
@@ -26807,7 +26807,7 @@ bool DAGCombiner::parallelizeChainedStores(StoreSDNode *St) {
   // BaseIndexOffset assumes that offsets are fixed-size, which
   // is not valid for scalable vectors where the offsets are
   // scaled by `vscale`, so bail out early.
-  if (St->getMemoryVT().isScalableVector())
+  if (St->getMemoryVT().isScalableVT())
     return false;
 
   // Add ST's interval.

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1103eeea2a3ee..aac315d081808 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -496,7 +496,6 @@ getCopyToParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
     return getCopyToPartsVector(DAG, DL, Val, Parts, NumParts, PartVT, V,
                                 CallConv);
 
-  unsigned PartBits = PartVT.getSizeInBits();
   unsigned OrigNumParts = NumParts;
   assert(DAG.getTargetLoweringInfo().isTypeLegal(PartVT) &&
          "Copying to an illegal type!");
@@ -512,6 +511,7 @@ getCopyToParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val, SDValue *Parts,
     return;
   }
 
+  unsigned PartBits = PartVT.getSizeInBits();
   if (NumParts * PartBits > ValueVT.getSizeInBits()) {
     // If the parts cover more bits than the value has, promote the value.
     if (PartVT.isFloatingPoint() && ValueVT.isFloatingPoint()) {

diff  --git a/llvm/lib/CodeGen/ValueTypes.cpp b/llvm/lib/CodeGen/ValueTypes.cpp
index 753650dc3eb0c..b4c873c0b1abd 100644
--- a/llvm/lib/CodeGen/ValueTypes.cpp
+++ b/llvm/lib/CodeGen/ValueTypes.cpp
@@ -174,6 +174,8 @@ std::string EVT::getEVTString() const {
   case MVT::Untyped:   return "Untyped";
   case MVT::funcref:   return "funcref";
   case MVT::externref: return "externref";
+  case MVT::aarch64svcount:
+    return "aarch64svcount";
   }
 }
 
@@ -210,6 +212,8 @@ Type *EVT::getTypeForEVT(LLVMContext &Context) const {
   case MVT::f128:    return Type::getFP128Ty(Context);
   case MVT::ppcf128: return Type::getPPC_FP128Ty(Context);
   case MVT::x86mmx:  return Type::getX86_MMXTy(Context);
+  case MVT::aarch64svcount:
+    return TargetExtType::get(Context, "aarch64.svcount");
   case MVT::x86amx:  return Type::getX86_AMXTy(Context);
   case MVT::i64x8:   return IntegerType::get(Context, 512);
   case MVT::externref: return Type::getWasm_ExternrefTy(Context);
@@ -579,6 +583,12 @@ MVT MVT::getVT(Type *Ty, bool HandleUnknown){
   case Type::DoubleTyID:    return MVT(MVT::f64);
   case Type::X86_FP80TyID:  return MVT(MVT::f80);
   case Type::X86_MMXTyID:   return MVT(MVT::x86mmx);
+  case Type::TargetExtTyID:
+    if (cast<TargetExtType>(Ty)->getName() == "aarch64.svcount")
+      return MVT(MVT::aarch64svcount);
+    if (HandleUnknown)
+      return MVT(MVT::Other);
+    llvm_unreachable("Unknown target ext type!");
   case Type::X86_AMXTyID:   return MVT(MVT::x86amx);
   case Type::FP128TyID:     return MVT(MVT::f128);
   case Type::PPC_FP128TyID: return MVT(MVT::ppcf128);
@@ -590,8 +600,6 @@ MVT MVT::getVT(Type *Ty, bool HandleUnknown){
       getVT(VTy->getElementType(), /*HandleUnknown=*/ false),
             VTy->getElementCount());
   }
-  case Type::TargetExtTyID:
-    return MVT(MVT::Other);
   }
 }
 

diff  --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp
index b2da840ea6931..3b8fc123d83b3 100644
--- a/llvm/lib/IR/Type.cpp
+++ b/llvm/lib/IR/Type.cpp
@@ -80,6 +80,12 @@ bool Type::isIEEE() const {
   return APFloat::getZero(getFltSemantics()).isIEEE();
 }
 
+bool Type::isScalableTargetExtTy() const {
+  if (auto *TT = dyn_cast<TargetExtType>(this))
+    return isa<ScalableVectorType>(TT->getLayoutType());
+  return false;
+}
+
 Type *Type::getFloatingPointTy(LLVMContext &C, const fltSemantics &S) {
   Type *Ty;
   if (&S == &APFloat::IEEEhalf())

diff  --git a/llvm/lib/Support/LowLevelType.cpp b/llvm/lib/Support/LowLevelType.cpp
index 0282cd9bd79ed..e394f98fa62e3 100644
--- a/llvm/lib/Support/LowLevelType.cpp
+++ b/llvm/lib/Support/LowLevelType.cpp
@@ -21,7 +21,7 @@ LLT::LLT(MVT VT) {
     init(/*IsPointer=*/false, asVector, /*IsScalar=*/!asVector,
          VT.getVectorElementCount(), VT.getVectorElementType().getSizeInBits(),
          /*AddressSpace=*/0);
-  } else if (VT.isValid()) {
+  } else if (VT.isValid() && !VT.isScalableTargetExtVT()) {
     // Aggregates are no 
diff erent from real scalars as far as GlobalISel is
     // concerned.
     init(/*IsPointer=*/false, /*IsVector=*/false, /*IsScalar=*/true,

diff  --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
index e53f573de66c1..853975c6193d9 100644
--- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td
+++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
@@ -82,9 +82,9 @@ def CC_AArch64_AAPCS : CallingConv<[
             nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64],
            CCPassIndirect<i64>>,
 
-  CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1],
+  CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1, aarch64svcount],
            CCAssignToReg<[P0, P1, P2, P3]>>,
-  CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1],
+  CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1, aarch64svcount],
            CCPassIndirect<i64>>,
 
   // Handle i1, i8, i16, i32, i64, f32, f64 and v2f64 by passing in registers,
@@ -149,7 +149,7 @@ def RetCC_AArch64_AAPCS : CallingConv<[
             nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64],
            CCAssignToReg<[Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7]>>,
 
-  CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1],
+  CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1, aarch64svcount],
            CCAssignToReg<[P0, P1, P2, P3]>>
 ]>;
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index ca77cbc471790..cb3b85c0f69ca 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -415,6 +415,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     }
   }
 
+  if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) {
+    addRegisterClass(MVT::aarch64svcount, &AArch64::PPRRegClass);
+    setOperationPromotedToType(ISD::LOAD, MVT::aarch64svcount, MVT::nxv16i1);
+    setOperationPromotedToType(ISD::STORE, MVT::aarch64svcount, MVT::nxv16i1);
+
+    setOperationAction(ISD::SELECT, MVT::aarch64svcount, Custom);
+    setOperationAction(ISD::SELECT_CC, MVT::aarch64svcount, Expand);
+  }
+
   // Compute derived properties from the register classes
   computeRegisterProperties(Subtarget->getRegisterInfo());
 
@@ -6429,6 +6438,9 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
                RegVT.getVectorElementType() == MVT::i1) {
         FuncInfo->setIsSVECC(true);
         RC = &AArch64::PPRRegClass;
+      } else if (RegVT == MVT::aarch64svcount) {
+        FuncInfo->setIsSVECC(true);
+        RC = &AArch64::PPRRegClass;
       } else if (RegVT.isScalableVector()) {
         FuncInfo->setIsSVECC(true);
         RC = &AArch64::ZPRRegClass;
@@ -6463,9 +6475,9 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
       case CCValAssign::Full:
         break;
       case CCValAssign::Indirect:
-        assert((VA.getValVT().isScalableVector() ||
-                Subtarget->isWindowsArm64EC()) &&
-               "Indirect arguments should be scalable on most subtargets");
+        assert(
+            (VA.getValVT().isScalableVT() || Subtarget->isWindowsArm64EC()) &&
+            "Indirect arguments should be scalable on most subtargets");
         break;
       case CCValAssign::BCvt:
         ArgValue = DAG.getNode(ISD::BITCAST, DL, VA.getValVT(), ArgValue);
@@ -6544,9 +6556,9 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
     }
 
     if (VA.getLocInfo() == CCValAssign::Indirect) {
-      assert(
-          (VA.getValVT().isScalableVector() || Subtarget->isWindowsArm64EC()) &&
-          "Indirect arguments should be scalable on most subtargets");
+      assert((VA.getValVT().isScalableVT() ||
+              Subtarget->isWindowsArm64EC()) &&
+             "Indirect arguments should be scalable on most subtargets");
 
       uint64_t PartSize = VA.getValVT().getStoreSize().getKnownMinValue();
       unsigned NumParts = 1;
@@ -7399,7 +7411,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
       Arg = DAG.getNode(ISD::FP_EXTEND, DL, VA.getLocVT(), Arg);
       break;
     case CCValAssign::Indirect:
-      bool isScalable = VA.getValVT().isScalableVector();
+      bool isScalable = VA.getValVT().isScalableVT();
       assert((isScalable || Subtarget->isWindowsArm64EC()) &&
              "Indirect arguments should be scalable on most subtargets");
 
@@ -9288,10 +9300,17 @@ SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
   SDLoc DL(Op);
 
   EVT Ty = Op.getValueType();
+  if (Ty == MVT::aarch64svcount) {
+    TVal = DAG.getNode(ISD::BITCAST, DL, MVT::nxv16i1, TVal);
+    FVal = DAG.getNode(ISD::BITCAST, DL, MVT::nxv16i1, FVal);
+    SDValue Sel =
+        DAG.getNode(ISD::SELECT, DL, MVT::nxv16i1, CCVal, TVal, FVal);
+    return DAG.getNode(ISD::BITCAST, DL, Ty, Sel);
+  }
+
   if (Ty.isScalableVector()) {
-    SDValue TruncCC = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, CCVal);
     MVT PredVT = MVT::getVectorVT(MVT::i1, Ty.getVectorElementCount());
-    SDValue SplatPred = DAG.getNode(ISD::SPLAT_VECTOR, DL, PredVT, TruncCC);
+    SDValue SplatPred = DAG.getNode(ISD::SPLAT_VECTOR, DL, PredVT, CCVal);
     return DAG.getNode(ISD::VSELECT, DL, Ty, SplatPred, TVal, FVal);
   }
 
@@ -14876,6 +14895,9 @@ bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL,
     return false;
 
   // FIXME: Update this method to support scalable addressing modes.
+  if (Ty->isScalableTargetExtTy())
+    return AM.HasBaseReg && !AM.BaseOffs && !AM.Scale;
+
   if (isa<ScalableVectorType>(Ty)) {
     uint64_t VecElemNumBytes =
         DL.getTypeSizeInBits(cast<VectorType>(Ty)->getElementType()) / 8;
@@ -20835,7 +20857,7 @@ static SDValue performSelectCombine(SDNode *N,
   if (N0.getOpcode() != ISD::SETCC)
     return SDValue();
 
-  if (ResVT.isScalableVector())
+  if (ResVT.isScalableVT())
     return SDValue();
 
   // Make sure the SETCC result is either i1 (initial DAG), or i32, the lowered
@@ -23224,15 +23246,15 @@ bool AArch64TargetLowering::shouldLocalize(
 }
 
 bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
-  if (isa<ScalableVectorType>(Inst.getType()))
+  if (Inst.getType()->isScalableTy())
     return true;
 
   for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
-    if (isa<ScalableVectorType>(Inst.getOperand(i)->getType()))
+    if (Inst.getOperand(i)->getType()->isScalableTy())
       return true;
 
   if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
-    if (isa<ScalableVectorType>(AI->getAllocatedType()))
+    if (AI->getAllocatedType()->isScalableTy())
       return true;
   }
 

diff  --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
index 358f7f6c16563..9f26bd732609e 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
@@ -891,7 +891,7 @@ class ZPRRegOp <string Suffix, AsmOperandClass C, ElementSizeEnum Size,
 // SVE predicate register classes.
 class PPRClass<int firstreg, int lastreg> : RegisterClass<
                                   "AArch64",
-                                  [ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1 ], 16,
+                                  [ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1, aarch64svcount ], 16,
                                   (sequence "P%u", firstreg, lastreg)> {
   let Size = 16;
 }

diff  --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 8ac7559eceeb3..8c4ad6ca1a20c 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -2512,6 +2512,9 @@ let Predicates = [HasSVEorSME] in {
     def : Pat<(nxv8f16 (bitconvert (nxv8bf16 ZPR:$src))), (nxv8f16 ZPR:$src)>;
     def : Pat<(nxv4f32 (bitconvert (nxv8bf16 ZPR:$src))), (nxv4f32 ZPR:$src)>;
     def : Pat<(nxv2f64 (bitconvert (nxv8bf16 ZPR:$src))), (nxv2f64 ZPR:$src)>;
+
+    def : Pat<(nxv16i1 (bitconvert (aarch64svcount PPR:$src))), (nxv16i1 PPR:$src)>;
+    def : Pat<(aarch64svcount (bitconvert (nxv16i1 PPR:$src))), (aarch64svcount PPR:$src)>;
   }
 
   // These allow casting from/to unpacked predicate types.

diff  --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
index d1c1e1075fc82..2855211258966 100644
--- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
+++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
@@ -527,10 +527,9 @@ static void handleMustTailForwardedRegisters(MachineIRBuilder &MIRBuilder,
 
 bool AArch64CallLowering::fallBackToDAGISel(const MachineFunction &MF) const {
   auto &F = MF.getFunction();
-  if (isa<ScalableVectorType>(F.getReturnType()))
-    return true;
-  if (llvm::any_of(F.args(), [](const Argument &A) {
-        return isa<ScalableVectorType>(A.getType());
+  if (F.getReturnType()->isScalableTy() ||
+      llvm::any_of(F.args(), [](const Argument &A) {
+        return A.getType()->isScalableTy();
       }))
     return true;
   const auto &ST = MF.getSubtarget<AArch64Subtarget>();

diff  --git a/llvm/test/CodeGen/AArch64/sme-aarch64-svcount-O3.ll b/llvm/test/CodeGen/AArch64/sme-aarch64-svcount-O3.ll
new file mode 100644
index 0000000000000..531c666277aae
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-aarch64-svcount-O3.ll
@@ -0,0 +1,39 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -O3 -mtriple=aarch64 -mattr=+sme -S < %s | FileCheck %s
+
+; Test PHI nodes are allowed with opaque scalable types.
+define target("aarch64.svcount") @test_alloca_store_reload(target("aarch64.svcount") %val0, target("aarch64.svcount") %val1, ptr %iptr, ptr %pptr, i64 %N) nounwind {
+; CHECK-LABEL: @test_alloca_store_reload(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    store i64 0, ptr [[IPTR:%.*]], align 4
+; CHECK-NEXT:    store target("aarch64.svcount") [[VAL0:%.*]], ptr [[PPTR:%.*]], align 2
+; CHECK-NEXT:    [[I1_PEEL:%.*]] = icmp eq i64 [[N:%.*]], 0
+; CHECK-NEXT:    br i1 [[I1_PEEL]], label [[LOOP_EXIT:%.*]], label [[LOOP_BODY:%.*]]
+; CHECK:       loop.body:
+; CHECK-NEXT:    [[IND:%.*]] = phi i64 [ [[IND_NEXT:%.*]], [[LOOP_BODY]] ], [ 1, [[ENTRY:%.*]] ]
+; CHECK-NEXT:    [[IPTR_GEP:%.*]] = getelementptr i64, ptr [[IPTR]], i64 [[IND]]
+; CHECK-NEXT:    store i64 [[IND]], ptr [[IPTR_GEP]], align 4
+; CHECK-NEXT:    store target("aarch64.svcount") [[VAL1:%.*]], ptr [[PPTR]], align 2
+; CHECK-NEXT:    [[IND_NEXT]] = add i64 [[IND]], 1
+; CHECK-NEXT:    [[I1:%.*]] = icmp eq i64 [[IND]], [[N]]
+; CHECK-NEXT:    br i1 [[I1]], label [[LOOP_EXIT]], label [[LOOP_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK:       loop.exit:
+; CHECK-NEXT:    [[PHI_LCSSA:%.*]] = phi target("aarch64.svcount") [ [[VAL0]], [[ENTRY]] ], [ [[VAL1]], [[LOOP_BODY]] ]
+; CHECK-NEXT:    ret target("aarch64.svcount") [[PHI_LCSSA]]
+;
+entry:
+  br label %loop.body
+
+loop.body:
+  %ind = phi i64 [0, %entry], [%ind.next, %loop.body]
+  %phi = phi target("aarch64.svcount") [%val0, %entry], [%val1, %loop.body]
+  %iptr.gep = getelementptr i64, ptr %iptr, i64 %ind
+  store i64 %ind, ptr %iptr.gep
+  store target("aarch64.svcount") %phi, ptr %pptr
+  %ind.next = add i64 %ind, 1
+  %i1 = icmp eq i64 %ind, %N
+  br i1 %i1, label %loop.exit, label %loop.body
+
+loop.exit:
+  ret target("aarch64.svcount") %phi
+}

diff  --git a/llvm/test/CodeGen/AArch64/sme-aarch64-svcount.ll b/llvm/test/CodeGen/AArch64/sme-aarch64-svcount.ll
new file mode 100644
index 0000000000000..41b7999555f39
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme-aarch64-svcount.ll
@@ -0,0 +1,172 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -O0 -mtriple=aarch64 -mattr=+sve2p1 < %s | FileCheck %s --check-prefixes=CHECK,CHECK-O0
+; RUN: llc -O3 -mtriple=aarch64 -mattr=+sve2p1 < %s | FileCheck %s --check-prefixes=CHECK,CHECK-O3
+
+;
+; Test simple loads, stores and return.
+;
+define target("aarch64.svcount") @test_load(ptr %ptr) nounwind {
+; CHECK-LABEL: test_load:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldr p0, [x0]
+; CHECK-NEXT:    ret
+  %res = load target("aarch64.svcount"), ptr %ptr
+  ret target("aarch64.svcount") %res
+}
+
+define void @test_store(ptr %ptr, target("aarch64.svcount") %val) nounwind {
+; CHECK-LABEL: test_store:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str p0, [x0]
+; CHECK-NEXT:    ret
+  store target("aarch64.svcount") %val, ptr %ptr
+  ret void
+}
+
+define target("aarch64.svcount") @test_alloca_store_reload(target("aarch64.svcount") %val) nounwind {
+; CHECKO0-LABEL: test_alloca_store_reload:
+; CHECKO0:       // %bb.0:
+; CHECKO0-NEXT:    sub sp, sp, #16
+; CHECKO0-NEXT:    add x8, sp, #14
+; CHECKO0-NEXT:    str p0, [x8]
+; CHECKO0-NEXT:    ldr p0, [x8]
+; CHECKO0-NEXT:    add sp, sp, #16
+; CHECKO0-NEXT:    ret
+;
+; CHECKO3-LABEL: test_alloca_store_reload:
+; CHECKO3:       // %bb.0:
+; CHECKO3-NEXT:    sub sp, sp, #16
+; CHECKO3-NEXT:    add x8, sp, #14
+; CHECKO3-NEXT:    str p0, [x8]
+; CHECKO3-NEXT:    add sp, sp, #16
+; CHECKO3-NEXT:    ret
+; CHECK-O0-LABEL: test_alloca_store_reload:
+; CHECK-O0:       // %bb.0:
+; CHECK-O0-NEXT:    sub sp, sp, #16
+; CHECK-O0-NEXT:    add x8, sp, #14
+; CHECK-O0-NEXT:    str p0, [x8]
+; CHECK-O0-NEXT:    ldr p0, [x8]
+; CHECK-O0-NEXT:    add sp, sp, #16
+; CHECK-O0-NEXT:    ret
+;
+; CHECK-O3-LABEL: test_alloca_store_reload:
+; CHECK-O3:       // %bb.0:
+; CHECK-O3-NEXT:    sub sp, sp, #16
+; CHECK-O3-NEXT:    add x8, sp, #14
+; CHECK-O3-NEXT:    str p0, [x8]
+; CHECK-O3-NEXT:    add sp, sp, #16
+; CHECK-O3-NEXT:    ret
+  %ptr = alloca target("aarch64.svcount"), align 1
+  store target("aarch64.svcount") %val, ptr %ptr
+  %res = load target("aarch64.svcount"), ptr %ptr
+  ret target("aarch64.svcount") %res
+}
+
+;
+; Test passing as arguments (from perspective of callee)
+;
+
+define target("aarch64.svcount") @test_return_arg1(target("aarch64.svcount") %arg0, target("aarch64.svcount") %arg1) nounwind {
+; CHECK-LABEL: test_return_arg1:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov p0.b, p1.b
+; CHECK-NEXT:    ret
+  ret target("aarch64.svcount") %arg1
+}
+
+define target("aarch64.svcount") @test_return_arg4(target("aarch64.svcount") %arg0, target("aarch64.svcount") %arg1, target("aarch64.svcount") %arg2, target("aarch64.svcount") %arg3, target("aarch64.svcount") %arg4) nounwind {
+; CHECK-LABEL: test_return_arg4:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldr p0, [x0]
+; CHECK-NEXT:    ret
+  ret target("aarch64.svcount") %arg4
+}
+
+;
+; Test passing as arguments (from perspective of caller)
+;
+
+declare void @take_svcount_1(target("aarch64.svcount") %arg)
+define void @test_pass_1arg(target("aarch64.svcount") %arg) nounwind {
+; CHECK-LABEL: test_pass_1arg:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x30, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    bl take_svcount_1
+; CHECK-NEXT:    ldr x30, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  call void @take_svcount_1(target("aarch64.svcount") %arg)
+  ret void
+}
+
+declare void @take_svcount_5(target("aarch64.svcount") %arg0, target("aarch64.svcount") %arg1, target("aarch64.svcount") %arg2, target("aarch64.svcount") %arg3, target("aarch64.svcount") %arg4)
+define void @test_pass_5args(target("aarch64.svcount") %arg) nounwind {
+; CHECKO0-LABEL: test_pass_5args:
+; CHECKO0:       // %bb.0:
+; CHECKO0-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECKO0-NEXT:    addvl sp, sp, #-1
+; CHECKO0-NEXT:    mov p3.b, p0.b
+; CHECKO0-NEXT:    str p3, [sp, #7, mul vl]
+; CHECKO0-NEXT:    addpl x0, sp, #7
+; CHECKO0-NEXT:    mov p0.b, p3.b
+; CHECKO0-NEXT:    mov p1.b, p3.b
+; CHECKO0-NEXT:    mov p2.b, p3.b
+; CHECKO0-NEXT:    bl take_svcount_5
+; CHECKO0-NEXT:    addvl sp, sp, #1
+; CHECKO0-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECKO0-NEXT:    ret
+;
+; CHECKO3-LABEL: test_pass_5args:
+; CHECKO3:       // %bb.0:
+; CHECKO3-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECKO3-NEXT:    addvl sp, sp, #-1
+; CHECKO3-NEXT:    addpl x0, sp, #7
+; CHECKO3-NEXT:    mov p1.b, p0.b
+; CHECKO3-NEXT:    mov p2.b, p0.b
+; CHECKO3-NEXT:    mov p3.b, p0.b
+; CHECKO3-NEXT:    str p0, [sp, #7, mul vl]
+; CHECKO3-NEXT:    bl take_svcount_5
+; CHECKO3-NEXT:    addvl sp, sp, #1
+; CHECKO3-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECKO3-NEXT:    ret
+; CHECK-O0-LABEL: test_pass_5args:
+; CHECK-O0:       // %bb.0:
+; CHECK-O0-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-O0-NEXT:    addvl sp, sp, #-1
+; CHECK-O0-NEXT:    mov p3.b, p0.b
+; CHECK-O0-NEXT:    str p3, [sp, #7, mul vl]
+; CHECK-O0-NEXT:    addpl x0, sp, #7
+; CHECK-O0-NEXT:    mov p0.b, p3.b
+; CHECK-O0-NEXT:    mov p1.b, p3.b
+; CHECK-O0-NEXT:    mov p2.b, p3.b
+; CHECK-O0-NEXT:    bl take_svcount_5
+; CHECK-O0-NEXT:    addvl sp, sp, #1
+; CHECK-O0-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-O0-NEXT:    ret
+;
+; CHECK-O3-LABEL: test_pass_5args:
+; CHECK-O3:       // %bb.0:
+; CHECK-O3-NEXT:    stp x29, x30, [sp, #-16]! // 16-byte Folded Spill
+; CHECK-O3-NEXT:    addvl sp, sp, #-1
+; CHECK-O3-NEXT:    addpl x0, sp, #7
+; CHECK-O3-NEXT:    mov p1.b, p0.b
+; CHECK-O3-NEXT:    mov p2.b, p0.b
+; CHECK-O3-NEXT:    mov p3.b, p0.b
+; CHECK-O3-NEXT:    str p0, [sp, #7, mul vl]
+; CHECK-O3-NEXT:    bl take_svcount_5
+; CHECK-O3-NEXT:    addvl sp, sp, #1
+; CHECK-O3-NEXT:    ldp x29, x30, [sp], #16 // 16-byte Folded Reload
+; CHECK-O3-NEXT:    ret
+  call void @take_svcount_5(target("aarch64.svcount") %arg, target("aarch64.svcount") %arg, target("aarch64.svcount") %arg, target("aarch64.svcount") %arg, target("aarch64.svcount") %arg)
+  ret void
+}
+
+define target("aarch64.svcount") @test_sel(target("aarch64.svcount") %x, target("aarch64.svcount") %y, i1 %cmp) {
+  %x.y = select i1 %cmp, target("aarch64.svcount") %x, target("aarch64.svcount") %y
+  ret target("aarch64.svcount") %x.y
+}
+
+define target("aarch64.svcount") @test_sel_cc(target("aarch64.svcount") %x, target("aarch64.svcount") %y, i32 %k) {
+  %cmp = icmp sgt i32 %k, 42
+  %x.y = select i1 %cmp, target("aarch64.svcount") %x, target("aarch64.svcount") %y
+  ret target("aarch64.svcount") %x.y
+}

diff  --git a/llvm/utils/TableGen/CodeGenTarget.cpp b/llvm/utils/TableGen/CodeGenTarget.cpp
index f500d304d6f04..672728ae16c20 100644
--- a/llvm/utils/TableGen/CodeGenTarget.cpp
+++ b/llvm/utils/TableGen/CodeGenTarget.cpp
@@ -82,6 +82,7 @@ StringRef llvm::getEnumName(MVT::SimpleValueType T) {
   case MVT::ppcf128:  return "MVT::ppcf128";
   case MVT::x86mmx:   return "MVT::x86mmx";
   case MVT::x86amx:   return "MVT::x86amx";
+  case MVT::aarch64svcount:   return "MVT::aarch64svcount";
   case MVT::i64x8:    return "MVT::i64x8";
   case MVT::Glue:     return "MVT::Glue";
   case MVT::isVoid:   return "MVT::isVoid";


        


More information about the llvm-commits mailing list