#include #include #include template class PersistentSet { class Node { public: Node(const T& data) : _left(0), _right(0), _data(data), _ref_count(1), _height(1) {} Node(const Node* node) : _left(node->_left), _right(node->_right), _data(node->_data), _ref_count(1), _height(node->_height) { if (_left != 0) _left->_ref_count++; if (_right != 0) _right->_ref_count++; } void purge() { if (--_ref_count == 0) { if (_left != 0) _left->purge(); if (_right != 0) _right->purge(); delete this; } } inline unsigned int height() { return this == 0 ? 0 : _height; } void calcHeight() { if (this == 0) return; int left_height = _left->height(); int right_height = _right->height(); int height = 1 + (left_height < right_height ? right_height : left_height); _height = height; } Node* _left; Node* _right; T _data; unsigned long _ref_count; unsigned int _height; }; void claim(Node* &node) { if (node != 0 && node->_ref_count > 1) { Node* new_node = new Node(node); node->_ref_count--; node = new_node; } } void swapLeft(Node* &node) { claim(node->_left); Node* left = node->_left; node->_left = left->_right; node->calcHeight(); left->_right = node; left->calcHeight(); node = left; } void swapRight(Node* &node) { claim(node->_right); Node* right = node->_right; node->_right = right->_left; node->calcHeight(); right->_left = node; right->calcHeight(); node = right; } void insert(Node* &node, const T& data) { if (node == 0) { node = new Node(data); return; } int c = data.compare(node->_data); if (c == 0) return; claim(node); if (c < 0) { insert(node->_left, data); Node* left = node->_left; if (node->_height == left->_height) if (left->_height == 1 || node->_right->height() == _tree->_height - 1) node->_height++; else { if (left->_left->height() < left->_height - 1 && left->_right->height() == left->_height - 1) swapRight(node->_left); swapLeft(node); } } else { insert(node->_right, data); Node* right = node->_right; if (node->_height == right->_height) if (right->_height == 1 || node->_left->height() == _tree->_height - 1) node->_height++; else { if (right->_right->height() < right->_height - 1 && right->_left->height() == right->_height - 1) swapLeft(node->_right); swapRight(node); } } } void getReplacement(Node* &node, Node* &result) { claim(node); if (node->_left != 0) { getReplacement(node->_left, result); if (node->_height > 2 && node->_left->height() < node->_right->height()) { Node* &right = node->_right; claim(right); if (right->_right->height() < right->_height - 1 && right->_left->height() == right->_height - 1) swapLeft(right); swapRight(node); } node->calcHeight(); } else { result = node; node = node->_right; result->_right = 0; // Not really needed } } void remove(Node* &node, const T& data) { if (node == 0) return; // not found claim(node); int c = data.compare(node->_data); if (c != 0) { if (c < 0) { remove(node->_left, data); if (node->height() > 2 && node->_left->height() < node->_right->height()) { Node* &right = node->_right; if (right->_right->height() < right->_height - 1 && right->_left->height() == right->_height - 1) { claim(right); swapLeft(right); } if (right->_left->height() < node->_height - 1) swapRight(node); } } else { remove(node->_right, data); if (node->height() > 2 && node->_right->height() < node->_left->height()) { Node* &left = node->_left; if (left->_left->height() < left->_height - 1 && left->_right->height() == left->_height - 1) { claim(left); swapRight(left); } if (left->_right->height() < node->_height - 1) swapLeft(node); } } node->calcHeight(); return; } // Found the node to be removed. Node* old_node = node; if (node->_left == 0) node = node->_right; else if (node->_right == 0) node = node->_left; else { Node* right = node->_right; Node* replacement = 0; getReplacement(right, replacement); node = replacement; node->_right = right; node->_left = old_node->_left; } if (node != 0) { node->calcHeight(); if (node->_right->height() < node->_left->height()) { Node* &left = node->_left; if (left->_left->height() < left->_height - 1 && left->_right->height() == left->_height - 1) { claim(left); swapRight(left); } if (left->_right->height() < node->_height - 1) swapLeft(node); } } printf("delete %p: ", old_node); old_node->_data.print(); printf("\n"); delete old_node; } char _print_buffer[100]; void print(Node* node, char m, unsigned int parent_height, unsigned int root_height) { if (m == '\\') _print_buffer[parent_height] = '|'; if (node->_right != 0) print(node->_right, '/', node->_height, root_height); if (m != ' ') { unsigned int i = root_height; for (; i > parent_height; i--) printf("%c", _print_buffer[i]); printf("%c", m); i--; for (; i > node->_height; i--) printf("-"); } printf("* "); node->_data.print(); if (node->_ref_count > 1) printf("(%d)", node->_ref_count); printf("\n"); _print_buffer[parent_height] = m == '/' ? '|' : ' '; if (node->_left != 0) print(node->_left, '\\', node->_height, root_height); _print_buffer[parent_height] = ' '; } public: PersistentSet() : _tree(0) {} ~PersistentSet() { if (_tree != 0) _tree->purge(); } PersistentSet(const PersistentSet& lhs) : _tree(0) { if (lhs->_tree != 0) { _tree = lhs->_tree; _tree->_ref_count++; } } PersistentSet& operator=(const PersistentSet& lhs) { Node* _old_tree = _tree; _tree = 0; if (lhs._tree != 0) { _tree = lhs._tree; _tree->_ref_count++; } if (_old_tree != 0) _old_tree->purge(); return *this; } void insert(const T& data) { insert(_tree, data); } void remove(const T& data) { remove(_tree, data); } void print() { for (int i = 0; i < 100; i++) _print_buffer[i] = ' '; if (_tree != 0) print(_tree, ' ', 0, _tree->_height); printf("\n"); } class Iterator { public: Iterator(const PersistentSet& set) : _depth(-1) { Node* node = set._tree; for (; node != 0; node = node->_left) _stack[++_depth] = node; } bool more() { return _depth >= 0; } void next() { if (_stack[_depth]->_right != 0) { Node* node = _stack[_depth--]->_right; for (; node != 0; node = node->_left) _stack[++_depth] = node; } else _depth--; } T& operator*() { return _stack[_depth]->_data; } private: Node* _stack[50]; int _depth; }; private: Node* _tree; }; class Integer { public: Integer() : _i(0) {} Integer(int i) : _i(i) {} int compare(const Integer& lhs) const { if (_i < lhs._i) return -1; if (_i > lhs._i) return 1; return 0; } int value() const { return _i; } void print() { printf("%d", _i); } private: int _i; }; int main(int argc, char *argv[]) { srand (time(NULL)); for (int j = 0; j < 100; j++) { PersistentSet intSet1; for (int i = 0; i < 100; i++) { int v = rand()%200; intSet1.insert(Integer(2*v)); } intSet1.print(); printf("intSet1"); int org_count = 0; for (PersistentSet::Iterator it(intSet1); it.more(); it.next(), org_count++) printf(" %d", (*it).value()); printf(" count = %d\n", org_count); PersistentSet intSet2; PersistentSet intSet3; intSet2 = intSet1; for (int i = 0; i < 100; i++) { int v = (rand()%200)*2+1; intSet2.insert(Integer(v)); intSet3.insert(Integer(v)); } printf("intSet1"); int count1 = 0; for (PersistentSet::Iterator it(intSet1); it.more(); it.next(), count1++) printf(" %d%s", (*it).value(), (*it).value()%2 == 1 ? "*" : ""); printf(" count = %d\n", count1); if (count1 != org_count) printf("ERROR count1 %d org_count %d\n", count1, org_count); intSet1.print(); printf("intSet3"); int count3 = 0; for (PersistentSet::Iterator it(intSet3); it.more(); it.next(), count3++) printf(" %d", (*it).value()); printf(" count = %d\n", count3); printf("intSet2"); int count2 = 0; for (PersistentSet::Iterator it(intSet2); it.more(); it.next(), count2++) printf(" %d", (*it).value()); printf(" count = %d\n", count2); intSet2.print(); if (count2 != count1 + count3) printf("ERROR count2 %d count1 %d count3 %d\n", count2, count1, count3); int values[300]; int nr = 0; for (PersistentSet::Iterator it(intSet1); it.more(); it.next()) values[nr++] = (*it).value(); for (; nr > 0; nr--) { int r = rand() % nr; intSet2.remove(values[r]); for (int i = r; i+1 < nr; i++) values[i] = values[i+1]; } printf("intSet2"); count2 = 0; for (PersistentSet::Iterator it(intSet2); it.more(); it.next(), count2++) printf(" %d", (*it).value()); printf(" count = %d\n", count2); if (count3 != count2) printf("ERROR count3 %d count2 %d\n", count3, count2); intSet2.print(); } return 0; }