- paste
- 2022-8-30 15:52:18 @
* @brief An AVLTree-based map implementation
* @details The map is sorted according to the natural ordering of its
* keys or by a {@code Compare} function provided; This implementation
* provides guaranteed log(n) time cost for the contains, get, insert
* and remove operations.
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <stack>
#include <utility>
#include <vector>
* An AVLTree-based map implementation
* https://en.wikipedia.org/wiki/AVL_tree
* @tparam Key the type of keys maintained by this map
* @tparam Value the type of mapped values
* @tparam Compare
template <typename Key, typename Value, typename Compare = std::less<Key> >
class AvlTreeMap {
using USize = size_t;
using Factor = int64_t;
Compare compare = Compare();
struct Entry {
Key key;
Value value;
bool operator==(const Entry &rhs) const noexcept {
return this->key == rhs.key && this->value == rhs.value;
bool operator!=(const Entry &rhs) const noexcept {
return this->key != rhs.key || this->value != rhs.value;
struct Node {
using Ptr = std::shared_ptr<Node>;
using Provider = const std::function<Ptr(void)> &;
using Consumer = const std::function<void(const Ptr &)> &;
Key key;
Value value{};
Ptr left = nullptr;
Ptr right = nullptr;
USize height = 1;
explicit Node(Key k) : key(std::move(k)) {}
explicit Node(Key k, Value v) : key(std::move(k)), value(std::move(v)) {}
~Node() = default;
inline bool isLeaf() const noexcept {
return this->left == nullptr && this->right == nullptr;
inline void updateHeight() noexcept {
if (this->isLeaf()) {
this->height = 1;
} else if (this->left == nullptr) {
this->height = this->right->height + 1;
} else if (this->right == nullptr) {
this->height = this->left->height + 1;
} else {
this->height = std::max(left->height, right->height) + 1;
inline Factor factor() const noexcept {
if (this->isLeaf()) {
return 0;
} else if (this->left == nullptr) {
return (Factor)this->right->height;
} else if (this->right == nullptr) {
return (Factor) - this->left->height;
} else {
return (Factor)(this->right->height - this->left->height);
inline Entry entry() const {
return Entry{key, value};
static Ptr from(const Key &k) {
return std::make_shared<Node>(Node(k));
static Ptr from(const Key &k, const Value &v) {
return std::make_shared<Node>(Node(k, v));
using NodePtr = typename Node::Ptr;
using ConstNodePtr = const NodePtr &;
using NodeProvider = typename Node::Provider;
using NodeConsumer = typename Node::Consumer;
NodePtr root = nullptr;
USize count = 0;
using K = const Key &;
using V = const Value &;
using EntryList = std::vector<Entry>;
using KeyValueConsumer = const std::function<void(K, V)> &;
using MutKeyValueConsumer = const std::function<void(K, Value &)> &;
using KeyValueFilter = const std::function<bool(K, V)> &;
class NoSuchMappingException : protected std::exception {
const char *message;
explicit NoSuchMappingException(const char *msg) : message(msg) {}
const char *what() const noexcept override {
return message;
AvlTreeMap() noexcept = default;
* Returns the number of entries in this map.
* @return size_t
inline USize size() const noexcept {
return this->count;
* Returns true if this collection contains no elements.
* @return bool
inline bool empty() const noexcept {
return this->count == 0;
* Removes all of the elements from this map.
void clear() noexcept {
this->root = nullptr;
this->count = 0;
* Returns the value to which the specified key is mapped; If this map
* contains no mapping for the key, a {@code NoSuchMappingException} will
* be thrown.
* @param key
* @return AvlTreeMap<Key, Value>::Value
* @throws NoSuchMappingException
Value get(K key) const {
if (this->root == nullptr) {
throw NoSuchMappingException("Invalid key");
} else {
NodePtr node = this->getNode(this->root, key);
if (node != nullptr) {
return node->value;
} else {
throw NoSuchMappingException("Invalid key");
* Returns the value to which the specified key is mapped; If this map
* contains no mapping for the key, a new mapping with a default value
* will be inserted.
* @param key
* @return AvlTreeMap<Key, Value>::Value &
Value &getOrDefault(K key) {
if (this->root == nullptr) {
this->root = Node::from(key);
this->count += 1;
return this->root->value;
} else {
return this
->getNodeOrProvide(this->root, key,
[&key]() {
return Node::from(key);
* Returns true if this map contains a mapping for the specified key.
* @param key
* @return bool
bool contains(K key) const {
return this->getNode(this->root, key) != nullptr;
* Associates the specified value with the specified key in this map.
* @param key
* @param value
void insert(K key, V value) {
if (this->root == nullptr) {
this->root = Node::from(key, value);
this->count += 1;
} else {
this->insert(this->root, key, value);
* If the specified key is not already associated with a value, associates
* it with the given value and returns true, else returns false.
* @param key
* @param value
* @return bool
bool insertIfAbsent(K key, V value) {
USize sizeBeforeInsertion = this->size();
if (this->root == nullptr) {
this->root = Node::from(key, value);
this->count += 1;
} else {
this->insert(this->root, key, value, false);
return this->size() > sizeBeforeInsertion;
* If the specified key is not already associated with a value, associates
* it with the given value and returns the value, else returns the associated
* value.
* @param key
* @param value
* @return
Value &getOrInsert(K key, V value) {
if (this->root == nullptr) {
this->root = Node::from(key, value);
this->count += 1;
return root->value;
} else {
NodePtr node = getNodeOrProvide(this->root, key,
[&]() {
return Node::from(key, value);
return node->value;
Value operator[](K key) const {
return this->get(key);
Value &operator[](K key) {
return this->getOrDefault(key);
* Removes the mapping for a key from this map if it is present;
* Returns true if the mapping is present else returns false
* @param key the key of the mapping
* @return bool
bool remove(K key) {
if (this->root == nullptr) {
return false;
} else {
return this->remove(this->root, key, [](ConstNodePtr) {});
* Removes the mapping for a key from this map if it is present and returns
* the value which is mapped to the key; If this map contains no mapping for
* the key, a {@code NoSuchMappingException} will be thrown.
* @param key
* @return AvlTreeMap<Key, Value>::Value
* @throws NoSuchMappingException
Value getAndRemove(K key) {
Value result;
NodeConsumer action = [&](ConstNodePtr node) {
result = node->value;
if (root == nullptr) {
throw NoSuchMappingException("Invalid key");
} else {
if (remove(this->root, key, action)) {
return result;
} else {
throw NoSuchMappingException("Invalid key");
* Gets the entry corresponding to the specified key; if no such entry
* exists, returns the entry for the least key greater than the specified
* key; if no such entry exists (i.e., the greatest key in the Tree is less
* than the specified key), a {@code NoSuchMappingException} will be thrown.
* @param key
* @return AvlTreeMap<Key, Value>::Entry
* @throws NoSuchMappingException
Entry getCeilingEntry(K key) const {
if (this->root == nullptr) {
throw NoSuchMappingException("No ceiling entry in this map");
NodePtr node = this->root;
std::stack<NodePtr> ancestors;
while (node != nullptr) {
if (key == node->key) {
return node->entry();
if (compare(key, node->key)) {
/* key < node->key */
if (node->left != nullptr) {
node = node->left;
} else {
return node->entry();
} else {
/* key > node->key */
if (node->right != nullptr) {
node = node->right;
} else {
if (ancestors.empty()) {
throw NoSuchMappingException("No ceiling entry in this map");
NodePtr parent = ancestors.top();
while (node == parent->right) {
node = parent;
if (!ancestors.empty()) {
parent = ancestors.top();
} else {
throw NoSuchMappingException("No ceiling entry in this map");
return parent->entry();
throw NoSuchMappingException("No ceiling entry in this map");
* Gets the entry corresponding to the specified key; if no such entry exists,
* returns the entry for the greatest key less than the specified key;
* if no such entry exists, a {@code NoSuchMappingException} will be thrown.
* @param key
* @return AvlTreeMap<Key, Value>::Entry
* @throws NoSuchMappingException
Entry getFloorEntry(K key) const {
if (this->root == nullptr) {
throw NoSuchMappingException("No floor entry exists in this map");
NodePtr node = this->root;
std::stack<NodePtr> ancestors;
while (node != nullptr) {
if (key == node->key) {
return node->entry();
if (compare(key, node->key)) {
/* key < node->key */
if (node->left != nullptr) {
node = node->left;
} else {
if (ancestors.empty()) {
throw NoSuchMappingException("No floor entry exists in this map");
NodePtr parent = ancestors.top();
while (node == parent->left) {
node = parent;
if (!ancestors.empty()) {
parent = ancestors.top();
} else {
throw NoSuchMappingException("No floor entry exists in this map");
return parent->entry();
} else {
/* key > node->key */
if (node->right != nullptr) {
node = node->right;
} else {
return node->entry();
throw NoSuchMappingException("No floor entry exists in this map");
* Gets the entry for the least key greater than the specified
* key; if no such entry exists, returns the entry for the least
* key greater than the specified key; if no such entry exists,
* a {@code NoSuchMappingException} will be thrown.
* @param key
* @return AvlTreeMap<Key, Value>::Entry
* @throws NoSuchMappingException
Entry getHigherEntry(K key) {
if (this->root == nullptr) {
throw NoSuchMappingException("No higher entry exists in this map");
NodePtr node = this->root;
std::stack<NodePtr> ancestors;
while (node != nullptr) {
if (compare(key, node->key)) {
/* key < node->key */
if (node->left != nullptr) {
node = node->left;
} else {
return node->entry();
} else {
/* key >= node->key */
if (node->right != nullptr) {
node = node->right;
} else {
if (ancestors.empty()) {
throw NoSuchMappingException("No higher entry exists in this map");
NodePtr parent = ancestors.top();
while (node == parent->right) {
node = parent;
if (!ancestors.empty()) {
parent = ancestors.top();
} else {
throw NoSuchMappingException(
"No higher entry exists in this map");
return parent->entry();
throw NoSuchMappingException("No higher entry exists in this map");
* Returns the entry for the greatest key less than the specified key; if
* no such entry exists (i.e., the least key in the Tree is greater than
* the specified key), a {@code NoSuchMappingException} will be thrown.
* @param key
* @return AvlTreeMap<Key, Value>::Entry
* @throws NoSuchMappingException
Entry getLowerEntry(K key) const {
if (this->root == nullptr) {
throw NoSuchMappingException("No lower entry exists in this map");
NodePtr node = this->root;
std::stack<NodePtr> ancestors;
while (node != nullptr) {
if (compare(key, node->key) || key == node->key) {
/* key <= node->key */
if (node->left != nullptr) {
node = node->left;
} else {
if (ancestors.empty()) {
throw NoSuchMappingException("No lower entry exists in this map");
NodePtr parent = ancestors.top();
while (node == parent->left) {
node = parent;
if (!ancestors.empty()) {
parent = ancestors.top();
} else {
throw NoSuchMappingException("No lower entry exists in this map");
return parent->entry();
} else {
/* key > node->key */
if (node->right != nullptr) {
node = node->right;
} else {
return node->entry();
throw NoSuchMappingException("No lower entry exists in this map");
* Remove all entries that satisfy the filter condition.
* @param filter
void removeAll(KeyValueFilter filter) {
std::vector<Key> keys;
this->inorderTraversal([&](ConstNodePtr node) {
if (filter(node->key, node->value)) {
for (const Key &key : keys) {
* Performs the given action for each key and value entry in this map.
* The value is immutable for the action.
* @param action
void forEach(KeyValueConsumer action) const {
[&](ConstNodePtr node) {
action(node->key, node->value);
* Performs the given action for each key and value entry in this map.
* The value is mutable for the action.
* @param action
void forEachMut(MutKeyValueConsumer action) {
[&](ConstNodePtr node) {
action(node->key, node->value);
* Returns a list containing all of the entries in this map.
* @return AvlTreeMap<Key, Value>::EntryList
EntryList toEntryList() const {
EntryList entryList;
[&](ConstNodePtr node) {
return entryList;
static NodePtr rotateLeft(ConstNodePtr node) {
// clang-format off
// | |
// N S
// / \ l-rotate(N) / \
// L S ==========> N R
// / \ / \
// M R L M
NodePtr successor = node->right;
// clang-format on
node->right = successor->left;
successor->left = node;
return successor;
static NodePtr rotateRight(ConstNodePtr node) {
// clang-format off
// | |
// N S
// / \ r-rotate(N) / \
// S R ==========> L N
// / \ / \
// L M M R
NodePtr successor = node->left;
// clang-format on
node->left = successor->right;
successor->right = node;
return successor;
static void swapNode(NodePtr &lhs, NodePtr &rhs) {
std::swap(lhs->key, rhs->key);
std::swap(lhs->value, rhs->value);
std::swap(lhs, rhs);
static void fixBalance(NodePtr &node) {
if (node->factor() < -1) {
if (node->left->factor() < 0) {
// clang-format off
// Left-Left Case
// |
// C |
// / r-rotate(C) B
// B ==========> / \
// / A C
// A
// clang-format on
node = rotateRight(node);
} else {
// clang-format off
// Left-Right Case
// | |
// C C |
// / l-rotate(A) / r-rotate(C) B
// A ==========> B ==========> / \
// \ / A C
// B A
// clang-format on
node->left = rotateLeft(node->left);
node = rotateRight(node);
} else if (node->factor() > 1) {
if (node->right->factor() > 0) {
// clang-format off
// Right-Right Case
// |
// C |
// \ l-rotate(C) B
// B ==========> / \
// \ A C
// A
// clang-format on
node = rotateLeft(node);
} else {
// clang-format off
// Right-Left Case
// | |
// A A |
// \ r-rotate(C) \ l-rotate(A) B
// C ==========> B ==========> / \
// / \ A C
// B C
// clang-format on
node->right = rotateRight(node->right);
node = rotateLeft(node);
NodePtr getNodeOrProvide(NodePtr &node, K key, NodeProvider provide) {
assert(node != nullptr);
if (key == node->key) {
return node;
assert(key != node->key);
NodePtr result;
if (compare(key, node->key)) {
/* key < node->key */
if (node->left == nullptr) {
result = node->left = provide();
this->count += 1;
} else {
result = getNodeOrProvide(node->left, key, provide);
} else {
/* key > node->key */
if (node->right == nullptr) {
result = node->right = provide();
this->count += 1;
} else {
result = getNodeOrProvide(node->right, key, provide);
return result;
NodePtr getNode(ConstNodePtr node, K key) const {
assert(node != nullptr);
if (key == node->key) {
return node;
if (compare(key, node->key)) {
/* key < node->key */
return node->left == nullptr ? nullptr : getNode(node->left, key);
} else {
/* key > node->key */
return node->right == nullptr ? nullptr : getNode(node->right, key);
void insert(NodePtr &node, K key, V value, bool replace = true) {
assert(node != nullptr);
if (key == node->key) {
if (replace) {
node->value = value;
assert(key != node->key);
if (compare(key, node->key)) {
/* key < node->key */
if (node->left == nullptr) {
node->left = Node::from(key, value);
this->count += 1;
} else {
insert(node->left, key, value, replace);
} else {
/* key > node->key */
if (node->right == nullptr) {
node->right = Node::from(key, value);
this->count += 1;
} else {
insert(node->right, key, value, replace);
bool remove(NodePtr &node, K key, NodeConsumer action) {
assert(node != nullptr);
if (key != node->key) {
if (compare(key, node->key)) {
/* key < node->key */
NodePtr &left = node->left;
if (left != nullptr && remove(left, key, action)) {
return true;
} else {
return false;
} else {
/* key > node->key */
NodePtr &right = node->right;
if (right != nullptr && remove(right, key, action)) {
return true;
} else {
return false;
assert(key == node->key);
if (node->isLeaf()) {
// Case 1: no child
node = nullptr;
} else if (node->right == nullptr) {
// clang-format off
// Case 2: left child only
// P
// | remove(N) P
// N ========> |
// / L
// L
// clang-format on
node = node->left;
} else if (node->left == nullptr) {
// clang-format off
// Case 3: right child only
// P
// | remove(N) P
// N ========> |
// \ R
// R
// clang-format on
node = node->right;
} else if (node->right->left == nullptr) {
// clang-format off
// Case 4: both left and right child, right child has no left child
// | |
// N remove(N) R
// / \ ========> /
// L R L
// clang-format on
NodePtr right = node->right;
swapNode(node, right);
right->right = node->right;
node = right;
} else {
// clang-format off
// Case 5: both left and right child, right child is not a leaf
// Step 1. find the node N with the smallest key
// and its parent P on the right subtree
// Step 2. swap S and N
// Step 3. remove node N like Case 1 or Case 3
// Step 4. update height for P
// | |
// N S |
// / \ / \ S
// L .. swap(N, S) L .. remove(N) / \
// | =========> | ========> L ..
// P P |
// / \ / \ P
// S .. N .. / \
// \ \ R ..
// R R
// clang-format on
// Step 1
NodePtr successor = node->right;
NodePtr parent = node;
while (successor->left != nullptr) {
parent = successor;
successor = parent->left;
// Step 2
swapNode(node, successor);
// Step 3
parent->left = node->right;
// Restore node
node = successor;
// Step 4
this->count -= 1;
return true;
void inorderTraversal(NodeConsumer action) const {
if (this->root == nullptr) {
std::stack<NodePtr> stack;
NodePtr node = this->root;
while (node != nullptr || !stack.empty()) {
while (node != nullptr) {
node = node->left;
if (!stack.empty()) {
node = stack.top();
node = node->right;
No comments so far...