summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAart Bik <ajcbik@google.com>2021-02-18 22:01:39 -0800
committerAart Bik <ajcbik@google.com>2021-02-18 23:22:14 -0800
commit2556d622828ae5631ac483d82592440fa1910d80 (patch)
tree11eb31c6b36b2cf3fdf115a55c51c35df4daa5ec
parentcd4051ac802fdc5664a3432f57d99bbcb4c07a92 (diff)
downloadllvm-2556d622828ae5631ac483d82592440fa1910d80.tar.gz
[mlir][sparse] assert fail on mismatch between rank and annotations array
Rationale: Providing the wrong number of sparse/dense annotations was silently ignored or caused unrelated crashes. This minor change verifies that the provided number matches the rank. Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D97034
-rw-r--r--mlir/lib/ExecutionEngine/SparseUtils.cpp51
1 files changed, 33 insertions, 18 deletions
diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp
index 0ff6f7d49b46..903b9f115182 100644
--- a/mlir/lib/ExecutionEngine/SparseUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp
@@ -76,8 +76,8 @@ public:
}
/// Adds element as indices and value.
void add(const std::vector<uint64_t> &ind, double val) {
- assert(sizes.size() == ind.size());
- for (int64_t r = 0, rank = sizes.size(); r < rank; r++)
+ assert(getRank() == ind.size());
+ for (int64_t r = 0, rank = getRank(); r < rank; r++)
assert(ind[r] < sizes[r]); // within bounds
elements.emplace_back(Element(ind, val));
}
@@ -85,6 +85,8 @@ public:
void sort() { std::sort(elements.begin(), elements.end(), lexOrder); }
/// Primitive one-time iteration.
const Element &next() { return elements[pos++]; }
+ /// Returns rank.
+ uint64_t getRank() const { return sizes.size(); }
/// Getter for sizes array.
const std::vector<uint64_t> &getSizes() const { return sizes; }
/// Getter for elements array.
@@ -139,13 +141,13 @@ public:
/// Constructs sparse tensor storage scheme following the given
/// per-rank dimension dense/sparse annotations.
SparseTensorStorage(SparseTensor *tensor, bool *sparsity)
- : sizes(tensor->getSizes()), pointers(sizes.size()),
- indices(sizes.size()) {
+ : sizes(tensor->getSizes()), pointers(getRank()), indices(getRank()) {
// Provide hints on capacity.
// TODO: needs fine-tuning based on sparsity
- values.reserve(tensor->getElements().size());
- for (uint64_t d = 0, s = 1, rank = sizes.size(); d < rank; d++) {
- s *= tensor->getSizes()[d];
+ uint64_t nnz = tensor->getElements().size();
+ values.reserve(nnz);
+ for (uint64_t d = 0, s = 1, rank = getRank(); d < rank; d++) {
+ s *= sizes[d];
if (sparsity[d]) {
pointers[d].reserve(s + 1);
indices[d].reserve(s);
@@ -153,12 +155,16 @@ public:
}
}
// Then setup the tensor.
- traverse(tensor, sparsity, 0, tensor->getElements().size(), 0);
+ traverse(tensor, sparsity, 0, nnz, 0);
}
virtual ~SparseTensorStorage() {}
+ uint64_t getRank() const { return sizes.size(); }
+
uint64_t getDimSize(uint64_t d) override { return sizes[d]; }
+
+ // Partially specialize these three methods based on template types.
void getPointers(std::vector<P> **out, uint64_t d) override {
*out = &pointers[d];
}
@@ -176,7 +182,7 @@ private:
uint64_t d) {
const std::vector<Element> &elements = tensor->getElements();
// Once dimensions are exhausted, insert the numerical values.
- if (d == sizes.size()) {
+ if (d == getRank()) {
values.push_back(lo < hi ? elements[lo].value : 0.0);
return;
}
@@ -221,9 +227,10 @@ private:
/// Templated reader.
template <typename P, typename I, typename V>
-void *newSparseTensor(char *filename, bool *sparsity) {
+void *newSparseTensor(char *filename, bool *sparsity, uint64_t size) {
uint64_t idata[64];
SparseTensor *t = static_cast<SparseTensor *>(openTensorC(filename, idata));
+ assert(size == t->getRank()); // sparsity array must match rank
SparseTensorStorageBase *tensor =
new SparseTensorStorage<P, I, V>(t, sparsity);
delete t;
@@ -481,21 +488,29 @@ void *newSparseTensor(char *filename, bool *abase, bool *adata, uint64_t aoff,
assert(astride == 1);
bool *sparsity = abase + aoff;
if (ptrTp == kU64 && indTp == kU64 && valTp == kF64)
- return newSparseTensor<uint64_t, uint64_t, double>(filename, sparsity);
+ return newSparseTensor<uint64_t, uint64_t, double>(filename, sparsity,
+ asize);
if (ptrTp == kU64 && indTp == kU64 && valTp == kF32)
- return newSparseTensor<uint64_t, uint64_t, float>(filename, sparsity);
+ return newSparseTensor<uint64_t, uint64_t, float>(filename, sparsity,
+ asize);
if (ptrTp == kU64 && indTp == kU32 && valTp == kF64)
- return newSparseTensor<uint64_t, uint32_t, double>(filename, sparsity);
+ return newSparseTensor<uint64_t, uint32_t, double>(filename, sparsity,
+ asize);
if (ptrTp == kU64 && indTp == kU32 && valTp == kF32)
- return newSparseTensor<uint64_t, uint32_t, float>(filename, sparsity);
+ return newSparseTensor<uint64_t, uint32_t, float>(filename, sparsity,
+ asize);
if (ptrTp == kU32 && indTp == kU64 && valTp == kF64)
- return newSparseTensor<uint32_t, uint64_t, double>(filename, sparsity);
+ return newSparseTensor<uint32_t, uint64_t, double>(filename, sparsity,
+ asize);
if (ptrTp == kU32 && indTp == kU64 && valTp == kF32)
- return newSparseTensor<uint32_t, uint64_t, float>(filename, sparsity);
+ return newSparseTensor<uint32_t, uint64_t, float>(filename, sparsity,
+ asize);
if (ptrTp == kU32 && indTp == kU32 && valTp == kF64)
- return newSparseTensor<uint32_t, uint32_t, double>(filename, sparsity);
+ return newSparseTensor<uint32_t, uint32_t, double>(filename, sparsity,
+ asize);
if (ptrTp == kU32 && indTp == kU32 && valTp == kF32)
- return newSparseTensor<uint32_t, uint32_t, float>(filename, sparsity);
+ return newSparseTensor<uint32_t, uint32_t, float>(filename, sparsity,
+ asize);
fputs("unsupported combination of types\n", stderr);
exit(1);
}