diff options
author | Mircea Trofin <mtrofin@google.com> | 2020-07-30 12:44:07 -0700 |
---|---|---|
committer | Mircea Trofin <mtrofin@google.com> | 2020-08-03 09:49:31 -0700 |
commit | 4b1b109c5126efc963cc19949df5201e40f1bcc1 (patch) | |
tree | 0a661390534238081058d3334bbd1e28214fe6bb | |
parent | caf002c7be44cb6c54de5a1b19aa177f18b6b0c1 (diff) | |
download | llvm-4b1b109c5126efc963cc19949df5201e40f1bcc1.tar.gz |
[llvm] Add a parser from JSON to TensorSpec
A JSON->TensorSpec utility we will use subsequently to specify
additional outputs needed for certain training scenarios.
Differential Revision: https://reviews.llvm.org/D84976
-rw-r--r-- | llvm/include/llvm/Analysis/Utils/TFUtils.h | 44 | ||||
-rw-r--r-- | llvm/lib/Analysis/TFUtils.cpp | 59 | ||||
-rw-r--r-- | llvm/unittests/Analysis/TFUtilsTest.cpp | 29 |
3 files changed, 103 insertions, 29 deletions
diff --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h index 512f45bb5671..d4450276a22e 100644 --- a/llvm/include/llvm/Analysis/Utils/TFUtils.h +++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h @@ -13,6 +13,7 @@ #ifdef LLVM_HAVE_TF_API #include "llvm/IR/LLVMContext.h" +#include "llvm/Support/JSON.h" #include <memory> #include <vector> @@ -58,6 +59,13 @@ public: int typeIndex() const { return TypeIndex; } const std::vector<int64_t> &shape() const { return Shape; } + bool operator==(const TensorSpec &Other) const { + return Name == Other.Name && Port == Other.Port && + TypeIndex == Other.TypeIndex && Shape == Other.Shape; + } + + bool operator!=(const TensorSpec &Other) const { return !(*this == Other); } + private: TensorSpec(const std::string &Name, int Port, int TypeIndex, const std::vector<int64_t> &Shape) @@ -73,6 +81,9 @@ private: std::vector<int64_t> Shape; }; +Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx, + const json::Value &Value); + class TFModelEvaluator final { public: /// The result of a model evaluation. Handles the lifetime of the output @@ -124,17 +135,28 @@ private: std::unique_ptr<TFModelEvaluatorImpl> Impl; }; -template <> int TensorSpec::getDataType<float>(); -template <> int TensorSpec::getDataType<double>(); -template <> int TensorSpec::getDataType<int8_t>(); -template <> int TensorSpec::getDataType<uint8_t>(); -template <> int TensorSpec::getDataType<int16_t>(); -template <> int TensorSpec::getDataType<uint16_t>(); -template <> int TensorSpec::getDataType<int32_t>(); -template <> int TensorSpec::getDataType<uint32_t>(); -template <> int TensorSpec::getDataType<int64_t>(); -template <> int TensorSpec::getDataType<uint64_t>(); - +/// List of supported types, as a triple: +/// C++ type +/// short name (for strings, for instance) +/// capitalized short name (for enums, for instance) +#define TFUTILS_SUPPORTED_TYPES(M) \ + M(float, float, FLOAT) \ + M(double, double, DOUBLE) \ + M(int8_t, int8, INT8) \ + M(uint8_t, uint8, UINT8) \ + M(int16_t, int16, INT16) \ + M(uint16_t, uint16, UINT16) \ + M(int32_t, int32, INT32) \ + M(uint32_t, uint32, UINT32) \ + M(int64_t, int64, INT64) \ + M(uint64_t, uint64, UINT64) + +#define TFUTILS_GETDATATYPE_DEF(T, S, C) \ + template <> int TensorSpec::getDataType<T>(); + +TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_DEF) + +#undef TFUTILS_GETDATATYPE_DEF } // namespace llvm #endif // LLVM_HAVE_TF_API diff --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp index b0ff19857963..8fd4011e6cd4 100644 --- a/llvm/lib/Analysis/TFUtils.cpp +++ b/llvm/lib/Analysis/TFUtils.cpp @@ -13,9 +13,10 @@ #include "llvm/Config/config.h" #if defined(LLVM_HAVE_TF_API) -#include "llvm/Analysis/Utils/TFUtils.h" #include "llvm/ADT/Twine.h" +#include "llvm/Analysis/Utils/TFUtils.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/JSON.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/raw_ostream.h" @@ -83,6 +84,41 @@ private: std::vector<TF_Tensor *> Output; }; +Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx, + const json::Value &Value) { + auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> { + std::string S; + llvm::raw_string_ostream OS(S); + OS << Value; + Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S); + return None; + }; + json::ObjectMapper Mapper(Value); + if (!Mapper) + return EmitError("Value is not a dict"); + + std::string TensorName; + int TensorPort = -1; + std::string TensorType; + std::vector<int64_t> TensorShape; + + if (!Mapper.map<std::string>("name", TensorName)) + return EmitError("'name' property not present or not a string"); + if (!Mapper.map<std::string>("type", TensorType)) + return EmitError("'type' property not present or not a string"); + if (!Mapper.map<int>("port", TensorPort)) + return EmitError("'port' property not present or not an int"); + if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape)) + return EmitError("'shape' property not present or not an int array"); + +#define PARSE_TYPE(T, S, E) \ + if (TensorType == #S) \ + return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort); + TFUTILS_SUPPORTED_TYPES(PARSE_TYPE) +#undef PARSE_TYPE + return None; +} + class TFModelEvaluatorImpl { public: TFModelEvaluatorImpl(StringRef SavedModelPath, @@ -249,25 +285,12 @@ void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) { return TF_TensorData(Impl->getOutput()[Index]); } -template <> int TensorSpec::getDataType<float>() { return TF_FLOAT; } - -template <> int TensorSpec::getDataType<double>() { return TF_DOUBLE; } - -template <> int TensorSpec::getDataType<int8_t>() { return TF_INT8; } - -template <> int TensorSpec::getDataType<uint8_t>() { return TF_UINT8; } - -template <> int TensorSpec::getDataType<int16_t>() { return TF_INT16; } - -template <> int TensorSpec::getDataType<uint16_t>() { return TF_UINT16; } - -template <> int TensorSpec::getDataType<int32_t>() { return TF_INT32; } - -template <> int TensorSpec::getDataType<uint32_t>() { return TF_UINT32; } +#define TFUTILS_GETDATATYPE_IMPL(T, S, E) \ + template <> int TensorSpec::getDataType<T>() { return TF_##E; } -template <> int TensorSpec::getDataType<int64_t>() { return TF_INT64; } +TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL) -template <> int TensorSpec::getDataType<uint64_t>() { return TF_UINT64; } +#undef TFUTILS_GETDATATYPE_IMPL TFModelEvaluator::EvaluationResult::~EvaluationResult() {} TFModelEvaluator::~TFModelEvaluator() {} diff --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp index e96d34092c7e..abdf2b2b9784 100644 --- a/llvm/unittests/Analysis/TFUtilsTest.cpp +++ b/llvm/unittests/Analysis/TFUtilsTest.cpp @@ -94,3 +94,32 @@ TEST(TFUtilsTest, EvalError) { EXPECT_FALSE(ER.hasValue()); EXPECT_FALSE(Evaluator.isValid()); } + +TEST(TFUtilsTest, JSONParsing) { + auto Value = json::parse( + R"({"name": "tensor_name", + "port": 2, + "type": "int32", + "shape":[1,4] + })"); + EXPECT_TRUE(!!Value); + LLVMContext Ctx; + Optional<TensorSpec> Spec = getTensorSpecFromJSON(Ctx, *Value); + EXPECT_TRUE(Spec.hasValue()); + EXPECT_EQ(*Spec, TensorSpec::createSpec<int32_t>("tensor_name", {1, 4}, 2)); +} + +TEST(TFUtilsTest, JSONParsingInvalidTensorType) { + auto Value = json::parse( + R"( + {"name": "tensor_name", + "port": 2, + "type": "no such type", + "shape":[1,4] + } + )"); + EXPECT_TRUE(!!Value); + LLVMContext Ctx; + auto Spec = getTensorSpecFromJSON(Ctx, *Value); + EXPECT_FALSE(Spec.hasValue()); +} |