#include #include #include bool debug = false; class IntervalSet { public: IntervalSet() : _root(0) {} ~IntervalSet() { delete _root; } private: class Node { public: Node(long value) : from(value), to(value), _left(0), _right(0), _height(1) {} Node(long f, long t) : from(f), to(t), _left(0), _right(0), _height(1) {} ~Node() { delete _left; delete _right; } long from; long to; Node* _left; Node* _right; int _height; inline int height() { return this == 0 ? 0 : _height; } void calcHeight() { if (this == 0) return; int left_height = _left->height(); int right_height = _right->height(); _height = 1 + (left_height < right_height ? right_height : left_height); } void print() const { if (this == 0) return; _left->print(); printf(" %ld", from); if (from < to) printf("-%ld", to); _right->print(); } }; public: bool contains(long value) { for(Node* tree = _root; tree != 0;) if (value < tree->from) tree = tree->_left; else if (value > tree->to) tree = tree->_right; else return true; return false; } void addValue(long value) { addValue(_root, value, /*side*/' '); } void addInterval(long from, long to) { addInterval(_root, from, to); } void removeValue(long value) { removeValue(_root, value); } void removeInterval(long from, long to) { removeInterval(_root, from, to); } void print() { printf("["); _root->print(); printf("]\n"); } void printTree() { for (int i = 0; i < 100; i++) _print_buffer[i] = ' '; if (_root != 0) print(_root, ' ', 0, _root->_height); printf("\n"); } int unbalanced() { return unbalanced(_root); } class IntervalIterator { public: IntervalIterator(const IntervalSet& set) : _depth(-1) { Node* node = set._root; for (; node != 0; node = node->_left) _stack[++_depth] = node; } bool more() { return _depth >= 0; } void next() { if (_stack[_depth]->_right != 0) { Node* node = _stack[_depth--]->_right; for (; node != 0; node = node->_left) _stack[++_depth] = node; } else _depth--; } long from() { return _stack[_depth]->from; } long to() { return _stack[_depth]->to; } private: Node* _stack[50]; int _depth; }; class Iterator { public: Iterator(const IntervalSet& set) : _intervalIt(set) { _more = _intervalIt.more(); if (_more) _value = _intervalIt.from(); } long value() { return _value; } bool more() { return _more; } void next() { _value++; if (_value > _intervalIt.to()) { _intervalIt.next(); _more = _intervalIt.more(); if (_more) _value = _intervalIt.from(); } } private: IntervalIterator _intervalIt; bool _more; long _value; }; private: friend class IntervalIterator; void swapLeft(Node* &node) { Node* left = node->_left; node->_left = left->_right; node->calcHeight(); left->_right = node; left->calcHeight(); node = left; } void swapRight(Node* &node) { Node* right = node->_right; node->_right = right->_left; node->calcHeight(); right->_left = node; right->calcHeight(); node = right; } bool balanceRight(Node* &node) { if (node->_left->height() < node->_height - 2) { Node* &right = node->_right; if (right->_left->height() > right->_right->height()) swapLeft(right); swapRight(node); return true; } return false; } bool balanceLeft(Node* &node) { if (node->_right->height() < node->_height - 2) { Node* &left = node->_left; if (left->_right->height() > left->_left->height()) swapRight(left); swapLeft(node); return true; } return false; } void balance(Node* &node) { node->calcHeight(); while (balanceRight(node) || balanceLeft(node)) { } } bool getLeftTouching(Node* &node, long &from) { if (node->_right != 0) { if (getLeftTouching(node->_right, from)) { node->calcHeight(); balanceLeft(node); return true; } } else if (node->to + 1 == from) { from = node->from; Node* old_node = node; node = node->_left; old_node->_left = 0; delete old_node; return true; } return false; } bool getRightTouching(Node* &node, long &to) { if (node->_left != 0) { if (getRightTouching(node->_left, to)) { node->calcHeight(); balanceRight(node); return true; } } else if (to + 1 == node->from) { to = node->to; Node* old_node = node; node = node->_right; old_node->_right = 0; delete old_node; return true; } return false; } bool addValue(Node* &node, long value, const char side) { if (node == 0) { node = new Node(value); return true; } if (node->from <= value && value <= node->to) return false; // already included if (value < node->from - 1) { int old_height = node->_height; if (!addValue(node->_left, value, 'l')) return false; balance(node); return node->_height != old_height; } if (node->to + 1 < value) { int old_height = node->_height; if (!addValue(node->_right, value, 'h')) return false; balance(node); return node->_height != old_height; } if (value == node->from - 1) { node->from = value; if (node->_left != 0) getLeftTouching(node->_left, node->from); } else // node->to + 1 == value { node->to = value; if (node->_right != 0) getRightTouching(node->_right, node->to); } return true; } void leftMerge(Node* &node, long &from) { if (node == 0) return; if (node->to + 1 < from) leftMerge(node->_right, from); else { bool more_left = false; if (node->from < from) from = node->from; else if (from < node->from) more_left = true; Node* old_node = node; node = node->_left; old_node->_left = 0; delete old_node; if (more_left) leftMerge(node, from); } } void rightMerge(Node* &node, long &to) { if (node == 0) return; if (to < node->from - 1) rightMerge(node->_left, to); else { bool more_right = false; if (to < node->to) to = node->to; else if (node->to < to) more_right = true; Node* old_node = node; node = node->_right; old_node->_right = 0; delete old_node; if (more_right) rightMerge(node, to); } } bool addInterval(Node* &node, long from, long to) { if (node == 0) { node = new Node(from, to); return true; } if (node->from <= from && to <= node->to) return false; // already included if (to < node->from - 1) { int old_height = node->_height; if (!addInterval(node->_left, from, to)) return false; balance(node); return node->_height != old_height; } if (node->to + 1 < from) { int old_height = node->_height; if (!addInterval(node->_right, from, to)) return false; balance(node); return node->_height != old_height; } if (from < node->from) { node->from = from; leftMerge(node->_left, node->from); } if (node->to < to) { node->to = to; rightMerge(node->_right, node->to); } balance(node); return true; } char _print_buffer[100]; void print(const Node* node, char m, unsigned int parent_height, unsigned int root_height) { if (m == '\\') _print_buffer[parent_height] = '|'; if (node->_right != 0) print(node->_right, '/', node->_height, root_height); if (m != ' ') { unsigned int i = root_height; for (; i > parent_height; i--) printf("%c", _print_buffer[i]); printf("%c", m); i--; for (; i > node->_height; i--) printf("-"); } printf("* "); printf("%d-%d", node->from, node->to); printf("\n"); _print_buffer[parent_height] = m == '/' ? '|' : ' '; if (node->_left != 0) print(node->_left, '\\', node->_height, root_height); _print_buffer[parent_height] = ' '; } void getReplacement(Node* &node, Node* &result) { if (node->_left != 0) { getReplacement(node->_left, result); balanceRight(node); node->calcHeight(); } else { result = node; node = node->_right; result->_right = 0; // Not really needed } } void removeValue(Node* &node, long value) { if (node == 0) return; // value not found if (value < node->from) { removeValue(node->_left, value); balance(node); return; } if (node->to < value) { removeValue(node->_right, value); balance(node); return; } if (value == node->from) { if (value < node->to) { node->from++; return; } // The interval needs to be removed Node* old_node = node; if (node->_left == 0) node = node->_right; else if (node->_right == 0) node = node->_left; else { Node* right = node->_right; Node* replacement = 0; getReplacement(right, replacement); node = replacement; node->_right = right; node->_left = old_node->_left; } old_node->_left = 0; old_node->_right = 0; delete old_node; if (node != 0) balance(node); return; } if (value == node->to) { node->to--; return; } // Value is in the middle of the interval, we need to split it int left_height = node->_left->height(); int right_height = node->_right->height(); if (left_height < right_height) { addInterval(node->_left, node->from, value-1); node->from = value+1; } else { addInterval(node->_right, value+1, node->to); node->to = value-1; } balance(node); } void removeInterval(Node* &node, long from, long to) { if (node == 0) return; // interval not found if (to < node->from) { removeInterval(node->_left, from, to); balance(node); return; } if (node->to < from) { removeInterval(node->_right, from, to); balance(node); return; } if (from < node->from) removeInterval(node->_left, from, to); if (node->to < to) removeInterval(node->_right, from, to); if (from <= node->from && node->to <= to) { // The interval needs to be removed Node* old_node = node; if (node->_left == 0) node = node->_right; else if (node->_right == 0) node = node->_left; else { Node* right = node->_right; Node* replacement = 0; getReplacement(right, replacement); node = replacement; node->_right = right; node->_left = old_node->_left; } old_node->_left = 0; old_node->_right = 0; delete old_node; if (node != 0) balance(node); return; } if (node->from < from && to < node->to) { // Interval is in the middle of the interval, we need to split it int left_height = node->_left->height(); int right_height = node->_right->height(); if (left_height < right_height) { addInterval(node->_left, node->from, from-1); node->from = to+1; } else { addInterval(node->_right, to+1, node->to); node->to = from-1; } } else if (to < node->to) node->from = to+1; else node->to = from-1; balance(node); } int unbalanced(const Node* node) const { if (node == 0) return 0; int left_height = node->_left->height(); int right_height = node->_right->height(); if (node->_height != (left_height > right_height ? left_height : right_height)+1) { printf("ERROR: Height error %d %d %d at: ", node->_height, left_height, right_height); node->print(); printf("\n"); } int result = 0; if (left_height < right_height - 1) { result = right_height - 1 - left_height; if (debug) { printf("Not balanced %d %d at: ", left_height, right_height); node->print(); printf("\n"); } } else if (right_height < left_height - 1) { result = left_height - 1 - right_height; if (debug) { printf("Not balanced %d %d at: ", left_height, right_height); node->print(); printf("\n"); } } return result + unbalanced(node->_left) + unbalanced(node->_right); } Node* _root; }; #define TEST_RANGE 50 void initIntervalSet(IntervalSet &set, bool filled[], bool target[]) { for (long i = 0; i < TEST_RANGE; i++) { if (filled[i]) { long f = i; while (i+1 < TEST_RANGE && filled[i+1]) i++; if (f == i) set.addValue(i); else set.addInterval(f, i); } } for (long i = 0; i < TEST_RANGE; i++) target[i] = false; for (IntervalSet::Iterator it(set); it.more(); it.next()) target[it.value()] = true; for (long i = 0; i < TEST_RANGE; i++) if (target[i] != filled[i]) printf("ERROR: (it) %d %s\n", i, target[i] ? "t->f" : "f->t"); } void checkIntervalSet(IntervalSet &set, bool target[]) { bool result[TEST_RANGE]; for (int i = 0; i < TEST_RANGE; i++) result[i] = false; long prev = -10; for (IntervalSet::IntervalIterator it(set); it.more(); it.next()) { if (prev+1 == it.from()) printf("ERROR: Touching interval at %d\n", it.from()); for (long i = it.from(); i <= it.to(); i++) result[i] = true; prev = it.to(); } for (int i = 0; i < TEST_RANGE; i++) if (target[i] != result[i]) printf("ERROR: %d %s\n", i, target[i] ? "t->f" : "f->t"); printf("Unbalance %d\n", set.unbalanced()); } int main(int argc, char *argv[]) { srand (time(NULL)); //srand(1); bool filled[TEST_RANGE]; bool target[TEST_RANGE]; for (int i = 0; i < TEST_RANGE; i++) filled[i] = (rand()%100) < 50; for (int i = 0; i < TEST_RANGE; i++) if (filled[i]) printf("%d ", i); printf("\n"); for (long value = 0; value < TEST_RANGE; value++) { printf("Test addValue(%d)\n", value); IntervalSet set; initIntervalSet(set, filled, target); set.addValue(value); target[value] = true; checkIntervalSet(set, target); } for (long from = 0; from < TEST_RANGE-1; from++) for (long to = from; to < TEST_RANGE; to++) { printf("Test addInterval(%d, %d)\n", from, to); IntervalSet set; initIntervalSet(set, filled, target); set.addInterval(from, to); for (int i = from; i <= to; i++) target[i] = true; checkIntervalSet(set, target); } for (long value = 0; value < TEST_RANGE; value++) { printf("Test removeValue(%d,%d)\n", value); IntervalSet set; initIntervalSet(set, filled, target); set.removeValue(value); target[value] = false; checkIntervalSet(set, target); } for (long from = 0; from < TEST_RANGE-1; from++) for (long to = from; to < TEST_RANGE; to++) { printf("Test removeInterval(%d, %d)\n", from, to); IntervalSet set; initIntervalSet(set, filled, target); set.removeInterval(from, to); for (int i = from; i <= to; i++) target[i] = false; checkIntervalSet(set, target); } return 0; }