diff options
Diffstat (limited to 'src/trie.c')
-rw-r--r-- | src/trie.c | 290 |
1 files changed, 290 insertions, 0 deletions
diff --git a/src/trie.c b/src/trie.c new file mode 100644 index 000000000..c26368edc --- /dev/null +++ b/src/trie.c @@ -0,0 +1,290 @@ +/* + * Simple trie implementation for key-value mapping storage + * + * Copyright (c) 2020 Ákos Uzonyi <uzonyi.akos@gmail.com> + * All rights reserved. + * + * SPDX-License-Identifier: LGPL-2.1-or-later + */ + +#ifdef HAVE_CONFIG_H +# include "config.h" +#endif + +#include <stdlib.h> +#include <stdio.h> + +#include "trie.h" +#include "xmalloc.h" + +static const uint8_t ptr_sz_lg = (sizeof(void *) == 8 ? 6 : 5); + +/** + * Returns lg2 of node size in bits for the specific level of the trie. + */ +static uint8_t +trie_get_node_size(struct trie *t, uint8_t depth) +{ + /* Last level contains data and we allow it having a different size */ + if (depth == t->max_depth) + return t->data_block_key_bits + t->item_size_lg; + /* Last level of the tree can be smaller */ + if (depth == t->max_depth - 1) + return (t->key_size - t->data_block_key_bits - 1) % + t->node_key_bits + 1 + ptr_sz_lg; + + return t->node_key_bits + ptr_sz_lg; +} + +/** + * Provides starting offset of bits in key corresponding to the node index + * at the specific level. + */ +static uint8_t +trie_get_node_bit_offs(struct trie *t, uint8_t depth) +{ + uint8_t offs; + + if (depth == t->max_depth) + return 0; + + offs = t->data_block_key_bits; + + if (depth == t->max_depth - 1) + return offs; + + /* data_block_size + remainder */ + offs += trie_get_node_size(t, t->max_depth - 1) - ptr_sz_lg; + offs += (t->max_depth - depth - 2) * t->node_key_bits; + + return offs; +} + +struct trie * +trie_create(uint8_t key_size, uint8_t item_size_lg, uint8_t node_key_bits, + uint8_t data_block_key_bits, uint64_t empty_value) +{ + if (item_size_lg > 6) + return NULL; + if (key_size > 64) + return NULL; + if (node_key_bits < 1) + return NULL; + if (data_block_key_bits < 1 || data_block_key_bits > key_size) + return NULL; + + struct trie *t = malloc(sizeof(*t)); + if (!t) + return NULL; + + t->empty_value = empty_value; + t->data = NULL; + t->item_size_lg = item_size_lg; + t->node_key_bits = node_key_bits; + t->data_block_key_bits = data_block_key_bits; + t->key_size = key_size; + t->max_depth = (key_size - data_block_key_bits + node_key_bits - 1) + / t->node_key_bits; + + if (item_size_lg != 6) + t->empty_value &= (((uint64_t) 1 << (1 << t->item_size_lg)) - 1); + + return t; +} + +static void * +trie_create_data_block(struct trie *t) +{ + uint64_t fill_value = t->empty_value; + for (int i = 1; i < 1 << (6 - t->item_size_lg); i++) { + fill_value <<= (1 << t->item_size_lg); + fill_value |= t->empty_value; + } + + uint8_t sz = t->data_block_key_bits + t->item_size_lg; + if (sz < 6) + sz = 6; + + size_t count = 1 << (sz - 6); + uint64_t *data_block = xcalloc(count, 8); + + for (size_t i = 0; i < count; i++) + data_block[i] = fill_value; + + return data_block; +} + +static uint64_t * +trie_get_node(struct trie *t, uint64_t key, bool auto_create) +{ + void **cur_node = &(t->data); + + if (t->key_size < 64 && key > (uint64_t) 1 << t->key_size) + return NULL; + + for (uint8_t cur_depth = 0; cur_depth <= t->max_depth; cur_depth++) { + uint8_t offs = trie_get_node_bit_offs(t, cur_depth); + uint8_t sz = trie_get_node_size(t, cur_depth); + + if (!*cur_node) { + if (!auto_create) + return NULL; + + if (cur_depth == t->max_depth) + *cur_node = trie_create_data_block(t); + else + *cur_node = xcalloc(1 << sz, 1); + } + + if (cur_depth == t->max_depth) + break; + + size_t pos = (key >> offs) & ((1 << (sz - ptr_sz_lg)) - 1); + cur_node = (((void **) (*cur_node)) + pos); + } + + return (uint64_t *) (*cur_node); +} + +static void +trie_data_block_calc_pos(struct trie *t, uint64_t key, + uint64_t *pos, uint64_t *mask, uint64_t *offs) +{ + uint64_t key_mask; + + key_mask = (1 << t->data_block_key_bits) - 1; + *pos = (key & key_mask) >> (6 - t->item_size_lg); + + if (t->item_size_lg == 6) { + *offs = 0; + *mask = -1; + return; + } + + key_mask = (1 << (6 - t->item_size_lg)) - 1; + *offs = (key & key_mask) * (1 << t->item_size_lg); + + *mask = (((uint64_t) 1 << (1 << t->item_size_lg)) - 1) << *offs; +} + +bool +trie_set(struct trie *t, uint64_t key, uint64_t val) +{ + uint64_t *data = trie_get_node(t, key, true); + if (!data) + return false; + + uint64_t pos, mask, offs; + trie_data_block_calc_pos(t, key, &pos, &mask, &offs); + + data[pos] &= ~mask; + data[pos] |= (val << offs) & mask; + + return true; +} + +static uint64_t +trie_data_block_get(struct trie *t, uint64_t *data, uint64_t key) +{ + if (!data) + return t->empty_value; + + uint64_t pos, mask, offs; + trie_data_block_calc_pos(t, key, &pos, &mask, &offs); + + return (data[pos] & mask) >> offs; +} + +uint64_t +trie_get(struct trie *b, uint64_t key) +{ + return trie_data_block_get(b, trie_get_node(b, key, false), key); +} + +static uint64_t +trie_iterate_keys_node(struct trie *t, + trie_iterate_fn fn, void *fn_data, + void *node, uint64_t start, uint64_t end, + uint8_t depth) +{ + if (start > end || !node) + return 0; + + if (t->key_size < 64) { + uint64_t key_max = ((uint64_t) 1 << t->key_size) - 1; + if (end > key_max) + end = key_max; + } + + if (depth == t->max_depth) { + for (uint64_t i = start; i <= end; i++) + fn(fn_data, i, trie_data_block_get(t, + (uint64_t *) node, i)); + + return end - start + 1; + } + + uint8_t parent_node_bit_off = depth == 0 ? + t->key_size : + trie_get_node_bit_offs(t, depth - 1); + + uint64_t first_key_in_node = start & + (uint64_t) -1 << parent_node_bit_off; + + uint8_t node_bit_off = trie_get_node_bit_offs(t, depth); + uint8_t node_key_bits = parent_node_bit_off - node_bit_off; + uint64_t mask = ((uint64_t) 1 << (node_key_bits)) - 1; + uint64_t start_index = (start >> node_bit_off) & mask; + uint64_t end_index = (end >> node_bit_off) & mask; + uint64_t child_key_count = (uint64_t) 1 << node_bit_off; + + uint64_t count = 0; + + for (uint64_t i = start_index; i <= end_index; i++) { + uint64_t child_start = first_key_in_node + i * child_key_count; + uint64_t child_end = first_key_in_node + + (i + 1) * child_key_count - 1; + + if (child_start < start) + child_start = start; + if (child_end > end) + child_end = end; + + count += trie_iterate_keys_node(t, fn, fn_data, + ((void **) node)[i], child_start, child_end, + depth + 1); + } + + return count; +} + +uint64_t trie_iterate_keys(struct trie *t, uint64_t start, uint64_t end, + trie_iterate_fn fn, void *fn_data) +{ + return trie_iterate_keys_node(t, fn, fn_data, t->data, + start, end, 0); +} + +static void +trie_free_node(struct trie *t, void *node, uint8_t depth) +{ + if (!node) + return; + + if (depth >= t->max_depth) + goto free_node; + + size_t sz = 1 << (trie_get_node_size(t, depth) - ptr_sz_lg); + for (size_t i = 0; i < sz; i++) + trie_free_node(t, ((void **) node)[i], depth + 1); + +free_node: + free(node); +} + +void +trie_free(struct trie *t) +{ + trie_free_node(t, t->data, 0); + free(t); +} |