#include #include #include //#define MAX_SHARING #ifdef MAX_SHARING #define CHECK(X) ((X)->_ref_count == 1) #else #define CHECK(X) true #endif #define MAX_TREE_DEPTH 100 template class PersistentSet { class Node { public: Node(const T& data) : _left(0), _right(0), _data(data), _ref_count(1), _height(1) {} Node(const Node* node) : _left(node->_left), _right(node->_right), _data(node->_data), _ref_count(1), _height(node->_height) { if (_left != 0) _left->_ref_count++; if (_right != 0) _right->_ref_count++; } void purge() { if (--_ref_count == 0) { if (_left != 0) _left->purge(); if (_right != 0) _right->purge(); delete this; } } inline unsigned int height() { return this == 0 ? 0 : _height; } void calcHeight() { if (this == 0) return; int left_height = _left->height(); int right_height = _right->height(); _height = 1 + (left_height < right_height ? right_height : left_height); } Node* _left; Node* _right; T _data; unsigned long _ref_count; unsigned int _height; }; void claim(Node* &node) { if (node->_ref_count > 1) { Node* new_node = new Node(node); node->_ref_count--; node = new_node; } } void swapLeft(Node* &node) { Node* left = node->_left; node->_left = left->_right; node->calcHeight(); left->_right = node; left->calcHeight(); node = left; } void swapRight(Node* &node) { Node* right = node->_right; node->_right = right->_left; node->calcHeight(); right->_left = node; right->calcHeight(); node = right; } void insert(Node* &node, const T& data, const char side) { if (node == 0) { node = new Node(data); return; } int c = data.compare(node->_data); if (c == 0) return; claim(node); if (c < 0) { insert(node->_left, data, 'L'); if (node->_height == node->_left->_height) { if (side == 'R' || (node->_right->height() < node->_height-1 && node->_left->_right->height() < node->_height-1)) swapLeft(node); else node->_height++; } } else { insert(node->_right, data, 'R'); Node* right = node->_right; if (node->_height == node->_right->_height) { if (side == 'L' || (node->_left->height() < node->_height-1 && node->_right->_left->height() < node->_height-1)) swapRight(node); else node->_height++; } } } void balanceRight(Node* &node) { if (node->_left->height() < node->_height - 2 && CHECK(node->_right)) { Node* &right = node->_right; claim(right); if (right->_left->height() > right->_right->height() && CHECK(right->_left)) { claim(right->_left); swapLeft(right); } swapRight(node); } } void balanceLeft(Node* &node) { if (node->_right->height() < node->_height - 2 && CHECK(node->_left)) { Node* &left = node->_left; claim(left); if (left->_right->height() > left->_left->height() && CHECK(left->_right)) { claim(left->_right); swapRight(left); } swapLeft(node); } } void getReplacement(Node* &node, Node* &result) { claim(node); if (node->_left != 0) { getReplacement(node->_left, result); balanceRight(node); node->calcHeight(); } else { result = node; node = node->_right; result->_right = 0; // Not really needed } } void remove(Node* &node, const T& data) { if (node == 0) return; // not found claim(node); int c = data.compare(node->_data); if (c != 0) { if (c < 0) { remove(node->_left, data); balanceRight(node); } else { remove(node->_right, data); balanceLeft(node); } node->calcHeight(); return; } // Found the node to be removed. Node* old_node = node; if (node->_left == 0) node = node->_right; else if (node->_right == 0) node = node->_left; else { Node* right = node->_right; Node* replacement = 0; getReplacement(right, replacement); node = replacement; node->_right = right; node->_left = old_node->_left; } if (node != 0) { node->calcHeight(); if (node->_height > 1) balanceLeft(node); } delete old_node; } char _print_buffer[MAX_TREE_DEPTH]; void print(const Node* node, char m, unsigned int parent_height, unsigned int root_height) { if (m == '\\') _print_buffer[parent_height] = '|'; if (node->_right != 0) print(node->_right, '/', node->_height, root_height); if (m != ' ') { unsigned int i = root_height; for (; i > parent_height; i--) printf("%c", _print_buffer[i]); printf("%c", m); i--; for (; i > node->_height; i--) printf("-"); } printf("* "); node->_data.print(); if (node->_ref_count > 1) printf("(%d)", node->_ref_count); printf("\n"); _print_buffer[parent_height] = m == '/' ? '|' : ' '; if (node->_left != 0) print(node->_left, '\\', node->_height, root_height); _print_buffer[parent_height] = ' '; } int unbalanced(const Node* node) const { if (node == 0) return 0; int left_height = node->_left->height(); int right_height = node->_right->height(); if (node->_height != (left_height > right_height ? left_height : right_height)+1) { printf("ERROR: Height error %d %d %d at: ", node->_height, left_height, right_height); node->_data.print(); printf("\n"); } int result = 0; if (left_height < right_height - 1) { result = right_height - 1 - left_height; } else if (right_height < left_height - 1) { result = left_height - 1 - right_height; } return result + unbalanced(node->_left) + unbalanced(node->_right); } public: PersistentSet() : _tree(0) {} ~PersistentSet() { if (_tree != 0) _tree->purge(); } PersistentSet(const PersistentSet& lhs) : _tree(0) { if (lhs->_tree != 0) { _tree = lhs->_tree; _tree->_ref_count++; } } PersistentSet& operator=(const PersistentSet& lhs) { Node* _old_tree = _tree; _tree = 0; if (lhs._tree != 0) { _tree = lhs._tree; _tree->_ref_count++; } if (_old_tree != 0) _old_tree->purge(); return *this; } void insert(const T& data) { insert(_tree, data, ' '); } void remove(const T& data) { remove(_tree, data); } void print() { for (int i = 0; i < MAX_TREE_DEPTH; i++) _print_buffer[i] = ' '; if (_tree != 0) print(_tree, ' ', 0, _tree->_height); printf("\n"); } int unbalanced() { return unbalanced(_tree); } class Iterator { public: Iterator(const PersistentSet& set) : _depth(-1) { Node* node = set._tree; for (; node != 0; node = node->_left) _stack[++_depth] = node; } bool more() { return _depth >= 0; } void next() { if (_stack[_depth]->_right != 0) { Node* node = _stack[_depth--]->_right; for (; node != 0; node = node->_left) _stack[++_depth] = node; } else _depth--; } T& operator*() { return _stack[_depth]->_data; } private: Node* _stack[MAX_TREE_DEPTH]; int _depth; }; private: Node* _tree; }; class Integer { public: Integer() : _i(0) {} Integer(int i) : _i(i) {} int compare(const Integer& lhs) const { if (_i < lhs._i) return -1; if (_i > lhs._i) return 1; return 0; } int value() const { return _i; } void print() const { printf("%d", _i); } private: int _i; }; int main(int argc, char *argv[]) { srand (time(NULL)); //srand(1); for (int j = 0; j < 100; j++) { PersistentSet intSet1; printf("Info: Insert 100 random in set1\n"); for (int i = 0; i < 100; i++) { int v = rand()%200; intSet1.insert(Integer(2*v)); int unb = intSet1.unbalanced(); if (unb > 0) printf("Warning: set1 %d not balanced after %d\n", unb, i); } intSet1.print(); printf("intSet1"); int org_count = 0; for (PersistentSet::Iterator it(intSet1); it.more(); it.next(), org_count++) printf(" %d", (*it).value()); printf(" count = %d\n", org_count); printf("Info: set2 = set1, insert 100 random in set2 and set3\n"); PersistentSet intSet2; PersistentSet intSet3; intSet2 = intSet1; for (int i = 0; i < 100; i++) { int v = (rand()%200)*2+1; intSet2.insert(Integer(v)); int unb = intSet2.unbalanced(); if (unb > 0) printf("Warning: set2 %d not balanced after %d\n", unb, i); intSet3.insert(Integer(v)); unb = intSet3.unbalanced(); if (unb > 0) printf("Warning: set3 %d not balanced after %d\n", unb, i); } printf("intSet1"); int count1 = 0; for (PersistentSet::Iterator it(intSet1); it.more(); it.next(), count1++) printf(" %d%s", (*it).value(), (*it).value()%2 == 1 ? "*" : ""); printf(" count = %d\n", count1); if (count1 != org_count) printf("ERROR count1 %d org_count %d\n", count1, org_count); intSet1.print(); printf("intSet3"); int count3 = 0; for (PersistentSet::Iterator it(intSet3); it.more(); it.next(), count3++) printf(" %d", (*it).value()); printf(" count = %d\n", count3); printf("intSet2"); int count2 = 0; for (PersistentSet::Iterator it(intSet2); it.more(); it.next(), count2++) printf(" %d", (*it).value()); printf(" count = %d\n", count2); intSet2.print(); if (count2 != count1 + count3) printf("ERROR count2 %d count1 %d count3 %d\n", count2, count1, count3); PersistentSet intSet4; intSet4 = intSet2; printf("Info: remove numbers from set1 from set2\n"); int values[300]; int nr = 0; for (PersistentSet::Iterator it(intSet1); it.more(); it.next()) values[nr++] = (*it).value(); for (; nr > 0; nr--) { int r = rand() % nr; printf("Info: %d remove %d\n", nr, values[r]); intSet4.remove(values[r]); for (int i = r; i+1 < nr; i++) values[i] = values[i+1]; int unb = intSet4.unbalanced(); if (unb != 0) printf("Warning: set2 %d not balanced after %d\n", unb, nr); } printf("intSet4"); int count4 = 0; for (PersistentSet::Iterator it(intSet4); it.more(); it.next(), count4++) printf(" %d", (*it).value()); printf(" count = %d\n", count4); if (count4 != count3) printf("ERROR count4 %d count3 %d\n", count4, count3); intSet4.print(); printf("Info: add numbers from set1 from set2\n"); nr = 0; for (PersistentSet::Iterator it(intSet1); it.more(); it.next()) values[nr++] = (*it).value(); for (; nr > 0; nr--) { int r = rand() % nr; printf("Info: %d remove %d\n", nr, values[r]); intSet4.insert(values[r]); for (int i = r; i+1 < nr; i++) values[i] = values[i+1]; int unb = intSet4.unbalanced(); if (unb != 0) printf("Warning: set2 %d not balanced after %d\n", unb, nr); } printf("intSet4"); count4 = 0; for (PersistentSet::Iterator it(intSet4); it.more(); it.next(), count4++) printf(" %d", (*it).value()); printf(" count = %d\n", count4); if (count4 != count2) printf("ERROR count4 %d count2 %d\n", count4, count2); intSet4.print(); } return 0; }