[llvm] bd7f7e2 - [GlobalISel] Add scalable property to LLT types.

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 22 01:25:01 PDT 2021


Author: Sander de Smalen
Date: 2021-06-22T08:43:34+01:00
New Revision: bd7f7e2ebae4e5bc95f0ca65efbc72575ca31c14

URL: https://github.com/llvm/llvm-project/commit/bd7f7e2ebae4e5bc95f0ca65efbc72575ca31c14
DIFF: https://github.com/llvm/llvm-project/commit/bd7f7e2ebae4e5bc95f0ca65efbc72575ca31c14.diff

LOG: [GlobalISel] Add scalable property to LLT types.

This patch aims to add the scalable property to LLT. The rest of the
patch-series changes the interfaces to take/return ElementCount and
TypeSize, which both have the ability to represent the scalable property.

The changes are mostly mechanical and aim to be non-functional changes
for fixed-width vectors.

For scalable vectors some unit tests have been added, but no effort has
been put into making any of the GlobalISel algorithms work with scalable
vectors yet. That will be left as future work.

The work is split into a series of 5 patches to make reviews easier.

Reviewed By: arsenm

Differential Revision: https://reviews.llvm.org/D104450

Added: 
    

Modified: 
    llvm/include/llvm/Support/LowLevelTypeImpl.h
    llvm/lib/CodeGen/LowLevelType.cpp
    llvm/lib/Support/LowLevelType.cpp
    llvm/unittests/CodeGen/LowLevelTypeTest.cpp
    llvm/utils/TableGen/GlobalISelEmitter.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Support/LowLevelTypeImpl.h b/llvm/include/llvm/Support/LowLevelTypeImpl.h
index 28efaa06c958d..77049544c2714 100644
--- a/llvm/include/llvm/Support/LowLevelTypeImpl.h
+++ b/llvm/include/llvm/Support/LowLevelTypeImpl.h
@@ -42,31 +42,37 @@ class LLT {
   /// Get a low-level scalar or aggregate "bag of bits".
   static LLT scalar(unsigned SizeInBits) {
     assert(SizeInBits > 0 && "invalid scalar size");
-    return LLT{/*isPointer=*/false, /*isVector=*/false, /*NumElements=*/0,
-               SizeInBits, /*AddressSpace=*/0};
+    return LLT{/*isPointer=*/false, /*isVector=*/false,
+               ElementCount::getFixed(0), SizeInBits,
+               /*AddressSpace=*/0};
   }
 
   /// Get a low-level pointer in the given address space.
   static LLT pointer(unsigned AddressSpace, unsigned SizeInBits) {
     assert(SizeInBits > 0 && "invalid pointer size");
-    return LLT{/*isPointer=*/true, /*isVector=*/false, /*NumElements=*/0,
-               SizeInBits, AddressSpace};
+    return LLT{/*isPointer=*/true, /*isVector=*/false,
+               ElementCount::getFixed(0), SizeInBits, AddressSpace};
   }
 
   /// Get a low-level vector of some number of elements and element width.
   /// \p NumElements must be at least 2.
-  static LLT vector(uint16_t NumElements, unsigned ScalarSizeInBits) {
-    assert(NumElements > 1 && "invalid number of vector elements");
+  static LLT vector(uint16_t NumElements, unsigned ScalarSizeInBits,
+                    bool Scalable = false) {
+    assert(((!Scalable && NumElements > 1) || NumElements > 0) &&
+           "invalid number of vector elements");
     assert(ScalarSizeInBits > 0 && "invalid vector element size");
-    return LLT{/*isPointer=*/false, /*isVector=*/true, NumElements,
-               ScalarSizeInBits, /*AddressSpace=*/0};
+    return LLT{/*isPointer=*/false, /*isVector=*/true,
+               ElementCount::get(NumElements, Scalable), ScalarSizeInBits,
+               /*AddressSpace=*/0};
   }
 
   /// Get a low-level vector of some number of elements and element type.
-  static LLT vector(uint16_t NumElements, LLT ScalarTy) {
-    assert(NumElements > 1 && "invalid number of vector elements");
+  static LLT vector(uint16_t NumElements, LLT ScalarTy, bool Scalable = false) {
+    assert(((!Scalable && NumElements > 1) || NumElements > 0) &&
+           "invalid number of vector elements");
     assert(!ScalarTy.isVector() && "invalid vector element type");
-    return LLT{ScalarTy.isPointer(), /*isVector=*/true, NumElements,
+    return LLT{ScalarTy.isPointer(), /*isVector=*/true,
+               ElementCount::get(NumElements, Scalable),
                ScalarTy.getSizeInBits(),
                ScalarTy.isPointer() ? ScalarTy.getAddressSpace() : 0};
   }
@@ -79,9 +85,9 @@ class LLT {
     return scalarOrVector(NumElements, LLT::scalar(ScalarSize));
   }
 
-  explicit LLT(bool isPointer, bool isVector, uint16_t NumElements,
+  explicit LLT(bool isPointer, bool isVector, ElementCount EC,
                unsigned SizeInBits, unsigned AddressSpace) {
-    init(isPointer, isVector, NumElements, SizeInBits, AddressSpace);
+    init(isPointer, isVector, EC, SizeInBits, AddressSpace);
   }
   explicit LLT() : IsPointer(false), IsVector(false), RawData(0) {}
 
@@ -98,18 +104,37 @@ class LLT {
   /// Returns the number of elements in a vector LLT. Must only be called on
   /// vector types.
   uint16_t getNumElements() const {
+    if (isScalable())
+      llvm::reportInvalidSizeRequest(
+          "Possible incorrect use of LLT::getNumElements() for "
+          "scalable vector. Scalable flag may be dropped, use "
+          "LLT::getElementCount() instead");
+    return getElementCount().getKnownMinValue();
+  }
+
+  /// Returns true if the LLT is a scalable vector. Must only be called on
+  /// vector types.
+  bool isScalable() const {
+    assert(isVector() && "Expected a vector type");
+    return IsPointer ? getFieldValue(PointerVectorScalableFieldInfo)
+                     : getFieldValue(VectorScalableFieldInfo);
+  }
+
+  ElementCount getElementCount() const {
     assert(IsVector && "cannot get number of elements on scalar/aggregate");
-    if (!IsPointer)
-      return getFieldValue(VectorElementsFieldInfo);
-    else
-      return getFieldValue(PointerVectorElementsFieldInfo);
+    return ElementCount::get(IsPointer
+                                 ? getFieldValue(PointerVectorElementsFieldInfo)
+                                 : getFieldValue(VectorElementsFieldInfo),
+                             isScalable());
   }
 
   /// Returns the total size of the type. Must only be called on sized types.
   unsigned getSizeInBits() const {
     if (isPointer() || isScalar())
       return getScalarSizeInBits();
-    return getScalarSizeInBits() * getNumElements();
+    // FIXME: This should return a TypeSize in order to work for scalable
+    // vectors.
+    return getScalarSizeInBits() * getElementCount().getKnownMinValue();
   }
 
   /// Returns the total size of the type in bytes, i.e. number of whole bytes
@@ -125,7 +150,9 @@ class LLT {
   /// If this type is a vector, return a vector with the same number of elements
   /// but the new element type. Otherwise, return the new element type.
   LLT changeElementType(LLT NewEltTy) const {
-    return isVector() ? LLT::vector(getNumElements(), NewEltTy) : NewEltTy;
+    return isVector() ? LLT::vector(getElementCount().getKnownMinValue(),
+                                    NewEltTy, isScalable())
+                      : NewEltTy;
   }
 
   /// If this type is a vector, return a vector with the same number of elements
@@ -134,13 +161,16 @@ class LLT {
   LLT changeElementSize(unsigned NewEltSize) const {
     assert(!getScalarType().isPointer() &&
            "invalid to directly change element size for pointers");
-    return isVector() ? LLT::vector(getNumElements(), NewEltSize)
+    return isVector() ? LLT::vector(getElementCount().getKnownMinValue(),
+                                    NewEltSize, isScalable())
                       : LLT::scalar(NewEltSize);
   }
 
   /// Return a vector or scalar with the same element type and the new number of
   /// elements.
   LLT changeNumElements(unsigned NewNumElts) const {
+    assert((!isVector() || !isScalable()) &&
+           "Cannot use changeNumElements on a scalable vector");
     return LLT::scalarOrVector(NewNumElts, getScalarType());
   }
 
@@ -237,22 +267,37 @@ class LLT {
   static const constexpr BitFieldInfo PointerSizeFieldInfo{16, 0};
   static const constexpr BitFieldInfo PointerAddressSpaceFieldInfo{
       24, PointerSizeFieldInfo[0] + PointerSizeFieldInfo[1]};
+  static_assert((PointerAddressSpaceFieldInfo[0] +
+                 PointerAddressSpaceFieldInfo[1]) <= 62,
+                "Insufficient bits to encode all data");
   /// * Vector-of-non-pointer (isPointer == 0 && isVector == 1):
   ///   NumElements: 16;
   ///   SizeOfElement: 32;
+  ///   Scalable: 1;
   static const constexpr BitFieldInfo VectorElementsFieldInfo{16, 0};
   static const constexpr BitFieldInfo VectorSizeFieldInfo{
       32, VectorElementsFieldInfo[0] + VectorElementsFieldInfo[1]};
+  static const constexpr BitFieldInfo VectorScalableFieldInfo{
+      1, VectorSizeFieldInfo[0] + VectorSizeFieldInfo[1]};
+  static_assert((VectorSizeFieldInfo[0] + VectorSizeFieldInfo[1]) <= 62,
+                "Insufficient bits to encode all data");
   /// * Vector-of-pointer (isPointer == 1 && isVector == 1):
   ///   NumElements: 16;
   ///   SizeOfElement: 16;
   ///   AddressSpace: 24;
+  ///   Scalable: 1;
   static const constexpr BitFieldInfo PointerVectorElementsFieldInfo{16, 0};
   static const constexpr BitFieldInfo PointerVectorSizeFieldInfo{
       16,
       PointerVectorElementsFieldInfo[1] + PointerVectorElementsFieldInfo[0]};
   static const constexpr BitFieldInfo PointerVectorAddressSpaceFieldInfo{
       24, PointerVectorSizeFieldInfo[1] + PointerVectorSizeFieldInfo[0]};
+  static const constexpr BitFieldInfo PointerVectorScalableFieldInfo{
+      1, PointerVectorAddressSpaceFieldInfo[0] +
+             PointerVectorAddressSpaceFieldInfo[1]};
+  static_assert((PointerVectorAddressSpaceFieldInfo[0] +
+                 PointerVectorAddressSpaceFieldInfo[1]) <= 62,
+                "Insufficient bits to encode all data");
 
   uint64_t IsPointer : 1;
   uint64_t IsVector : 1;
@@ -273,8 +318,8 @@ class LLT {
     return getMask(FieldInfo) & (RawData >> FieldInfo[1]);
   }
 
-  void init(bool IsPointer, bool IsVector, uint16_t NumElements,
-            unsigned SizeInBits, unsigned AddressSpace) {
+  void init(bool IsPointer, bool IsVector, ElementCount EC, unsigned SizeInBits,
+            unsigned AddressSpace) {
     this->IsPointer = IsPointer;
     this->IsVector = IsVector;
     if (!IsVector) {
@@ -284,15 +329,20 @@ class LLT {
         RawData = maskAndShift(SizeInBits, PointerSizeFieldInfo) |
                   maskAndShift(AddressSpace, PointerAddressSpaceFieldInfo);
     } else {
-      assert(NumElements > 1 && "invalid number of vector elements");
+      assert(EC.isVector() && "invalid number of vector elements");
       if (!IsPointer)
-        RawData = maskAndShift(NumElements, VectorElementsFieldInfo) |
-                  maskAndShift(SizeInBits, VectorSizeFieldInfo);
+        RawData =
+            maskAndShift(EC.getKnownMinValue(), VectorElementsFieldInfo) |
+            maskAndShift(SizeInBits, VectorSizeFieldInfo) |
+            maskAndShift(EC.isScalable() ? 1 : 0, VectorScalableFieldInfo);
       else
         RawData =
-            maskAndShift(NumElements, PointerVectorElementsFieldInfo) |
+            maskAndShift(EC.getKnownMinValue(),
+                         PointerVectorElementsFieldInfo) |
             maskAndShift(SizeInBits, PointerVectorSizeFieldInfo) |
-            maskAndShift(AddressSpace, PointerVectorAddressSpaceFieldInfo);
+            maskAndShift(AddressSpace, PointerVectorAddressSpaceFieldInfo) |
+            maskAndShift(EC.isScalable() ? 1 : 0,
+                         PointerVectorScalableFieldInfo);
     }
   }
 

diff  --git a/llvm/lib/CodeGen/LowLevelType.cpp b/llvm/lib/CodeGen/LowLevelType.cpp
index 2bda586db8c78..9e0b117794ec2 100644
--- a/llvm/lib/CodeGen/LowLevelType.cpp
+++ b/llvm/lib/CodeGen/LowLevelType.cpp
@@ -20,11 +20,11 @@ using namespace llvm;
 
 LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) {
   if (auto VTy = dyn_cast<VectorType>(&Ty)) {
-    auto NumElements = cast<FixedVectorType>(VTy)->getNumElements();
+    auto EC = VTy->getElementCount();
     LLT ScalarTy = getLLTForType(*VTy->getElementType(), DL);
-    if (NumElements == 1)
+    if (EC.isScalar())
       return ScalarTy;
-    return LLT::vector(NumElements, ScalarTy);
+    return LLT::vector(EC.getKnownMinValue(), ScalarTy, EC.isScalable());
   }
 
   if (auto PTy = dyn_cast<PointerType>(&Ty)) {

diff  --git a/llvm/lib/Support/LowLevelType.cpp b/llvm/lib/Support/LowLevelType.cpp
index 63559d5ac3eee..42b91e9a30116 100644
--- a/llvm/lib/Support/LowLevelType.cpp
+++ b/llvm/lib/Support/LowLevelType.cpp
@@ -18,13 +18,13 @@ using namespace llvm;
 LLT::LLT(MVT VT) {
   if (VT.isVector()) {
     init(/*IsPointer=*/false, VT.getVectorNumElements() > 1,
-         VT.getVectorNumElements(), VT.getVectorElementType().getSizeInBits(),
+         VT.getVectorElementCount(), VT.getVectorElementType().getSizeInBits(),
          /*AddressSpace=*/0);
   } else if (VT.isValid()) {
     // Aggregates are no 
diff erent from real scalars as far as GlobalISel is
     // concerned.
     assert(VT.getSizeInBits().isNonZero() && "invalid zero-sized type");
-    init(/*IsPointer=*/false, /*IsVector=*/false, /*NumElements=*/0,
+    init(/*IsPointer=*/false, /*IsVector=*/false, ElementCount::getFixed(0),
          VT.getSizeInBits(), /*AddressSpace=*/0);
   } else {
     IsPointer = false;
@@ -34,9 +34,10 @@ LLT::LLT(MVT VT) {
 }
 
 void LLT::print(raw_ostream &OS) const {
-  if (isVector())
-    OS << "<" << getNumElements() << " x " << getElementType() << ">";
-  else if (isPointer())
+  if (isVector()) {
+    OS << "<";
+    OS << getElementCount() << " x " << getElementType() << ">";
+  } else if (isPointer())
     OS << "p" << getAddressSpace();
   else if (isValid()) {
     assert(isScalar() && "unexpected type");
@@ -49,7 +50,9 @@ const constexpr LLT::BitFieldInfo LLT::ScalarSizeFieldInfo;
 const constexpr LLT::BitFieldInfo LLT::PointerSizeFieldInfo;
 const constexpr LLT::BitFieldInfo LLT::PointerAddressSpaceFieldInfo;
 const constexpr LLT::BitFieldInfo LLT::VectorElementsFieldInfo;
+const constexpr LLT::BitFieldInfo LLT::VectorScalableFieldInfo;
 const constexpr LLT::BitFieldInfo LLT::VectorSizeFieldInfo;
 const constexpr LLT::BitFieldInfo LLT::PointerVectorElementsFieldInfo;
+const constexpr LLT::BitFieldInfo LLT::PointerVectorScalableFieldInfo;
 const constexpr LLT::BitFieldInfo LLT::PointerVectorSizeFieldInfo;
 const constexpr LLT::BitFieldInfo LLT::PointerVectorAddressSpaceFieldInfo;

diff  --git a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp
index adf138e818e50..65f82169050f7 100644
--- a/llvm/unittests/CodeGen/LowLevelTypeTest.cpp
+++ b/llvm/unittests/CodeGen/LowLevelTypeTest.cpp
@@ -11,6 +11,7 @@
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Type.h"
+#include "llvm/Support/TypeSize.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
@@ -50,13 +51,19 @@ TEST(LowLevelTypeTest, Vector) {
   DataLayout DL("");
 
   for (unsigned S : {1U, 17U, 32U, 64U, 0xfffU}) {
-    for (uint16_t Elts : {2U, 3U, 4U, 32U, 0xffU}) {
+    for (auto EC :
+         {ElementCount::getFixed(2), ElementCount::getFixed(3),
+          ElementCount::getFixed(4), ElementCount::getFixed(32),
+          ElementCount::getFixed(0xff), ElementCount::getScalable(2),
+          ElementCount::getScalable(3), ElementCount::getScalable(4),
+          ElementCount::getScalable(32), ElementCount::getScalable(0xff)}) {
       const LLT STy = LLT::scalar(S);
-      const LLT VTy = LLT::vector(Elts, S);
+      const LLT VTy = LLT::vector(EC.getKnownMinValue(), S, EC.isScalable());
 
       // Test the alternative vector().
       {
-        const LLT VSTy = LLT::vector(Elts, STy);
+        const LLT VSTy =
+            LLT::vector(EC.getKnownMinValue(), STy, EC.isScalable());
         EXPECT_EQ(VTy, VSTy);
       }
 
@@ -71,9 +78,10 @@ TEST(LowLevelTypeTest, Vector) {
       ASSERT_FALSE(VTy.isPointer());
 
       // Test sizes.
-      EXPECT_EQ(S * Elts, VTy.getSizeInBits());
       EXPECT_EQ(S, VTy.getScalarSizeInBits());
-      EXPECT_EQ(Elts, VTy.getNumElements());
+      EXPECT_EQ(EC, VTy.getElementCount());
+      if (!EC.isScalable())
+        EXPECT_EQ(S * EC.getFixedValue(), VTy.getSizeInBits());
 
       // Test equality operators.
       EXPECT_TRUE(VTy == VTy);
@@ -85,7 +93,7 @@ TEST(LowLevelTypeTest, Vector) {
 
       // Test Type->LLT conversion.
       Type *IRSTy = IntegerType::get(C, S);
-      Type *IRTy = FixedVectorType::get(IRSTy, Elts);
+      Type *IRTy = VectorType::get(IRSTy, EC);
       EXPECT_EQ(VTy, getLLTForType(*IRTy, DL));
     }
   }
@@ -136,6 +144,22 @@ TEST(LowLevelTypeTest, ChangeElementType) {
 
   EXPECT_EQ(V2P1, V2P0.changeElementType(P1));
   EXPECT_EQ(V2S32, V2P0.changeElementType(S32));
+
+  // Similar tests for for scalable vectors.
+  const LLT NXV2S32 = LLT::vector(2, 32, true);
+  const LLT NXV2S64 = LLT::vector(2, 64, true);
+
+  const LLT NXV2P0 = LLT::vector(2, P0, true);
+  const LLT NXV2P1 = LLT::vector(2, P1, true);
+
+  EXPECT_EQ(NXV2S64, NXV2S32.changeElementType(S64));
+  EXPECT_EQ(NXV2S32, NXV2S64.changeElementType(S32));
+
+  EXPECT_EQ(NXV2S64, NXV2S32.changeElementSize(64));
+  EXPECT_EQ(NXV2S32, NXV2S64.changeElementSize(32));
+
+  EXPECT_EQ(NXV2P1, NXV2P0.changeElementType(P1));
+  EXPECT_EQ(NXV2S32, NXV2P0.changeElementType(S32));
 }
 
 TEST(LowLevelTypeTest, ChangeNumElements) {
@@ -191,9 +215,14 @@ TEST(LowLevelTypeTest, Pointer) {
   for (unsigned AS : {0U, 1U, 127U, 0xffffU,
         static_cast<unsigned>(maxUIntN(23)),
         static_cast<unsigned>(maxUIntN(24))}) {
-    for (unsigned NumElts : {2, 3, 4, 256, 65535}) {
+    for (ElementCount EC :
+         {ElementCount::getFixed(2), ElementCount::getFixed(3),
+          ElementCount::getFixed(4), ElementCount::getFixed(256),
+          ElementCount::getFixed(65535), ElementCount::getScalable(2),
+          ElementCount::getScalable(3), ElementCount::getScalable(4),
+          ElementCount::getScalable(256), ElementCount::getScalable(65535)}) {
       const LLT Ty = LLT::pointer(AS, DL.getPointerSizeInBits(AS));
-      const LLT VTy = LLT::vector(NumElts, Ty);
+      const LLT VTy = LLT::vector(EC.getKnownMinValue(), Ty, EC.isScalable());
 
       // Test kind.
       ASSERT_TRUE(Ty.isValid());
@@ -222,8 +251,8 @@ TEST(LowLevelTypeTest, Pointer) {
       // Test Type->LLT conversion.
       Type *IRTy = PointerType::get(IntegerType::get(C, 8), AS);
       EXPECT_EQ(Ty, getLLTForType(*IRTy, DL));
-      Type *IRVTy = FixedVectorType::get(
-          PointerType::get(IntegerType::get(C, 8), AS), NumElts);
+      Type *IRVTy =
+          VectorType::get(PointerType::get(IntegerType::get(C, 8), AS), EC);
       EXPECT_EQ(VTy, getLLTForType(*IRVTy, DL));
     }
   }

diff  --git a/llvm/utils/TableGen/GlobalISelEmitter.cpp b/llvm/utils/TableGen/GlobalISelEmitter.cpp
index 06513e8ccdec6..71b73e3906e25 100644
--- a/llvm/utils/TableGen/GlobalISelEmitter.cpp
+++ b/llvm/utils/TableGen/GlobalISelEmitter.cpp
@@ -118,7 +118,9 @@ class LLTCodeGen {
       return;
     }
     if (Ty.isVector()) {
-      OS << "GILLT_v" << Ty.getNumElements() << "s" << Ty.getScalarSizeInBits();
+      OS << (Ty.isScalable() ? "GILLT_nxv" : "GILLT_v")
+         << Ty.getElementCount().getKnownMinValue() << "s"
+         << Ty.getScalarSizeInBits();
       return;
     }
     if (Ty.isPointer()) {
@@ -136,8 +138,8 @@ class LLTCodeGen {
       return;
     }
     if (Ty.isVector()) {
-      OS << "LLT::vector(" << Ty.getNumElements() << ", "
-         << Ty.getScalarSizeInBits() << ")";
+      OS << "LLT::vector(" << Ty.getElementCount().getKnownMinValue() << ", "
+         << Ty.getScalarSizeInBits() << ", " << Ty.isScalable() << ")";
       return;
     }
     if (Ty.isPointer() && Ty.getSizeInBits() > 0) {
@@ -169,9 +171,14 @@ class LLTCodeGen {
     if (Ty.isPointer() && Ty.getAddressSpace() != Other.Ty.getAddressSpace())
       return Ty.getAddressSpace() < Other.Ty.getAddressSpace();
 
-    if (Ty.isVector() && Ty.getNumElements() != Other.Ty.getNumElements())
-      return Ty.getNumElements() < Other.Ty.getNumElements();
+    if (Ty.isVector() && Ty.getElementCount() != Other.Ty.getElementCount())
+      return std::make_tuple(Ty.isScalable(),
+                             Ty.getElementCount().getKnownMinValue()) <
+             std::make_tuple(Other.Ty.isScalable(),
+                             Other.Ty.getElementCount().getKnownMinValue());
 
+    assert((!Ty.isVector() || Ty.isScalable() == Other.Ty.isScalable()) &&
+           "Unexpected mismatch of scalable property");
     return Ty.getSizeInBits() < Other.Ty.getSizeInBits();
   }
 
@@ -187,12 +194,10 @@ class InstructionMatcher;
 static Optional<LLTCodeGen> MVTToLLT(MVT::SimpleValueType SVT) {
   MVT VT(SVT);
 
-  if (VT.isScalableVector())
-    return None;
-
-  if (VT.isFixedLengthVector() && VT.getVectorNumElements() != 1)
-    return LLTCodeGen(
-        LLT::vector(VT.getVectorNumElements(), VT.getScalarSizeInBits()));
+  if (VT.isVector() && !VT.getVectorElementCount().isScalar())
+    return LLTCodeGen(LLT::vector(VT.getVectorNumElements(),
+                                  VT.getScalarSizeInBits(),
+                                  VT.isScalableVector()));
 
   if (VT.isInteger() || VT.isFloatingPoint())
     return LLTCodeGen(LLT::scalar(VT.getSizeInBits()));


        


More information about the llvm-commits mailing list