[llvm] r300842 - [MVT][SVE] Scalable vector MVTs (3/3)

Amara Emerson via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 20 06:54:09 PDT 2017


Author: aemerson
Date: Thu Apr 20 08:54:09 2017
New Revision: 300842

URL: http://llvm.org/viewvc/llvm-project?rev=300842&view=rev
Log:
[MVT][SVE] Scalable vector MVTs (3/3)

Adds MVT::ElementCount to represent the length of a
vector which may be scalable, then adds helper functions
that work with it.

Patch by Graham Hunter.

Differential Revision: https://reviews.llvm.org/D32019


Added:
    llvm/trunk/unittests/CodeGen/ScalableVectorMVTsTest.cpp
Modified:
    llvm/trunk/include/llvm/CodeGen/MachineValueType.h
    llvm/trunk/include/llvm/CodeGen/ValueTypes.h
    llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
    llvm/trunk/unittests/CodeGen/CMakeLists.txt

Modified: llvm/trunk/include/llvm/CodeGen/MachineValueType.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/CodeGen/MachineValueType.h?rev=300842&r1=300841&r2=300842&view=diff
==============================================================================
--- llvm/trunk/include/llvm/CodeGen/MachineValueType.h (original)
+++ llvm/trunk/include/llvm/CodeGen/MachineValueType.h Thu Apr 20 08:54:09 2017
@@ -232,6 +232,42 @@ class MVT {
 
     SimpleValueType SimpleTy;
 
+
+    // A class to represent the number of elements in a vector
+    //
+    // For fixed-length vectors, the total number of elements is equal to 'Min'
+    // For scalable vectors, the total number of elements is a multiple of 'Min'
+    class ElementCount {
+    public:
+      unsigned Min;
+      bool Scalable;
+
+      ElementCount(unsigned Min, bool Scalable)
+      : Min(Min), Scalable(Scalable) {}
+
+      ElementCount operator*(unsigned RHS) {
+        return { Min * RHS, Scalable };
+      }
+
+      ElementCount& operator*=(unsigned RHS) {
+        Min *= RHS;
+        return *this;
+      }
+
+      ElementCount operator/(unsigned RHS) {
+        return { Min / RHS, Scalable };
+      }
+
+      ElementCount& operator/=(unsigned RHS) {
+        Min /= RHS;
+        return *this;
+      }
+
+      bool operator==(const ElementCount& RHS) {
+        return Min == RHS.Min && Scalable == RHS.Scalable;
+      }
+    };
+
     constexpr MVT() : SimpleTy(INVALID_SIMPLE_VALUE_TYPE) {}
     constexpr MVT(SimpleValueType SVT) : SimpleTy(SVT) {}
 
@@ -276,6 +312,15 @@ class MVT {
               SimpleTy <= MVT::LAST_VECTOR_VALUETYPE);
     }
 
+    /// Return true if this is a vector value type where the
+    /// runtime length is machine dependent
+    bool isScalableVector() const {
+      return ((SimpleTy >= MVT::FIRST_INTEGER_SCALABLE_VALUETYPE &&
+               SimpleTy <= MVT::LAST_INTEGER_SCALABLE_VALUETYPE) ||
+              (SimpleTy >= MVT::FIRST_FP_SCALABLE_VALUETYPE &&
+               SimpleTy <= MVT::LAST_FP_SCALABLE_VALUETYPE));
+    }
+
     /// Return true if this is a 16-bit vector type.
     bool is16BitVector() const {
       return (SimpleTy == MVT::v2i8  || SimpleTy == MVT::v1i16 ||
@@ -560,6 +605,10 @@ class MVT {
       }
     }
 
+    MVT::ElementCount getVectorElementCount() const {
+      return { getVectorNumElements(), isScalableVector() };
+    }
+
     unsigned getSizeInBits() const {
       switch (SimpleTy) {
       default:
@@ -837,6 +886,83 @@ class MVT {
       return (MVT::SimpleValueType)(MVT::INVALID_SIMPLE_VALUE_TYPE);
     }
 
+    static MVT getScalableVectorVT(MVT VT, unsigned NumElements) {
+      switch(VT.SimpleTy) {
+        default:
+          break;
+        case MVT::i1:
+          if (NumElements == 2)  return MVT::nxv2i1;
+          if (NumElements == 4)  return MVT::nxv4i1;
+          if (NumElements == 8)  return MVT::nxv8i1;
+          if (NumElements == 16) return MVT::nxv16i1;
+          if (NumElements == 32) return MVT::nxv32i1;
+          break;
+        case MVT::i8:
+          if (NumElements == 1)  return MVT::nxv1i8;
+          if (NumElements == 2)  return MVT::nxv2i8;
+          if (NumElements == 4)  return MVT::nxv4i8;
+          if (NumElements == 8)  return MVT::nxv8i8;
+          if (NumElements == 16) return MVT::nxv16i8;
+          if (NumElements == 32) return MVT::nxv32i8;
+          break;
+        case MVT::i16:
+          if (NumElements == 1)  return MVT::nxv1i16;
+          if (NumElements == 2)  return MVT::nxv2i16;
+          if (NumElements == 4)  return MVT::nxv4i16;
+          if (NumElements == 8)  return MVT::nxv8i16;
+          if (NumElements == 16) return MVT::nxv16i16;
+          if (NumElements == 32) return MVT::nxv32i16;
+          break;
+        case MVT::i32:
+          if (NumElements == 1)  return MVT::nxv1i32;
+          if (NumElements == 2)  return MVT::nxv2i32;
+          if (NumElements == 4)  return MVT::nxv4i32;
+          if (NumElements == 8)  return MVT::nxv8i32;
+          if (NumElements == 16) return MVT::nxv16i32;
+          if (NumElements == 32) return MVT::nxv32i32;
+          break;
+        case MVT::i64:
+          if (NumElements == 1)  return MVT::nxv1i64;
+          if (NumElements == 2)  return MVT::nxv2i64;
+          if (NumElements == 4)  return MVT::nxv4i64;
+          if (NumElements == 8)  return MVT::nxv8i64;
+          if (NumElements == 16) return MVT::nxv16i64;
+          if (NumElements == 32) return MVT::nxv32i64;
+          break;
+        case MVT::f16:
+          if (NumElements == 2)  return MVT::nxv2f16;
+          if (NumElements == 4)  return MVT::nxv4f16;
+          if (NumElements == 8)  return MVT::nxv8f16;
+          break;
+        case MVT::f32:
+          if (NumElements == 1)  return MVT::nxv1f32;
+          if (NumElements == 2)  return MVT::nxv2f32;
+          if (NumElements == 4)  return MVT::nxv4f32;
+          if (NumElements == 8)  return MVT::nxv8f32;
+          if (NumElements == 16) return MVT::nxv16f32;
+          break;
+        case MVT::f64:
+          if (NumElements == 1)  return MVT::nxv1f64;
+          if (NumElements == 2)  return MVT::nxv2f64;
+          if (NumElements == 4)  return MVT::nxv4f64;
+          if (NumElements == 8)  return MVT::nxv8f64;
+          break;
+      }
+      return (MVT::SimpleValueType)(MVT::INVALID_SIMPLE_VALUE_TYPE);
+    }
+
+    static MVT getVectorVT(MVT VT, unsigned NumElements, bool IsScalable) {
+      if (IsScalable)
+        return getScalableVectorVT(VT, NumElements);
+      return getVectorVT(VT, NumElements);
+    }
+
+    static MVT getVectorVT(MVT VT, MVT::ElementCount EC) {
+      if (EC.Scalable)
+        return getScalableVectorVT(VT, EC.Min);
+      return getVectorVT(VT, EC.Min);
+    }
+
     /// Return the value type corresponding to the specified type.  This returns
     /// all pointers as iPTR.  If HandleUnknown is true, unknown types are
     /// returned as Other, otherwise they are invalid.
@@ -887,6 +1013,14 @@ class MVT {
           MVT::FIRST_FP_VECTOR_VALUETYPE,
           (MVT::SimpleValueType)(MVT::LAST_FP_VECTOR_VALUETYPE + 1));
     }
+    static mvt_range integer_scalable_vector_valuetypes() {
+      return mvt_range(MVT::FIRST_INTEGER_SCALABLE_VALUETYPE,
+              (MVT::SimpleValueType)(MVT::LAST_INTEGER_SCALABLE_VALUETYPE + 1));
+    }
+    static mvt_range fp_scalable_vector_valuetypes() {
+      return mvt_range(MVT::FIRST_FP_SCALABLE_VALUETYPE,
+                   (MVT::SimpleValueType)(MVT::LAST_FP_SCALABLE_VALUETYPE + 1));
+    }
     /// @}
   };
 

Modified: llvm/trunk/include/llvm/CodeGen/ValueTypes.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/CodeGen/ValueTypes.h?rev=300842&r1=300841&r2=300842&view=diff
==============================================================================
--- llvm/trunk/include/llvm/CodeGen/ValueTypes.h (original)
+++ llvm/trunk/include/llvm/CodeGen/ValueTypes.h Thu Apr 20 08:54:09 2017
@@ -67,24 +67,41 @@ namespace llvm {
 
     /// Returns the EVT that represents a vector NumElements in length, where
     /// each element is of type VT.
-    static EVT getVectorVT(LLVMContext &Context, EVT VT, unsigned NumElements) {
-      MVT M = MVT::getVectorVT(VT.V, NumElements);
+    static EVT getVectorVT(LLVMContext &Context, EVT VT, unsigned NumElements,
+                           bool IsScalable = false) {
+      MVT M = MVT::getVectorVT(VT.V, NumElements, IsScalable);
       if (M.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE)
         return M;
+
+      assert(!IsScalable && "We don't support extended scalable types yet");
       return getExtendedVectorVT(Context, VT, NumElements);
     }
 
+    /// Returns the EVT that represents a vector EC.Min elements in length,
+    /// where each element is of type VT.
+    static EVT getVectorVT(LLVMContext &Context, EVT VT, MVT::ElementCount EC) {
+      MVT M = MVT::getVectorVT(VT.V, EC);
+      if (M.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE)
+        return M;
+      assert (!EC.Scalable && "We don't support extended scalable types yet");
+      return getExtendedVectorVT(Context, VT, EC.Min);
+    }
+
     /// Return a vector with the same number of elements as this vector, but
     /// with the element type converted to an integer type with the same
     /// bitwidth.
     EVT changeVectorElementTypeToInteger() const {
-      if (!isSimple())
+      if (!isSimple()) {
+        assert (!isScalableVector() &&
+                "We don't support extended scalable types yet");
         return changeExtendedVectorElementTypeToInteger();
+      }
       MVT EltTy = getSimpleVT().getVectorElementType();
       unsigned BitWidth = EltTy.getSizeInBits();
       MVT IntTy = MVT::getIntegerVT(BitWidth);
-      MVT VecTy = MVT::getVectorVT(IntTy, getVectorNumElements());
-      assert(VecTy.SimpleTy >= 0 &&
+      MVT VecTy = MVT::getVectorVT(IntTy, getVectorNumElements(),
+                                   isScalableVector());
+      assert(VecTy.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE &&
              "Simple vector VT not representable by simple integer vector VT!");
       return VecTy;
     }
@@ -132,6 +149,17 @@ namespace llvm {
       return isSimple() ? V.isVector() : isExtendedVector();
     }
 
+    /// Return true if this is a vector type where the runtime
+    /// length is machine dependent
+    bool isScalableVector() const {
+      // FIXME: We don't support extended scalable types yet, because the
+      // matching IR type doesn't exist. Once it has been added, this can
+      // be changed to call isExtendedScalableVector.
+      if (!isSimple())
+        return false;
+      return V.isScalableVector();
+    }
+
     /// Return true if this is a 16-bit vector type.
     bool is16BitVector() const {
       return isSimple() ? V.is16BitVector() : isExtended16BitVector();
@@ -247,6 +275,17 @@ namespace llvm {
       return getExtendedVectorNumElements();
     }
 
+    // Given a (possibly scalable) vector type, return the ElementCount
+    MVT::ElementCount getVectorElementCount() const {
+      assert((isVector()) && "Invalid vector type!");
+      if (isSimple())
+        return V.getVectorElementCount();
+
+      assert(!isScalableVector() &&
+             "We don't support extended scalable types yet");
+      return {getExtendedVectorNumElements(), false};
+    }
+
     /// Return the size of the specified value type in bits.
     unsigned getSizeInBits() const {
       if (isSimple())
@@ -301,7 +340,7 @@ namespace llvm {
     EVT widenIntegerVectorElementType(LLVMContext &Context) const {
       EVT EltVT = getVectorElementType();
       EltVT = EVT::getIntegerVT(Context, 2 * EltVT.getSizeInBits());
-      return EVT::getVectorVT(Context, EltVT, getVectorNumElements());
+      return EVT::getVectorVT(Context, EltVT, getVectorElementCount());
     }
 
     // Return a VT for a vector type with the same element type but
@@ -309,9 +348,8 @@ namespace llvm {
     // extended type.
     EVT getHalfNumVectorElementsVT(LLVMContext &Context) const {
       EVT EltVT = getVectorElementType();
-      auto EltCnt = getVectorNumElements();
-      assert(!(getVectorNumElements() & 1) &&
-             "Splitting vector, but not in half!");
+      auto EltCnt = getVectorElementCount();
+      assert(!(EltCnt.Min & 1) && "Splitting vector, but not in half!");
       return EVT::getVectorVT(Context, EltVT, EltCnt / 2);
     }
 
@@ -327,7 +365,8 @@ namespace llvm {
       if (!isPow2VectorType()) {
         unsigned NElts = getVectorNumElements();
         unsigned Pow2NElts = 1 <<  Log2_32_Ceil(NElts);
-        return EVT::getVectorVT(Context, getVectorElementType(), Pow2NElts);
+        return EVT::getVectorVT(Context, getVectorElementType(), Pow2NElts,
+                                isScalableVector());
       }
       else {
         return *this;

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp?rev=300842&r1=300841&r2=300842&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp Thu Apr 20 08:54:09 2017
@@ -925,9 +925,9 @@ SDValue DAGTypeLegalizer::BitConvertVect
   assert(Op.getValueType().isVector() && "Only applies to vectors!");
   unsigned EltWidth = Op.getScalarValueSizeInBits();
   EVT EltNVT = EVT::getIntegerVT(*DAG.getContext(), EltWidth);
-  unsigned NumElts = Op.getValueType().getVectorNumElements();
+  auto EltCnt = Op.getValueType().getVectorElementCount();
   return DAG.getNode(ISD::BITCAST, SDLoc(Op),
-                     EVT::getVectorVT(*DAG.getContext(), EltNVT, NumElts), Op);
+                     EVT::getVectorVT(*DAG.getContext(), EltNVT, EltCnt), Op);
 }
 
 SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op,

Modified: llvm/trunk/unittests/CodeGen/CMakeLists.txt
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/CodeGen/CMakeLists.txt?rev=300842&r1=300841&r2=300842&view=diff
==============================================================================
--- llvm/trunk/unittests/CodeGen/CMakeLists.txt (original)
+++ llvm/trunk/unittests/CodeGen/CMakeLists.txt Thu Apr 20 08:54:09 2017
@@ -9,6 +9,7 @@ set(CodeGenSources
   DIEHashTest.cpp
   LowLevelTypeTest.cpp
   MachineInstrBundleIteratorTest.cpp
+  ScalableVectorMVTsTest.cpp
   )
 
 add_llvm_unittest(CodeGenTests

Added: llvm/trunk/unittests/CodeGen/ScalableVectorMVTsTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/CodeGen/ScalableVectorMVTsTest.cpp?rev=300842&view=auto
==============================================================================
--- llvm/trunk/unittests/CodeGen/ScalableVectorMVTsTest.cpp (added)
+++ llvm/trunk/unittests/CodeGen/ScalableVectorMVTsTest.cpp Thu Apr 20 08:54:09 2017
@@ -0,0 +1,88 @@
+//===-------- llvm/unittest/CodeGen/ScalableVectorMVTsTest.cpp ------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/MachineValueType.h"
+#include "llvm/CodeGen/ValueTypes.h"
+#include "llvm/IR/LLVMContext.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+TEST(ScalableVectorMVTsTest, IntegerMVTs) {
+  for (auto VecTy : MVT::integer_scalable_vector_valuetypes()) {
+    ASSERT_TRUE(VecTy.isValid());
+    ASSERT_TRUE(VecTy.isInteger());
+    ASSERT_TRUE(VecTy.isVector());
+    ASSERT_TRUE(VecTy.isScalableVector());
+    ASSERT_TRUE(VecTy.getScalarType().isValid());
+
+    ASSERT_FALSE(VecTy.isFloatingPoint());
+  }
+}
+
+TEST(ScalableVectorMVTsTest, FloatMVTs) {
+  for (auto VecTy : MVT::fp_scalable_vector_valuetypes()) {
+    ASSERT_TRUE(VecTy.isValid());
+    ASSERT_TRUE(VecTy.isFloatingPoint());
+    ASSERT_TRUE(VecTy.isVector());
+    ASSERT_TRUE(VecTy.isScalableVector());
+    ASSERT_TRUE(VecTy.getScalarType().isValid());
+
+    ASSERT_FALSE(VecTy.isInteger());
+  }
+}
+
+TEST(ScalableVectorMVTsTest, HelperFuncs) {
+  LLVMContext Ctx;
+
+  // Create with scalable flag
+  EVT Vnx4i32 = EVT::getVectorVT(Ctx, MVT::i32, 4, /*Scalable=*/true);
+  ASSERT_TRUE(Vnx4i32.isScalableVector());
+
+  // Create with separate MVT::ElementCount
+  auto EltCnt = MVT::ElementCount(2, true);
+  EVT Vnx2i32 = EVT::getVectorVT(Ctx, MVT::i32, EltCnt);
+  ASSERT_TRUE(Vnx2i32.isScalableVector());
+
+  // Create with inline MVT::ElementCount
+  EVT Vnx2i64 = EVT::getVectorVT(Ctx, MVT::i64, {2, true});
+  ASSERT_TRUE(Vnx2i64.isScalableVector());
+
+  // Check that changing scalar types/element count works
+  EXPECT_EQ(Vnx2i32.widenIntegerVectorElementType(Ctx), Vnx2i64);
+  EXPECT_EQ(Vnx4i32.getHalfNumVectorElementsVT(Ctx), Vnx2i32);
+
+  // Check that overloaded '*' and '/' operators work
+  EXPECT_EQ(EVT::getVectorVT(Ctx, MVT::i64, EltCnt * 2), MVT::nxv4i64);
+  EXPECT_EQ(EVT::getVectorVT(Ctx, MVT::i64, EltCnt / 2), MVT::nxv1i64);
+
+  // Check that float->int conversion works
+  EVT Vnx2f64 = EVT::getVectorVT(Ctx, MVT::f64, {2, true});
+  EXPECT_EQ(Vnx2f64.changeTypeToInteger(), Vnx2i64);
+
+  // Check fields inside MVT::ElementCount
+  EltCnt = Vnx4i32.getVectorElementCount();
+  EXPECT_EQ(EltCnt.Min, 4);
+  ASSERT_TRUE(EltCnt.Scalable);
+
+  // Check that fixed-length vector types aren't scalable.
+  EVT V8i32 = EVT::getVectorVT(Ctx, MVT::i32, 8);
+  ASSERT_FALSE(V8i32.isScalableVector());
+  EVT V4f64 = EVT::getVectorVT(Ctx, MVT::f64, {4, false});
+  ASSERT_FALSE(V4f64.isScalableVector());
+
+  // Check that MVT::ElementCount works for fixed-length types.
+  EltCnt = V8i32.getVectorElementCount();
+  EXPECT_EQ(EltCnt.Min, 8);
+  ASSERT_FALSE(EltCnt.Scalable);
+}
+
+}




More information about the llvm-commits mailing list