[llvm] 90b9c49 - [llvm] Expose type and element count-related APIs on TensorSpec
Mircea Trofin via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 4 17:32:26 PDT 2020
Author: Mircea Trofin
Date: 2020-08-04T17:32:16-07:00
New Revision: 90b9c49ca6477a85e69018967c0a4d4d38ee6e72
URL: https://github.com/llvm/llvm-project/commit/90b9c49ca6477a85e69018967c0a4d4d38ee6e72
DIFF: https://github.com/llvm/llvm-project/commit/90b9c49ca6477a85e69018967c0a4d4d38ee6e72.diff
LOG: [llvm] Expose type and element count-related APIs on TensorSpec
Added a mechanism to check the element type, get the total element
count, and the size of an element.
Differential Revision: https://reviews.llvm.org/D85250
Added:
Modified:
llvm/include/llvm/Analysis/Utils/TFUtils.h
llvm/lib/Analysis/TFUtils.cpp
llvm/unittests/Analysis/TFUtilsTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h
index d4450276a22e..681560e45335 100644
--- a/llvm/include/llvm/Analysis/Utils/TFUtils.h
+++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h
@@ -66,10 +66,18 @@ class TensorSpec final {
bool operator!=(const TensorSpec &Other) const { return !(*this == Other); }
+ /// Get the number of elements in a tensor with this shape.
+ size_t getElementCount() const { return ElementCount; }
+ /// Get the size, in bytes, of one element.
+ size_t getElementByteSize() const;
+
+ template <typename T> bool isElementType() const {
+ return getDataType<T>() == TypeIndex;
+ }
+
private:
TensorSpec(const std::string &Name, int Port, int TypeIndex,
- const std::vector<int64_t> &Shape)
- : Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape) {}
+ const std::vector<int64_t> &Shape);
template <typename T> static int getDataType() {
llvm_unreachable("Undefined tensor type");
@@ -79,6 +87,7 @@ class TensorSpec final {
int Port = 0;
int TypeIndex = 0;
std::vector<int64_t> Shape;
+ size_t ElementCount = 0;
};
Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
diff --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp
index 8fd4011e6cd4..b1be027dc940 100644
--- a/llvm/lib/Analysis/TFUtils.cpp
+++ b/llvm/lib/Analysis/TFUtils.cpp
@@ -24,6 +24,7 @@
#include "tensorflow/c/c_api_experimental.h"
#include <cassert>
+#include <numeric>
using namespace llvm;
@@ -84,6 +85,16 @@ class EvaluationResultImpl {
std::vector<TF_Tensor *> Output;
};
+size_t TensorSpec::getElementByteSize() const {
+ return TF_DataTypeSize(static_cast<TF_DataType>(TypeIndex));
+}
+
+TensorSpec::TensorSpec(const std::string &Name, int Port, int TypeIndex,
+ const std::vector<int64_t> &Shape)
+ : Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape),
+ ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
+ std::multiplies<int64_t>())) {}
+
Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
const json::Value &Value) {
auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {
diff --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp
index abdf2b2b9784..9e4f2c7faf71 100644
--- a/llvm/unittests/Analysis/TFUtilsTest.cpp
+++ b/llvm/unittests/Analysis/TFUtilsTest.cpp
@@ -123,3 +123,18 @@ TEST(TFUtilsTest, JSONParsingInvalidTensorType) {
auto Spec = getTensorSpecFromJSON(Ctx, *Value);
EXPECT_FALSE(Spec.hasValue());
}
+
+TEST(TFUtilsTest, TensorSpecSizesAndTypes) {
+ auto Spec1D = TensorSpec::createSpec<int16_t>("Hi1", {1});
+ auto Spec2D = TensorSpec::createSpec<int16_t>("Hi2", {1, 1});
+ auto Spec1DLarge = TensorSpec::createSpec<float>("Hi3", {10});
+ auto Spec3DLarge = TensorSpec::createSpec<float>("Hi3", {2, 4, 10});
+ EXPECT_TRUE(Spec1D.isElementType<int16_t>());
+ EXPECT_FALSE(Spec3DLarge.isElementType<double>());
+ EXPECT_EQ(Spec1D.getElementCount(), 1);
+ EXPECT_EQ(Spec2D.getElementCount(), 1);
+ EXPECT_EQ(Spec1DLarge.getElementCount(), 10);
+ EXPECT_EQ(Spec3DLarge.getElementCount(), 80);
+ EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float));
+ EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t));
+}
\ No newline at end of file
More information about the llvm-commits
mailing list