[llvm] 164f4b9 - [CodeGen][SVE] Calculate correct type legalization for scalable vectors.

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 5 07:21:34 PDT 2020


Author: Sander de Smalen
Date: 2020-06-05T15:20:34+01:00
New Revision: 164f4b9d26fdf3cd640a09b63b5ec44d033cbe8a

URL: https://github.com/llvm/llvm-project/commit/164f4b9d26fdf3cd640a09b63b5ec44d033cbe8a
DIFF: https://github.com/llvm/llvm-project/commit/164f4b9d26fdf3cd640a09b63b5ec44d033cbe8a.diff

LOG: [CodeGen][SVE] Calculate correct type legalization for scalable vectors.

This patch updates TargetLoweringBase::computeRegisterProperties and
TargetLoweringBase::getTypeConversion to support scalable vectors,
and make the right calls on how to legalise them. These changes are required
to legalise both MVTs and EVTs.

Reviewers: efriedma, david-arm, ctetreau

Reviewed By: efriedma

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D80640

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/include/llvm/Support/TypeSize.h
    llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeTypesGeneric.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
    llvm/lib/CodeGen/TargetLoweringBase.cpp
    llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 05770163cab4..b3c3bcadc4cd 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -210,6 +210,13 @@ class TargetLoweringBase {
     TypeWidenVector,     // This vector should be widened into a larger vector.
     TypePromoteFloat,    // Replace this float with a larger one.
     TypeSoftPromoteHalf, // Soften half to i16 and use float to do arithmetic.
+    TypeScalarizeScalableVector, // This action is explicitly left unimplemented.
+                                 // While it is theoretically possible to
+                                 // legalize operations on scalable types with a
+                                 // loop that handles the vscale * #lanes of the
+                                 // vector, this is non-trivial at SelectionDAG
+                                 // level and these types are better to be
+                                 // widened or promoted.
   };
 
   /// LegalizeKind holds the legalization kind that needs to happen to EVT
@@ -412,7 +419,7 @@ class TargetLoweringBase {
   virtual TargetLoweringBase::LegalizeTypeAction
   getPreferredVectorAction(MVT VT) const {
     // The default action for one element vectors is to scalarize
-    if (VT.getVectorNumElements() == 1)
+    if (VT.getVectorElementCount() == 1)
       return TypeScalarizeVector;
     // The default action for an odd-width vector is to widen.
     if (!VT.isPow2VectorType())

diff  --git a/llvm/include/llvm/Support/TypeSize.h b/llvm/include/llvm/Support/TypeSize.h
index b4c869fcb9e4..05295f40049f 100644
--- a/llvm/include/llvm/Support/TypeSize.h
+++ b/llvm/include/llvm/Support/TypeSize.h
@@ -15,6 +15,7 @@
 #ifndef LLVM_SUPPORT_TYPESIZE_H
 #define LLVM_SUPPORT_TYPESIZE_H
 
+#include "llvm/Support/MathExtras.h"
 #include "llvm/Support/WithColor.h"
 
 #include <cstdint>
@@ -49,6 +50,12 @@ class ElementCount {
   bool operator!=(const ElementCount& RHS) const {
     return !(*this == RHS);
   }
+  bool operator==(unsigned RHS) const { return Min == RHS && !Scalable; }
+  bool operator!=(unsigned RHS) const { return !(*this == RHS); }
+
+  ElementCount NextPowerOf2() const {
+    return ElementCount(llvm::NextPowerOf2(Min), Scalable);
+  }
 };
 
 // This class is used to represent the size of types. If the type is of fixed

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 70ef59338375..873812d074d9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -344,6 +344,8 @@ SDValue DAGTypeLegalizer::PromoteIntRes_BITCAST(SDNode *N) {
       return DAG.getNode(ISD::ANY_EXTEND, dl, NOutVT,
                          BitConvertToInteger(GetScalarizedVector(InOp)));
     break;
+  case TargetLowering::TypeScalarizeScalableVector:
+    report_fatal_error("Scalarization of scalable vectors is not supported.");
   case TargetLowering::TypeSplitVector: {
     if (!NOutVT.isVector()) {
       // For example, i32 = BITCAST v2i16 on alpha.  Convert the split

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index 484f62668838..9885110d64f9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -245,6 +245,9 @@ bool DAGTypeLegalizer::run() {
       case TargetLowering::TypeLegal:
         LLVM_DEBUG(dbgs() << "Legal result type\n");
         break;
+      case TargetLowering::TypeScalarizeScalableVector:
+        report_fatal_error(
+            "Scalarization of scalable vectors is not supported.");
       // The following calls must take care of *all* of the node's results,
       // not just the illegal result they were passed (this includes results
       // with a legal type).  Results can be remapped using ReplaceValueWith,
@@ -307,6 +310,9 @@ bool DAGTypeLegalizer::run() {
       case TargetLowering::TypeLegal:
         LLVM_DEBUG(dbgs() << "Legal operand\n");
         continue;
+      case TargetLowering::TypeScalarizeScalableVector:
+        report_fatal_error(
+            "Scalarization of scalable vectors is not supported.");
       // The following calls must either replace all of the node's results
       // using ReplaceValueWith, and return "false"; or update the node's
       // operands in place, and return "true".

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypesGeneric.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypesGeneric.cpp
index 3c1f8e61b531..666f128a4cc2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypesGeneric.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypesGeneric.cpp
@@ -83,6 +83,8 @@ void DAGTypeLegalizer::ExpandRes_BITCAST(SDNode *N, SDValue &Lo, SDValue &Hi) {
       Lo = DAG.getNode(ISD::BITCAST, dl, NOutVT, Lo);
       Hi = DAG.getNode(ISD::BITCAST, dl, NOutVT, Hi);
       return;
+    case TargetLowering::TypeScalarizeScalableVector:
+      report_fatal_error("Scalarization of scalable vectors is not supported.");
     case TargetLowering::TypeWidenVector: {
       assert(!(InVT.getVectorNumElements() & 1) && "Unsupported BITCAST");
       InOp = GetWidenedVector(InOp);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index ff2c8d3a8db2..9ebf4ea9637c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1063,6 +1063,8 @@ void DAGTypeLegalizer::SplitVecRes_BITCAST(SDNode *N, SDValue &Lo,
     Lo = DAG.getNode(ISD::BITCAST, dl, LoVT, Lo);
     Hi = DAG.getNode(ISD::BITCAST, dl, HiVT, Hi);
     return;
+  case TargetLowering::TypeScalarizeScalableVector:
+    report_fatal_error("Scalarization of scalable vectors is not supported.");
   }
 
   // In the general case, convert the input to an integer and split it by hand.
@@ -3465,6 +3467,8 @@ SDValue DAGTypeLegalizer::WidenVecRes_BITCAST(SDNode *N) {
   switch (getTypeAction(InVT)) {
   case TargetLowering::TypeLegal:
     break;
+  case TargetLowering::TypeScalarizeScalableVector:
+    report_fatal_error("Scalarization of scalable vectors is not supported.");
   case TargetLowering::TypePromoteInteger: {
     // If the incoming type is a vector that is being promoted, then
     // we know that the elements are arranged 
diff erently and that we

diff  --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index 3c8fc5a7e05b..4a304ae99653 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -823,9 +823,7 @@ TargetLoweringBase::getTypeConversion(LLVMContext &Context, EVT VT) const {
            "Promote may not follow Expand or Promote");
 
     if (LA == TypeSplitVector)
-      return LegalizeKind(LA,
-                          EVT::getVectorVT(Context, SVT.getVectorElementType(),
-                                           SVT.getVectorNumElements() / 2));
+      return LegalizeKind(LA, SVT.getHalfNumVectorElementsVT());
     if (LA == TypeScalarizeVector)
       return LegalizeKind(LA, SVT.getVectorElementType());
     return LegalizeKind(LA, NVT);
@@ -852,13 +850,16 @@ TargetLoweringBase::getTypeConversion(LLVMContext &Context, EVT VT) const {
   }
 
   // Handle vector types.
-  unsigned NumElts = VT.getVectorNumElements();
+  ElementCount NumElts = VT.getVectorElementCount();
   EVT EltVT = VT.getVectorElementType();
 
   // Vectors with only one element are always scalarized.
   if (NumElts == 1)
     return LegalizeKind(TypeScalarizeVector, EltVT);
 
+  if (VT.getVectorElementCount() == ElementCount(1, true))
+    report_fatal_error("Cannot legalize this vector");
+
   // Try to widen vector elements until the element type is a power of two and
   // promote it to a legal type later on, for example:
   // <3 x i8> -> <4 x i8> -> <4 x i32>
@@ -866,7 +867,7 @@ TargetLoweringBase::getTypeConversion(LLVMContext &Context, EVT VT) const {
     // Vectors with a number of elements that is not a power of two are always
     // widened, for example <3 x i8> -> <4 x i8>.
     if (!VT.isPow2VectorType()) {
-      NumElts = (unsigned)NextPowerOf2(NumElts);
+      NumElts = NumElts.NextPowerOf2();
       EVT NVT = EVT::getVectorVT(Context, EltVT, NumElts);
       return LegalizeKind(TypeWidenVector, NVT);
     }
@@ -915,7 +916,7 @@ TargetLoweringBase::getTypeConversion(LLVMContext &Context, EVT VT) const {
   // If there is no wider legal type, split the vector.
   while (true) {
     // Round up to the next power of 2.
-    NumElts = (unsigned)NextPowerOf2(NumElts);
+    NumElts = NumElts.NextPowerOf2();
 
     // If there is no simple vector type with this many elements then there
     // cannot be a larger legal vector type.  Note that this assumes that
@@ -938,7 +939,7 @@ TargetLoweringBase::getTypeConversion(LLVMContext &Context, EVT VT) const {
   }
 
   // Vectors with illegal element types are expanded.
-  EVT NVT = EVT::getVectorVT(Context, EltVT, VT.getVectorNumElements() / 2);
+  EVT NVT = EVT::getVectorVT(Context, EltVT, VT.getVectorElementCount() / 2);
   return LegalizeKind(TypeSplitVector, NVT);
 }
 
@@ -1257,7 +1258,7 @@ void TargetLoweringBase::computeRegisterProperties(
       continue;
 
     MVT EltVT = VT.getVectorElementType();
-    unsigned NElts = VT.getVectorNumElements();
+    ElementCount EC = VT.getVectorElementCount();
     bool IsLegalWiderType = false;
     bool IsScalable = VT.isScalableVector();
     LegalizeTypeAction PreferredAction = getPreferredVectorAction(VT);
@@ -1274,8 +1275,7 @@ void TargetLoweringBase::computeRegisterProperties(
         // Promote vectors of integers to vectors with the same number
         // of elements, with a wider element type.
         if (SVT.getScalarSizeInBits() > EltVT.getSizeInBits() &&
-            SVT.getVectorNumElements() == NElts &&
-            SVT.isScalableVector() == IsScalable && isTypeLegal(SVT)) {
+            SVT.getVectorElementCount() == EC && isTypeLegal(SVT)) {
           TransformToType[i] = SVT;
           RegisterTypeForVT[i] = SVT;
           NumRegistersForVT[i] = 1;
@@ -1290,13 +1290,13 @@ void TargetLoweringBase::computeRegisterProperties(
     }
 
     case TypeWidenVector:
-      if (isPowerOf2_32(NElts)) {
+      if (isPowerOf2_32(EC.Min)) {
         // Try to widen the vector.
         for (unsigned nVT = i + 1; nVT <= MVT::LAST_VECTOR_VALUETYPE; ++nVT) {
           MVT SVT = (MVT::SimpleValueType) nVT;
-          if (SVT.getVectorElementType() == EltVT
-              && SVT.getVectorNumElements() > NElts
-              && SVT.isScalableVector() == IsScalable && isTypeLegal(SVT)) {
+          if (SVT.getVectorElementType() == EltVT &&
+              SVT.isScalableVector() == IsScalable &&
+              SVT.getVectorElementCount().Min > EC.Min && isTypeLegal(SVT)) {
             TransformToType[i] = SVT;
             RegisterTypeForVT[i] = SVT;
             NumRegistersForVT[i] = 1;
@@ -1340,10 +1340,12 @@ void TargetLoweringBase::computeRegisterProperties(
           ValueTypeActions.setTypeAction(VT, TypeScalarizeVector);
         else if (PreferredAction == TypeSplitVector)
           ValueTypeActions.setTypeAction(VT, TypeSplitVector);
+        else if (EC.Min > 1)
+          ValueTypeActions.setTypeAction(VT, TypeSplitVector);
         else
-          // Set type action according to the number of elements.
-          ValueTypeActions.setTypeAction(VT, NElts == 1 ? TypeScalarizeVector
-                                                        : TypeSplitVector);
+          ValueTypeActions.setTypeAction(VT, EC.Scalable
+                                                 ? TypeScalarizeScalableVector
+                                                 : TypeScalarizeVector);
       } else {
         TransformToType[i] = NVT;
         ValueTypeActions.setTypeAction(VT, TypeWidenVector);

diff  --git a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
index 848cbc079002..42d47faa1003 100644
--- a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
+++ b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp
@@ -17,9 +17,7 @@
 #include "llvm/Target/TargetMachine.h"
 #include "gtest/gtest.h"
 
-using namespace llvm;
-
-namespace {
+namespace llvm {
 
 class AArch64SelectionDAGTest : public testing::Test {
 protected:
@@ -41,8 +39,8 @@ class AArch64SelectionDAGTest : public testing::Test {
       return;
 
     TargetOptions Options;
-    TM = std::unique_ptr<LLVMTargetMachine>(static_cast<LLVMTargetMachine*>(
-        T->createTargetMachine("AArch64", "", "", Options, None, None,
+    TM = std::unique_ptr<LLVMTargetMachine>(static_cast<LLVMTargetMachine *>(
+        T->createTargetMachine("AArch64", "", "+sve", Options, None, None,
                                CodeGenOpt::Aggressive)));
     if (!TM)
       return;
@@ -69,6 +67,14 @@ class AArch64SelectionDAGTest : public testing::Test {
     DAG->init(*MF, ORE, nullptr, nullptr, nullptr, nullptr, nullptr);
   }
 
+  TargetLoweringBase::LegalizeTypeAction getTypeAction(EVT VT) {
+    return DAG->getTargetLoweringInfo().getTypeAction(Context, VT);
+  }
+
+  EVT getTypeToTransformTo(EVT VT) {
+    return DAG->getTargetLoweringInfo().getTypeToTransformTo(Context, VT);
+  }
+
   LLVMContext Context;
   std::unique_ptr<LLVMTargetMachine> TM;
   std::unique_ptr<Module> M;
@@ -377,4 +383,59 @@ TEST_F(AArch64SelectionDAGTest, getSplatSourceVector_Scalable_ADD_of_SPLAT_VECTO
   EXPECT_EQ(SplatIdx, 0);
 }
 
-} // end anonymous namespace
+TEST_F(AArch64SelectionDAGTest, getTypeConversion_SplitScalableMVT) {
+  if (!TM)
+    return;
+
+  MVT VT = MVT::nxv4i64;
+  EXPECT_EQ(getTypeAction(VT), TargetLoweringBase::TypeSplitVector);
+  ASSERT_TRUE(getTypeToTransformTo(VT).isScalableVector());
+}
+
+TEST_F(AArch64SelectionDAGTest, getTypeConversion_PromoteScalableMVT) {
+  if (!TM)
+    return;
+
+  MVT VT = MVT::nxv2i32;
+  EXPECT_EQ(getTypeAction(VT), TargetLoweringBase::TypePromoteInteger);
+  ASSERT_TRUE(getTypeToTransformTo(VT).isScalableVector());
+}
+
+TEST_F(AArch64SelectionDAGTest, getTypeConversion_NoScalarizeMVT_nxv1f32) {
+  if (!TM)
+    return;
+
+  MVT VT = MVT::nxv1f32;
+  EXPECT_NE(getTypeAction(VT), TargetLoweringBase::TypeScalarizeVector);
+  ASSERT_TRUE(getTypeToTransformTo(VT).isScalableVector());
+}
+
+TEST_F(AArch64SelectionDAGTest, getTypeConversion_SplitScalableEVT) {
+  if (!TM)
+    return;
+
+  EVT VT = EVT::getVectorVT(Context, MVT::i64, 256, true);
+  EXPECT_EQ(getTypeAction(VT), TargetLoweringBase::TypeSplitVector);
+  EXPECT_EQ(getTypeToTransformTo(VT), VT.getHalfNumVectorElementsVT(Context));
+}
+
+TEST_F(AArch64SelectionDAGTest, getTypeConversion_WidenScalableEVT) {
+  if (!TM)
+    return;
+
+  EVT FromVT = EVT::getVectorVT(Context, MVT::i64, 6, true);
+  EVT ToVT = EVT::getVectorVT(Context, MVT::i64, 8, true);
+
+  EXPECT_EQ(getTypeAction(FromVT), TargetLoweringBase::TypeWidenVector);
+  EXPECT_EQ(getTypeToTransformTo(FromVT), ToVT);
+}
+
+TEST_F(AArch64SelectionDAGTest, getTypeConversion_NoScalarizeEVT_nxv1f128) {
+  if (!TM)
+    return;
+
+  EVT FromVT = EVT::getVectorVT(Context, MVT::f128, 1, true);
+  EXPECT_DEATH(getTypeAction(FromVT), "Cannot legalize this vector");
+}
+
+} // end namespace llvm


        


More information about the llvm-commits mailing list