Saturday, July 6, 2013

Generic AVL Tree Implementation in Java

UPDATE: Thanks to +Josh Dotson, who figured out the big mistakes I made, which resulted the AVL insert into a O(n) implementation. PS: depth of the tree is important.

Speaking of AVL Tree, I guess most of people with Computer Science(CS) background would not be unfamiliar with it. It's one of most famous self balanced binary search tree exists so far. So for the basic background information, please check on this Wiki Link. One thing very important is every basic action of AVL Tree takes O(logn). That is so wonderful.

I assume people read this blog already knew what Java Generic Programming means, even he/she doesn't know how to implement it. I believe this code is pretty easy to understand.

First of all, What do we store in our tree, so let's define our data type for our data to be stored in our tree. We typically name it as Node.

```/**
*
*/

/**
* @author antonio081014
* @time Jul 5, 2013, 9:31:32 PM
*/
public class Node<T extends Comparable<T>> implements Comparable<Node<T>> {

private T data;
private Node<T> left;
private Node<T> right;
public int level;
private int depth;

public Node(T data) {
this(data, null, null);
}

public Node(T data, Node<T> left, Node<T> right) {
super();
this.data = data;
this.left = left;
this.right = right;
if (left == null && right == null)
setDepth(1);
else if (left == null)
setDepth(right.getDepth() + 1);
else if (right == null)
setDepth(left.getDepth() + 1);
else
setDepth(Math.max(left.getDepth(), right.getDepth()) + 1);
}

public T getData() {
return data;
}

public void setData(T data) {
this.data = data;
}

public Node<T> getLeft() {
return left;
}

public void setLeft(Node<T> left) {
this.left = left;
}

public Node<T> getRight() {
return right;
}

public void setRight(Node<T> right) {
this.right = right;
}

/**
* @return the depth
*/
public int getDepth() {
return depth;
}

/**
* @param depth
*            the depth to set
*/
public void setDepth(int depth) {
this.depth = depth;
}

@Override
public int compareTo(Node<T> o) {
return this.data.compareTo(o.data);
}

@Override
public String toString() {
return "Level " + level + ": " + data;
}

}
```

Basically, this data is comprised of one generic data, one left child and one right child. This is just like the general binary tree, two children with one data to be stored. But we need our generic data can be comparable, so we can search them and have them in order.

Second, let's use this Node structure to build our tree.

```import java.util.LinkedList;
import java.util.Queue;

public class AVLTree<T extends Comparable<T>> {
Node<T> root;

public AVLTree() {
root = null;
}

public T Maximum() {
Node<T> local = root;
if (local == null)
return null;
while (local.getRight() != null)
local = local.getRight();
return local.getData();
}

public T Minimum() {
Node<T> local = root;
if (local == null)
return null;
while (local.getLeft() != null) {
local = local.getLeft();
}
return local.getData();
}

private int depth(Node<T> node) {
if (node == null)
return 0;
return node.getDepth();
// 1 + Math.max(depth(node.getLeft()), depth(node.getRight()));
}

public Node<T> insert(T data) {
root = insert(root, data);
switch (balanceNumber(root)) {
case 1:
root = rotateLeft(root);
break;
case -1:
root = rotateRight(root);
break;
default:
break;
}
return root;
}

public Node<T> insert(Node<T> node, T data) {
if (node == null)
return new Node<T>(data);
if (node.getData().compareTo(data) > 0) {
node = new Node<T>(node.getData(), insert(node.getLeft(), data),
node.getRight());
// node.setLeft(insert(node.getLeft(), data));
} else if (node.getData().compareTo(data) < 0) {
// node.setRight(insert(node.getRight(), data));
node = new Node<T>(node.getData(), node.getLeft(), insert(
node.getRight(), data));
}
// After insert the new node, check and rebalance the current node if
// necessary.
switch (balanceNumber(node)) {
case 1:
node = rotateLeft(node);
break;
case -1:
node = rotateRight(node);
break;
default:
return node;
}
return node;
}

private int balanceNumber(Node<T> node) {
int L = depth(node.getLeft());
int R = depth(node.getRight());
if (L - R >= 2)
return -1;
else if (L - R <= -2)
return 1;
return 0;
}

private Node<T> rotateLeft(Node<T> node) {
Node<T> q = node;
Node<T> p = q.getRight();
Node<T> c = q.getLeft();
Node<T> a = p.getLeft();
Node<T> b = p.getRight();
q = new Node<T>(q.getData(), c, a);
p = new Node<T>(p.getData(), q, b);
return p;
}

private Node<T> rotateRight(Node<T> node) {
Node<T> q = node;
Node<T> p = q.getLeft();
Node<T> c = q.getRight();
Node<T> a = p.getLeft();
Node<T> b = p.getRight();
q = new Node<T>(q.getData(), b, c);
p = new Node<T>(p.getData(), a, q);
return p;
}

public boolean search(T data) {
Node<T> local = root;
while (local != null) {
if (local.getData().compareTo(data) == 0)
return true;
else if (local.getData().compareTo(data) > 0)
local = local.getLeft();
else
local = local.getRight();
}
return false;
}

public String toString() {
return root.toString();
}

public void PrintTree() {
root.level = 0;
while (!queue.isEmpty()) {
Node<T> node = queue.poll();
System.out.println(node);
int level = node.level;
Node<T> left = node.getLeft();
Node<T> right = node.getRight();
if (left != null) {
left.level = level + 1;
}
if (right != null) {
right.level = level + 1;
}
}
}
}
```

For AVLTree class, we need a root node to let user know where this tree starts. Each function works the way as the method name suggested, insert is to insert the new node to our tree, maximum is to get the maximum value of the tree and minimum if to get the minimum value of the tree.

For the insert function, every time we insert a new node, we need to check if every parent of the new node is still balanced. If yes, we'll have no problem with it, otherwise, we have to make some action like rotateLeft, rotateRight to have it rebalanced.

For the search, maximum and minimum function, each use recursive way or non-recursive way to go from root node to leave node to find what they need.

For the PrintTree function, it uses a CS tech called Breadth-First-Search(BFS) to go through our tree level by level. At the same time, it will print each node to system output.

For rotation of AVL Tree, this picture on Wiki explains a lot.

There are four cases you need to take care to make a rotation as the picture above indicates. My code uses a DFS recursive call to insert the new node into our tree, if new node is inserted into 4's subtree, then we check on 4's parent to see if it's balanced, if not, we'll make the rotation. Then check on the 4's parant's parant, rotate it if it's unbalanced, keep checking until your function reach back to "root".

So, for testing our AVL Tree API, I made a simple demo here.

```/**
* @author antonio081014
* @time Jul 5, 2013, 9:31:10 PM
*/
public class Main {

public static void main(String[] args) {
AVLTree<Integer> tree = new AVLTree<Integer>();
for (int i = 1; i <= 7; i++)
tree.insert(new Integer(i));
tree.PrintTree();
}
}
```
The output of this:
Level 0: 4
Level 1: 2
Level 1: 6
Level 2: 1
Level 2: 3
Level 2: 5

Level 2: 7
To make this more graphically, it will looks like:
4
/   \
2     6
/  \   /  \
1  3  5  7
This is pretty much what we need.

The user could use other Object type, like String, Double, rather than primitive types like int, double.

For a more clear view, visit my gist page.