[llvm] 4b1b109 - [llvm] Add a parser from JSON to TensorSpec

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 3 09:49:43 PDT 2020


Author: Mircea Trofin
Date: 2020-08-03T09:49:31-07:00
New Revision: 4b1b109c5126efc963cc19949df5201e40f1bcc1

URL: https://github.com/llvm/llvm-project/commit/4b1b109c5126efc963cc19949df5201e40f1bcc1
DIFF: https://github.com/llvm/llvm-project/commit/4b1b109c5126efc963cc19949df5201e40f1bcc1.diff

LOG: [llvm] Add a parser from JSON to TensorSpec

A JSON->TensorSpec utility we will use subsequently to specify
additional outputs needed for certain training scenarios.

Differential Revision: https://reviews.llvm.org/D84976

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/Utils/TFUtils.h
    llvm/lib/Analysis/TFUtils.cpp
    llvm/unittests/Analysis/TFUtilsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h
index 512f45bb5671..d4450276a22e 100644
--- a/llvm/include/llvm/Analysis/Utils/TFUtils.h
+++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h
@@ -13,6 +13,7 @@
 
 #ifdef LLVM_HAVE_TF_API
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/Support/JSON.h"
 
 #include <memory>
 #include <vector>
@@ -58,6 +59,13 @@ class TensorSpec final {
   int typeIndex() const { return TypeIndex; }
   const std::vector<int64_t> &shape() const { return Shape; }
 
+  bool operator==(const TensorSpec &Other) const {
+    return Name == Other.Name && Port == Other.Port &&
+           TypeIndex == Other.TypeIndex && Shape == Other.Shape;
+  }
+
+  bool operator!=(const TensorSpec &Other) const { return !(*this == Other); }
+
 private:
   TensorSpec(const std::string &Name, int Port, int TypeIndex,
              const std::vector<int64_t> &Shape)
@@ -73,6 +81,9 @@ class TensorSpec final {
   std::vector<int64_t> Shape;
 };
 
+Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
+                                           const json::Value &Value);
+
 class TFModelEvaluator final {
 public:
   /// The result of a model evaluation. Handles the lifetime of the output
@@ -124,17 +135,28 @@ class TFModelEvaluator final {
   std::unique_ptr<TFModelEvaluatorImpl> Impl;
 };
 
-template <> int TensorSpec::getDataType<float>();
-template <> int TensorSpec::getDataType<double>();
-template <> int TensorSpec::getDataType<int8_t>();
-template <> int TensorSpec::getDataType<uint8_t>();
-template <> int TensorSpec::getDataType<int16_t>();
-template <> int TensorSpec::getDataType<uint16_t>();
-template <> int TensorSpec::getDataType<int32_t>();
-template <> int TensorSpec::getDataType<uint32_t>();
-template <> int TensorSpec::getDataType<int64_t>();
-template <> int TensorSpec::getDataType<uint64_t>();
-
+/// List of supported types, as a triple:
+/// C++ type
+/// short name (for strings, for instance)
+/// capitalized short name (for enums, for instance)
+#define TFUTILS_SUPPORTED_TYPES(M)                                             \
+  M(float, float, FLOAT)                                                       \
+  M(double, double, DOUBLE)                                                    \
+  M(int8_t, int8, INT8)                                                        \
+  M(uint8_t, uint8, UINT8)                                                     \
+  M(int16_t, int16, INT16)                                                     \
+  M(uint16_t, uint16, UINT16)                                                  \
+  M(int32_t, int32, INT32)                                                     \
+  M(uint32_t, uint32, UINT32)                                                  \
+  M(int64_t, int64, INT64)                                                     \
+  M(uint64_t, uint64, UINT64)
+
+#define TFUTILS_GETDATATYPE_DEF(T, S, C)                                       \
+  template <> int TensorSpec::getDataType<T>();
+
+TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_DEF)
+
+#undef TFUTILS_GETDATATYPE_DEF
 } // namespace llvm
 
 #endif // LLVM_HAVE_TF_API

diff  --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp
index b0ff19857963..8fd4011e6cd4 100644
--- a/llvm/lib/Analysis/TFUtils.cpp
+++ b/llvm/lib/Analysis/TFUtils.cpp
@@ -13,9 +13,10 @@
 #include "llvm/Config/config.h"
 #if defined(LLVM_HAVE_TF_API)
 
-#include "llvm/Analysis/Utils/TFUtils.h"
 #include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/Utils/TFUtils.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/JSON.h"
 #include "llvm/Support/ManagedStatic.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -83,6 +84,41 @@ class EvaluationResultImpl {
   std::vector<TF_Tensor *> Output;
 };
 
+Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
+                                           const json::Value &Value) {
+  auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {
+    std::string S;
+    llvm::raw_string_ostream OS(S);
+    OS << Value;
+    Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
+    return None;
+  };
+  json::ObjectMapper Mapper(Value);
+  if (!Mapper)
+    return EmitError("Value is not a dict");
+
+  std::string TensorName;
+  int TensorPort = -1;
+  std::string TensorType;
+  std::vector<int64_t> TensorShape;
+
+  if (!Mapper.map<std::string>("name", TensorName))
+    return EmitError("'name' property not present or not a string");
+  if (!Mapper.map<std::string>("type", TensorType))
+    return EmitError("'type' property not present or not a string");
+  if (!Mapper.map<int>("port", TensorPort))
+    return EmitError("'port' property not present or not an int");
+  if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
+    return EmitError("'shape' property not present or not an int array");
+
+#define PARSE_TYPE(T, S, E)                                                    \
+  if (TensorType == #S)                                                        \
+    return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
+  TFUTILS_SUPPORTED_TYPES(PARSE_TYPE)
+#undef PARSE_TYPE
+  return None;
+}
+
 class TFModelEvaluatorImpl {
 public:
   TFModelEvaluatorImpl(StringRef SavedModelPath,
@@ -249,25 +285,12 @@ void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
   return TF_TensorData(Impl->getOutput()[Index]);
 }
 
-template <> int TensorSpec::getDataType<float>() { return TF_FLOAT; }
-
-template <> int TensorSpec::getDataType<double>() { return TF_DOUBLE; }
-
-template <> int TensorSpec::getDataType<int8_t>() { return TF_INT8; }
-
-template <> int TensorSpec::getDataType<uint8_t>() { return TF_UINT8; }
-
-template <> int TensorSpec::getDataType<int16_t>() { return TF_INT16; }
-
-template <> int TensorSpec::getDataType<uint16_t>() { return TF_UINT16; }
-
-template <> int TensorSpec::getDataType<int32_t>() { return TF_INT32; }
-
-template <> int TensorSpec::getDataType<uint32_t>() { return TF_UINT32; }
+#define TFUTILS_GETDATATYPE_IMPL(T, S, E)                                      \
+  template <> int TensorSpec::getDataType<T>() { return TF_##E; }
 
-template <> int TensorSpec::getDataType<int64_t>() { return TF_INT64; }
+TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL)
 
-template <> int TensorSpec::getDataType<uint64_t>() { return TF_UINT64; }
+#undef TFUTILS_GETDATATYPE_IMPL
 
 TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
 TFModelEvaluator::~TFModelEvaluator() {}

diff  --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp
index e96d34092c7e..abdf2b2b9784 100644
--- a/llvm/unittests/Analysis/TFUtilsTest.cpp
+++ b/llvm/unittests/Analysis/TFUtilsTest.cpp
@@ -94,3 +94,32 @@ TEST(TFUtilsTest, EvalError) {
   EXPECT_FALSE(ER.hasValue());
   EXPECT_FALSE(Evaluator.isValid());
 }
+
+TEST(TFUtilsTest, JSONParsing) {
+  auto Value = json::parse(
+      R"({"name": "tensor_name", 
+        "port": 2, 
+        "type": "int32", 
+        "shape":[1,4]
+        })");
+  EXPECT_TRUE(!!Value);
+  LLVMContext Ctx;
+  Optional<TensorSpec> Spec = getTensorSpecFromJSON(Ctx, *Value);
+  EXPECT_TRUE(Spec.hasValue());
+  EXPECT_EQ(*Spec, TensorSpec::createSpec<int32_t>("tensor_name", {1, 4}, 2));
+}
+
+TEST(TFUtilsTest, JSONParsingInvalidTensorType) {
+  auto Value = json::parse(
+      R"(
+        {"name": "tensor_name", 
+        "port": 2, 
+        "type": "no such type", 
+        "shape":[1,4]
+        }
+      )");
+  EXPECT_TRUE(!!Value);
+  LLVMContext Ctx;
+  auto Spec = getTensorSpecFromJSON(Ctx, *Value);
+  EXPECT_FALSE(Spec.hasValue());
+}


        


More information about the llvm-commits mailing list