[llvm] e4794ff - [mlgo][nfc] Decouple TensorSpec from tensorflow.

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 21 15:37:11 PDT 2022


Author: Mircea Trofin
Date: 2022-04-21T15:37:01-07:00
New Revision: e4794ff5c685997c55f4308fac30a08eb8769758

URL: https://github.com/llvm/llvm-project/commit/e4794ff5c685997c55f4308fac30a08eb8769758
DIFF: https://github.com/llvm/llvm-project/commit/e4794ff5c685997c55f4308fac30a08eb8769758.diff

LOG: [mlgo][nfc] Decouple TensorSpec from tensorflow.

The motivation is twofold:

1) Allow plugging in a different training-time evaluator, e.g.
   TFLite-based, etc.

2) Allow using TensorSpec for AOT, too, to support evolution: we start
   by extracting a superset of the features currently supported by a
   model. For the tensors the model does not support, we just return a
   valid, but useless, buffer. This makes using a 'smaller' model (less
   supported tensors) transparent to the compiler. The key is to
   dimension the buffer appropriately, and we already have TensorSpec
   modeling that info.

The only coupling was due to the reliance of a TF internal API for
getting the element size, but for the types we are interested in,
`sizeof` is sufficient.

A subsequent change will yank out TensorSpec in its own module.

Differential Revision: https://reviews.llvm.org/D124045

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/Utils/TFUtils.h
    llvm/lib/Analysis/TFUtils.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h
index 785b9fe949a52..386f93333067b 100644
--- a/llvm/include/llvm/Analysis/Utils/TFUtils.h
+++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h
@@ -46,23 +46,48 @@ class EvaluationResultImpl;
 ///
 /// TensorSpec is used to set up a TFModelEvaluator by describing the expected
 /// inputs and outputs.
+
+/// Known tensor types. The left part is the C type, the right is a name we
+/// can use to identify the type (to implement TensorSpec equality checks), and
+/// to use, if needed, when mapping to an underlying evaluator's type system.
+/// The main requirement is that the C type we use has the same size and
+/// encoding (e.g. endian-ness) as the one used by the evaluator.
+#define SUPPORTED_TENSOR_TYPES(M)                                              \
+  M(float, Float)                                                              \
+  M(double, Double)                                                            \
+  M(int8_t, Int8)                                                              \
+  M(uint8_t, UInt8)                                                            \
+  M(int16_t, Int16)                                                            \
+  M(uint16_t, UInt16)                                                          \
+  M(int32_t, Int32)                                                            \
+  M(uint32_t, UInt32)                                                          \
+  M(int64_t, Int64)                                                            \
+  M(uint64_t, UInt64)
+
+enum class TensorType {
+  Invalid,
+#define _TENSOR_TYPE_ENUM_MEMBERS(_, Name) Name,
+  SUPPORTED_TENSOR_TYPES(_TENSOR_TYPE_ENUM_MEMBERS)
+#undef _TENSOR_TYPE_ENUM_MEMBERS
+};
+
 class TensorSpec final {
 public:
   template <typename T>
   static TensorSpec createSpec(const std::string &Name,
                                const std::vector<int64_t> &Shape,
                                int Port = 0) {
-    return TensorSpec(Name, Port, getDataType<T>(), Shape);
+    return TensorSpec(Name, Port, getDataType<T>(), sizeof(T), Shape);
   }
 
   const std::string &name() const { return Name; }
   int port() const { return Port; }
-  int typeIndex() const { return TypeIndex; }
+  TensorType type() const { return Type; }
   const std::vector<int64_t> &shape() const { return Shape; }
 
   bool operator==(const TensorSpec &Other) const {
-    return Name == Other.Name && Port == Other.Port &&
-           TypeIndex == Other.TypeIndex && Shape == Other.Shape;
+    return Name == Other.Name && Port == Other.Port && Type == Other.Type &&
+           Shape == Other.Shape;
   }
 
   bool operator!=(const TensorSpec &Other) const { return !(*this == Other); }
@@ -70,25 +95,24 @@ class TensorSpec final {
   /// Get the number of elements in a tensor with this shape.
   size_t getElementCount() const { return ElementCount; }
   /// Get the size, in bytes, of one element.
-  size_t getElementByteSize() const;
+  size_t getElementByteSize() const { return ElementSize; }
 
   template <typename T> bool isElementType() const {
-    return getDataType<T>() == TypeIndex;
+    return getDataType<T>() == Type;
   }
 
 private:
-  TensorSpec(const std::string &Name, int Port, int TypeIndex,
-             const std::vector<int64_t> &Shape);
+  TensorSpec(const std::string &Name, int Port, TensorType Type,
+             size_t ElementSize, const std::vector<int64_t> &Shape);
 
-  template <typename T> static int getDataType() {
-    llvm_unreachable("Undefined tensor type");
-  }
+  template <typename T> static TensorType getDataType();
 
   std::string Name;
   int Port = 0;
-  int TypeIndex = 0;
+  TensorType Type = TensorType::Invalid;
   std::vector<int64_t> Shape;
   size_t ElementCount = 0;
+  size_t ElementSize = 0;
 };
 
 /// Construct a TensorSpec from a JSON dictionary of the form:
@@ -262,25 +286,9 @@ class TFModelEvaluator final {
   std::unique_ptr<TFModelEvaluatorImpl> Impl;
 };
 
-/// List of supported types, as a pair:
-/// - C++ type
-/// - enum name (implementation-specific)
-#define TFUTILS_SUPPORTED_TYPES(M)                                             \
-  M(float, TF_FLOAT)                                                           \
-  M(double, TF_DOUBLE)                                                         \
-  M(int8_t, TF_INT8)                                                           \
-  M(uint8_t, TF_UINT8)                                                         \
-  M(int16_t, TF_INT16)                                                         \
-  M(uint16_t, TF_UINT16)                                                       \
-  M(int32_t, TF_INT32)                                                         \
-  M(uint32_t, TF_UINT32)                                                       \
-  M(int64_t, TF_INT64)                                                         \
-  M(uint64_t, TF_UINT64)
-
-#define TFUTILS_GETDATATYPE_DEF(T, E)                                          \
-  template <> int TensorSpec::getDataType<T>();
-
-TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_DEF)
+#define TFUTILS_GETDATATYPE_DEF(T, Name)                                       \
+  template <> TensorType TensorSpec::getDataType<T>();
+SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_DEF)
 
 #undef TFUTILS_GETDATATYPE_DEF
 } // namespace llvm

diff  --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp
index 26bc63983b4ee..ea2308a443e9b 100644
--- a/llvm/lib/Analysis/TFUtils.cpp
+++ b/llvm/lib/Analysis/TFUtils.cpp
@@ -82,6 +82,33 @@ void serialize(const Message &SE, std::string *OutStr) {
     *OutStr = SE.SerializeAsString();
   }
 }
+
+int getTFTypeIndex(TensorType TType) {
+  switch (TType) {
+  case TensorType::Double:
+    return TF_DOUBLE;
+  case TensorType::Float:
+    return TF_FLOAT;
+  case TensorType::Int8:
+    return TF_INT8;
+  case TensorType::UInt8:
+    return TF_UINT8;
+  case TensorType::Int16:
+    return TF_INT16;
+  case TensorType::UInt16:
+    return TF_UINT16;
+  case TensorType::Int32:
+    return TF_INT32;
+  case TensorType::UInt32:
+    return TF_UINT32;
+  case TensorType::Int64:
+    return TF_INT64;
+  case TensorType::UInt64:
+    return TF_UINT64;
+  case TensorType::Invalid:
+    llvm_unreachable("Unknown tensor type");
+  }
+}
 } // namespace
 
 namespace llvm {
@@ -105,15 +132,12 @@ class EvaluationResultImpl {
   std::vector<TF_Tensor *> Output;
 };
 
-size_t TensorSpec::getElementByteSize() const {
-  return TF_DataTypeSize(static_cast<TF_DataType>(TypeIndex));
-}
-
-TensorSpec::TensorSpec(const std::string &Name, int Port, int TypeIndex,
-                       const std::vector<int64_t> &Shape)
-    : Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape),
+TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
+                       size_t ElementSize, const std::vector<int64_t> &Shape)
+    : Name(Name), Port(Port), Type(Type), Shape(Shape),
       ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
-                                   std::multiplies<int64_t>())) {}
+                                   std::multiplies<int64_t>())),
+      ElementSize(ElementSize) {}
 
 Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
                                            const json::Value &Value) {
@@ -147,7 +171,7 @@ Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
 #define PARSE_TYPE(T, E)                                                       \
   if (TensorType == #T)                                                        \
     return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
-  TFUTILS_SUPPORTED_TYPES(PARSE_TYPE)
+  SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
 #undef PARSE_TYPE
   return None;
 }
@@ -390,7 +414,7 @@ TFModelEvaluatorImpl::TFModelEvaluatorImpl(
         InputSpec.port()};
     if (!checkReportAndInvalidate(InputFeed[I], InputSpec))
       return;
-    initInput(I, static_cast<TF_DataType>(InputSpec.typeIndex()),
+    initInput(I, static_cast<TF_DataType>(getTFTypeIndex(InputSpec.type())),
               InputSpec.shape());
   }
   for (size_t I = 0; I < OutputSpecsSize; ++I) {
@@ -496,9 +520,9 @@ TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
 }
 
 #define TFUTILS_GETDATATYPE_IMPL(T, E)                                         \
-  template <> int TensorSpec::getDataType<T>() { return E; }
+  template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
 
-TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL)
+SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)
 
 #undef TFUTILS_GETDATATYPE_IMPL
 


        


More information about the llvm-commits mailing list