summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMircea Trofin <mtrofin@google.com>2020-07-30 12:44:07 -0700
committerMircea Trofin <mtrofin@google.com>2020-08-03 09:49:31 -0700
commit4b1b109c5126efc963cc19949df5201e40f1bcc1 (patch)
tree0a661390534238081058d3334bbd1e28214fe6bb
parentcaf002c7be44cb6c54de5a1b19aa177f18b6b0c1 (diff)
downloadllvm-4b1b109c5126efc963cc19949df5201e40f1bcc1.tar.gz
[llvm] Add a parser from JSON to TensorSpec
A JSON->TensorSpec utility we will use subsequently to specify additional outputs needed for certain training scenarios. Differential Revision: https://reviews.llvm.org/D84976
-rw-r--r--llvm/include/llvm/Analysis/Utils/TFUtils.h44
-rw-r--r--llvm/lib/Analysis/TFUtils.cpp59
-rw-r--r--llvm/unittests/Analysis/TFUtilsTest.cpp29
3 files changed, 103 insertions, 29 deletions
diff --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h
index 512f45bb5671..d4450276a22e 100644
--- a/llvm/include/llvm/Analysis/Utils/TFUtils.h
+++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h
@@ -13,6 +13,7 @@
#ifdef LLVM_HAVE_TF_API
#include "llvm/IR/LLVMContext.h"
+#include "llvm/Support/JSON.h"
#include <memory>
#include <vector>
@@ -58,6 +59,13 @@ public:
int typeIndex() const { return TypeIndex; }
const std::vector<int64_t> &shape() const { return Shape; }
+ bool operator==(const TensorSpec &Other) const {
+ return Name == Other.Name && Port == Other.Port &&
+ TypeIndex == Other.TypeIndex && Shape == Other.Shape;
+ }
+
+ bool operator!=(const TensorSpec &Other) const { return !(*this == Other); }
+
private:
TensorSpec(const std::string &Name, int Port, int TypeIndex,
const std::vector<int64_t> &Shape)
@@ -73,6 +81,9 @@ private:
std::vector<int64_t> Shape;
};
+Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
+ const json::Value &Value);
+
class TFModelEvaluator final {
public:
/// The result of a model evaluation. Handles the lifetime of the output
@@ -124,17 +135,28 @@ private:
std::unique_ptr<TFModelEvaluatorImpl> Impl;
};
-template <> int TensorSpec::getDataType<float>();
-template <> int TensorSpec::getDataType<double>();
-template <> int TensorSpec::getDataType<int8_t>();
-template <> int TensorSpec::getDataType<uint8_t>();
-template <> int TensorSpec::getDataType<int16_t>();
-template <> int TensorSpec::getDataType<uint16_t>();
-template <> int TensorSpec::getDataType<int32_t>();
-template <> int TensorSpec::getDataType<uint32_t>();
-template <> int TensorSpec::getDataType<int64_t>();
-template <> int TensorSpec::getDataType<uint64_t>();
-
+/// List of supported types, as a triple:
+/// C++ type
+/// short name (for strings, for instance)
+/// capitalized short name (for enums, for instance)
+#define TFUTILS_SUPPORTED_TYPES(M) \
+ M(float, float, FLOAT) \
+ M(double, double, DOUBLE) \
+ M(int8_t, int8, INT8) \
+ M(uint8_t, uint8, UINT8) \
+ M(int16_t, int16, INT16) \
+ M(uint16_t, uint16, UINT16) \
+ M(int32_t, int32, INT32) \
+ M(uint32_t, uint32, UINT32) \
+ M(int64_t, int64, INT64) \
+ M(uint64_t, uint64, UINT64)
+
+#define TFUTILS_GETDATATYPE_DEF(T, S, C) \
+ template <> int TensorSpec::getDataType<T>();
+
+TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_DEF)
+
+#undef TFUTILS_GETDATATYPE_DEF
} // namespace llvm
#endif // LLVM_HAVE_TF_API
diff --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp
index b0ff19857963..8fd4011e6cd4 100644
--- a/llvm/lib/Analysis/TFUtils.cpp
+++ b/llvm/lib/Analysis/TFUtils.cpp
@@ -13,9 +13,10 @@
#include "llvm/Config/config.h"
#if defined(LLVM_HAVE_TF_API)
-#include "llvm/Analysis/Utils/TFUtils.h"
#include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/Utils/TFUtils.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/JSON.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/raw_ostream.h"
@@ -83,6 +84,41 @@ private:
std::vector<TF_Tensor *> Output;
};
+Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
+ const json::Value &Value) {
+ auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {
+ std::string S;
+ llvm::raw_string_ostream OS(S);
+ OS << Value;
+ Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
+ return None;
+ };
+ json::ObjectMapper Mapper(Value);
+ if (!Mapper)
+ return EmitError("Value is not a dict");
+
+ std::string TensorName;
+ int TensorPort = -1;
+ std::string TensorType;
+ std::vector<int64_t> TensorShape;
+
+ if (!Mapper.map<std::string>("name", TensorName))
+ return EmitError("'name' property not present or not a string");
+ if (!Mapper.map<std::string>("type", TensorType))
+ return EmitError("'type' property not present or not a string");
+ if (!Mapper.map<int>("port", TensorPort))
+ return EmitError("'port' property not present or not an int");
+ if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
+ return EmitError("'shape' property not present or not an int array");
+
+#define PARSE_TYPE(T, S, E) \
+ if (TensorType == #S) \
+ return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
+ TFUTILS_SUPPORTED_TYPES(PARSE_TYPE)
+#undef PARSE_TYPE
+ return None;
+}
+
class TFModelEvaluatorImpl {
public:
TFModelEvaluatorImpl(StringRef SavedModelPath,
@@ -249,25 +285,12 @@ void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
return TF_TensorData(Impl->getOutput()[Index]);
}
-template <> int TensorSpec::getDataType<float>() { return TF_FLOAT; }
-
-template <> int TensorSpec::getDataType<double>() { return TF_DOUBLE; }
-
-template <> int TensorSpec::getDataType<int8_t>() { return TF_INT8; }
-
-template <> int TensorSpec::getDataType<uint8_t>() { return TF_UINT8; }
-
-template <> int TensorSpec::getDataType<int16_t>() { return TF_INT16; }
-
-template <> int TensorSpec::getDataType<uint16_t>() { return TF_UINT16; }
-
-template <> int TensorSpec::getDataType<int32_t>() { return TF_INT32; }
-
-template <> int TensorSpec::getDataType<uint32_t>() { return TF_UINT32; }
+#define TFUTILS_GETDATATYPE_IMPL(T, S, E) \
+ template <> int TensorSpec::getDataType<T>() { return TF_##E; }
-template <> int TensorSpec::getDataType<int64_t>() { return TF_INT64; }
+TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL)
-template <> int TensorSpec::getDataType<uint64_t>() { return TF_UINT64; }
+#undef TFUTILS_GETDATATYPE_IMPL
TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
TFModelEvaluator::~TFModelEvaluator() {}
diff --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp
index e96d34092c7e..abdf2b2b9784 100644
--- a/llvm/unittests/Analysis/TFUtilsTest.cpp
+++ b/llvm/unittests/Analysis/TFUtilsTest.cpp
@@ -94,3 +94,32 @@ TEST(TFUtilsTest, EvalError) {
EXPECT_FALSE(ER.hasValue());
EXPECT_FALSE(Evaluator.isValid());
}
+
+TEST(TFUtilsTest, JSONParsing) {
+ auto Value = json::parse(
+ R"({"name": "tensor_name",
+ "port": 2,
+ "type": "int32",
+ "shape":[1,4]
+ })");
+ EXPECT_TRUE(!!Value);
+ LLVMContext Ctx;
+ Optional<TensorSpec> Spec = getTensorSpecFromJSON(Ctx, *Value);
+ EXPECT_TRUE(Spec.hasValue());
+ EXPECT_EQ(*Spec, TensorSpec::createSpec<int32_t>("tensor_name", {1, 4}, 2));
+}
+
+TEST(TFUtilsTest, JSONParsingInvalidTensorType) {
+ auto Value = json::parse(
+ R"(
+ {"name": "tensor_name",
+ "port": 2,
+ "type": "no such type",
+ "shape":[1,4]
+ }
+ )");
+ EXPECT_TRUE(!!Value);
+ LLVMContext Ctx;
+ auto Spec = getTensorSpecFromJSON(Ctx, *Value);
+ EXPECT_FALSE(Spec.hasValue());
+}