[llvm] b18c41c - [TFUtils] Expose untyped accessor to evaluation result tensors

Mircea Trofin via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 5 10:22:54 PDT 2020


Author: Mircea Trofin
Date: 2020-08-05T10:22:45-07:00
New Revision: b18c41c66fd16bde1a0a80a94f03815bc58dcc5a

URL: https://github.com/llvm/llvm-project/commit/b18c41c66fd16bde1a0a80a94f03815bc58dcc5a
DIFF: https://github.com/llvm/llvm-project/commit/b18c41c66fd16bde1a0a80a94f03815bc58dcc5a.diff

LOG: [TFUtils] Expose untyped accessor to evaluation result tensors

These were implementation detail, but become necessary for generic data
copying.

Also added const variations to them, and move assignment, since we had a
move ctor (and the move assignment helps in a subsequent patch).

Differential Revision: https://reviews.llvm.org/D85262

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/Utils/TFUtils.h
    llvm/lib/Analysis/TFUtils.cpp
    llvm/unittests/Analysis/TFUtilsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h
index 681560e45335..a6cfb16113c1 100644
--- a/llvm/include/llvm/Analysis/Utils/TFUtils.h
+++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h
@@ -101,18 +101,29 @@ class TFModelEvaluator final {
   class EvaluationResult {
   public:
     EvaluationResult(const EvaluationResult &) = delete;
+    EvaluationResult &operator=(const EvaluationResult &Other) = delete;
+
     EvaluationResult(EvaluationResult &&Other);
+    EvaluationResult &operator=(EvaluationResult &&Other);
+
     ~EvaluationResult();
 
-    /// Get a pointer to the first element of the tensor at Index.
+    /// Get a (const) pointer to the first element of the tensor at Index.
     template <typename T> T *getTensorValue(size_t Index) {
       return static_cast<T *>(getUntypedTensorValue(Index));
     }
 
+    template <typename T> const T *getTensorValue(size_t Index) const {
+      return static_cast<T *>(getUntypedTensorValue(Index));
+    }
+
+    /// Get a (const) pointer to the untyped data of the tensor.
+    void *getUntypedTensorValue(size_t Index);
+    const void *getUntypedTensorValue(size_t Index) const;
+
   private:
     friend class TFModelEvaluator;
     EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl);
-    void *getUntypedTensorValue(size_t Index);
     std::unique_ptr<EvaluationResultImpl> Impl;
   };
 

diff  --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp
index b1be027dc940..99b63305121d 100644
--- a/llvm/lib/Analysis/TFUtils.cpp
+++ b/llvm/lib/Analysis/TFUtils.cpp
@@ -292,10 +292,21 @@ TFModelEvaluator::EvaluationResult::EvaluationResult(
 TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)
     : Impl(std::move(Other.Impl)) {}
 
+TFModelEvaluator::EvaluationResult &
+TFModelEvaluator::EvaluationResult::operator=(EvaluationResult &&Other) {
+  Impl = std::move(Other.Impl);
+  return *this;
+}
+
 void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
   return TF_TensorData(Impl->getOutput()[Index]);
 }
 
+const void *
+TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
+  return TF_TensorData(Impl->getOutput()[Index]);
+}
+
 #define TFUTILS_GETDATATYPE_IMPL(T, S, E)                                      \
   template <> int TensorSpec::getDataType<T>() { return TF_##E; }
 

diff  --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp
index 9e4f2c7faf71..c33a5fd859a4 100644
--- a/llvm/unittests/Analysis/TFUtilsTest.cpp
+++ b/llvm/unittests/Analysis/TFUtilsTest.cpp
@@ -56,6 +56,8 @@ TEST(TFUtilsTest, LoadAndExecuteTest) {
     EXPECT_TRUE(ER.hasValue());
     float Ret = *ER->getTensorValue<float>(0);
     EXPECT_EQ(static_cast<size_t>(Ret), 80);
+    EXPECT_EQ(ER->getUntypedTensorValue(0),
+              reinterpret_cast<const void *>(ER->getTensorValue<float>(0)));
   }
   // The input vector should be unchanged
   for (auto I = 0; I < KnownSize; ++I) {
@@ -137,4 +139,4 @@ TEST(TFUtilsTest, TensorSpecSizesAndTypes) {
   EXPECT_EQ(Spec3DLarge.getElementCount(), 80);
   EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float));
   EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t));
-}
\ No newline at end of file
+}


        


More information about the llvm-commits mailing list