diff options
author | Aart Bik <ajcbik@google.com> | 2021-02-18 22:01:39 -0800 |
---|---|---|
committer | Aart Bik <ajcbik@google.com> | 2021-02-18 23:22:14 -0800 |
commit | 2556d622828ae5631ac483d82592440fa1910d80 (patch) | |
tree | 11eb31c6b36b2cf3fdf115a55c51c35df4daa5ec | |
parent | cd4051ac802fdc5664a3432f57d99bbcb4c07a92 (diff) | |
download | llvm-2556d622828ae5631ac483d82592440fa1910d80.tar.gz |
[mlir][sparse] assert fail on mismatch between rank and annotations array
Rationale:
Providing the wrong number of sparse/dense annotations was silently
ignored or caused unrelated crashes. This minor change verifies that
the provided number matches the rank.
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D97034
-rw-r--r-- | mlir/lib/ExecutionEngine/SparseUtils.cpp | 51 |
1 files changed, 33 insertions, 18 deletions
diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp index 0ff6f7d49b46..903b9f115182 100644 --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -76,8 +76,8 @@ public: } /// Adds element as indices and value. void add(const std::vector<uint64_t> &ind, double val) { - assert(sizes.size() == ind.size()); - for (int64_t r = 0, rank = sizes.size(); r < rank; r++) + assert(getRank() == ind.size()); + for (int64_t r = 0, rank = getRank(); r < rank; r++) assert(ind[r] < sizes[r]); // within bounds elements.emplace_back(Element(ind, val)); } @@ -85,6 +85,8 @@ public: void sort() { std::sort(elements.begin(), elements.end(), lexOrder); } /// Primitive one-time iteration. const Element &next() { return elements[pos++]; } + /// Returns rank. + uint64_t getRank() const { return sizes.size(); } /// Getter for sizes array. const std::vector<uint64_t> &getSizes() const { return sizes; } /// Getter for elements array. @@ -139,13 +141,13 @@ public: /// Constructs sparse tensor storage scheme following the given /// per-rank dimension dense/sparse annotations. SparseTensorStorage(SparseTensor *tensor, bool *sparsity) - : sizes(tensor->getSizes()), pointers(sizes.size()), - indices(sizes.size()) { + : sizes(tensor->getSizes()), pointers(getRank()), indices(getRank()) { // Provide hints on capacity. // TODO: needs fine-tuning based on sparsity - values.reserve(tensor->getElements().size()); - for (uint64_t d = 0, s = 1, rank = sizes.size(); d < rank; d++) { - s *= tensor->getSizes()[d]; + uint64_t nnz = tensor->getElements().size(); + values.reserve(nnz); + for (uint64_t d = 0, s = 1, rank = getRank(); d < rank; d++) { + s *= sizes[d]; if (sparsity[d]) { pointers[d].reserve(s + 1); indices[d].reserve(s); @@ -153,12 +155,16 @@ public: } } // Then setup the tensor. - traverse(tensor, sparsity, 0, tensor->getElements().size(), 0); + traverse(tensor, sparsity, 0, nnz, 0); } virtual ~SparseTensorStorage() {} + uint64_t getRank() const { return sizes.size(); } + uint64_t getDimSize(uint64_t d) override { return sizes[d]; } + + // Partially specialize these three methods based on template types. void getPointers(std::vector<P> **out, uint64_t d) override { *out = &pointers[d]; } @@ -176,7 +182,7 @@ private: uint64_t d) { const std::vector<Element> &elements = tensor->getElements(); // Once dimensions are exhausted, insert the numerical values. - if (d == sizes.size()) { + if (d == getRank()) { values.push_back(lo < hi ? elements[lo].value : 0.0); return; } @@ -221,9 +227,10 @@ private: /// Templated reader. template <typename P, typename I, typename V> -void *newSparseTensor(char *filename, bool *sparsity) { +void *newSparseTensor(char *filename, bool *sparsity, uint64_t size) { uint64_t idata[64]; SparseTensor *t = static_cast<SparseTensor *>(openTensorC(filename, idata)); + assert(size == t->getRank()); // sparsity array must match rank SparseTensorStorageBase *tensor = new SparseTensorStorage<P, I, V>(t, sparsity); delete t; @@ -481,21 +488,29 @@ void *newSparseTensor(char *filename, bool *abase, bool *adata, uint64_t aoff, assert(astride == 1); bool *sparsity = abase + aoff; if (ptrTp == kU64 && indTp == kU64 && valTp == kF64) - return newSparseTensor<uint64_t, uint64_t, double>(filename, sparsity); + return newSparseTensor<uint64_t, uint64_t, double>(filename, sparsity, + asize); if (ptrTp == kU64 && indTp == kU64 && valTp == kF32) - return newSparseTensor<uint64_t, uint64_t, float>(filename, sparsity); + return newSparseTensor<uint64_t, uint64_t, float>(filename, sparsity, + asize); if (ptrTp == kU64 && indTp == kU32 && valTp == kF64) - return newSparseTensor<uint64_t, uint32_t, double>(filename, sparsity); + return newSparseTensor<uint64_t, uint32_t, double>(filename, sparsity, + asize); if (ptrTp == kU64 && indTp == kU32 && valTp == kF32) - return newSparseTensor<uint64_t, uint32_t, float>(filename, sparsity); + return newSparseTensor<uint64_t, uint32_t, float>(filename, sparsity, + asize); if (ptrTp == kU32 && indTp == kU64 && valTp == kF64) - return newSparseTensor<uint32_t, uint64_t, double>(filename, sparsity); + return newSparseTensor<uint32_t, uint64_t, double>(filename, sparsity, + asize); if (ptrTp == kU32 && indTp == kU64 && valTp == kF32) - return newSparseTensor<uint32_t, uint64_t, float>(filename, sparsity); + return newSparseTensor<uint32_t, uint64_t, float>(filename, sparsity, + asize); if (ptrTp == kU32 && indTp == kU32 && valTp == kF64) - return newSparseTensor<uint32_t, uint32_t, double>(filename, sparsity); + return newSparseTensor<uint32_t, uint32_t, double>(filename, sparsity, + asize); if (ptrTp == kU32 && indTp == kU32 && valTp == kF32) - return newSparseTensor<uint32_t, uint32_t, float>(filename, sparsity); + return newSparseTensor<uint32_t, uint32_t, float>(filename, sparsity, + asize); fputs("unsupported combination of types\n", stderr); exit(1); } |