#ifndef __BINARY_SEARCH_H__
#define __BINARY_SEARCH_H__
#include <assert.h>
#include <iostream>
template <typename Key, typename Value>
class BinarySearchTree;
template <typename Key, typename Value>
std::ostream& operator<<(std::ostream &out, BinarySearchTree<Key, Value>&);
template <typename Key, typename Value>
class BinarySearchTree {
friend std::ostream& operator<< <Key, Value>(std::ostream &out, BinarySearchTree<Key, Value> &tree);
struct Node {
Key key;
Value val;
Node *left;
Node *right;
Node(Key pkey, Value pval):key(pkey), val(pval), left(nullptr), right(nullptr) {}
};
Node *root;
void traverse(Node *node, std::ostream &out) {
if (node == nullptr)
return;
traverse(node->left, out);
out << "(" << node->key << ", " << node->val << ") ";
traverse(node->right, out);
}
public:
BinarySearchTree(Key pkey, Value pval) :root(new Node(pkey, pval)) {}
void put(Key key, Value val) {
Node **node = &root;
while (*node != nullptr)
{
if (key < (*node)->key)
node = &((*node)->left);
else if (key > (*node)->key)
node = &((*node)->right);
else
return;
}
if(*node == nullptr)
*node = new Node(key, val);
}
Value get(Key key) {
Node *node = root;
while (node != nullptr)
{
if (key < node->key)
node = node->left;
else if (key > node->key)
node = node->right;
else
return node->val;
}
assert(false);
}
Node* deleteMin(Node* head) {
if (head == nullptr)
return nullptr;
Node *node = head;
Node *lastNode = nullptr;
while (node->left != nullptr)
{
lastNode = node;
node = node->left;
}
lastNode->left = node->right;
//delete node;
return head;
}
Node* min(Node *head) {
if (head == nullptr)
return nullptr;
Node *node = head;
while (node->left != nullptr)
node = node->left;
return node;
}
void deleteNode(Key key) {
Node *lastNode = nullptr;
Node *node = root;
Node *newNode = nullptr;
while (node != nullptr)
{
if (key < node->key) {
lastNode = node;
node = node->left;
}
else if (key > node->key) {
lastNode = node;
node = node->right;
}
else {
Node **plastNode = nullptr;
// 注意樹根
if (lastNode == nullptr)
plastNode = &root;
else
plastNode = &lastNode;
// 無節點的情況
if (node->left == nullptr && node->right == nullptr)
{
if ((*plastNode)->left == node)
{
(*plastNode)->left = nullptr;
delete node;
return;
}
else if ((*plastNode)->right == node)
{
(*plastNode)->right = nullptr;
delete node;
return;
}
}
// 隻有一個節點的情況
if (node->left == nullptr) {
if ((*plastNode)->left == node) {
(*plastNode)->left = node->right;
delete node;
return;
}
else if ((*plastNode)->right == node) {
(*plastNode)->right = node->right;
delete node;
return;
}
}
if (node->right == nullptr) {
if ((*plastNode)->right == node) {
(*plastNode)->right = node->left;
delete node;
return;
}
else if ((*plastNode)->left == node) {
(*plastNode)->left = node->left;
delete node;
return;
}
}
// 兩個節點的情況
Node *star = min(node->right);
star->right = deleteMin(node->right);
star->left = node->left;
}
}
}
};
template <typename Key, typename Value>
std::ostream& operator<<(std::ostream &out, BinarySearchTree<Key, Value> &tree)
{
tree.traverse(tree.root, out);
return out;
}
#endif