Skip to content

AVL Tree

基础的二分搜索树本身增删不保证平衡,AVL通过旋转保证了二分搜索树的绝对平衡

AVL树的实现

树定基础定义引用二分搜索树

class AVLTree<K : Comparable<K>, V> : BST<K, V> {
    private var root: BSTNode<K, V>? = null
    private var count = 0

    override fun size(): Int = this.count

    override fun getRoot(): BSTNode<K, V>? = this.root

    override fun insert(key: K, value: V) {
        root = add(root, key, value)
    }

    private fun add(node: BSTNode<K, V>?, k: K, v: V): BSTNode<K, V> {
        if (node == null) {
            count++
            return BasicBSTNode(k, v)
        }
        when {
            k < node.getKey() -> {
                node.setLeft(add(node.getLeft(), k, v)).updateHeight()
            }

            k > node.getKey() -> {
                node.setRight(add(node.getRight(), k, v)).updateHeight()
            }

            else -> {
                node.setValue(v)
            }
        }
        return node.updateHeight().rebalance(Action.ADD)
    }

    override fun remove(k: K): V? {
        return getNode(k)?.let { node ->
            val rtV = node.getValue()
            root = remove(root, k, rtV)
            rtV
        }
    }

    private fun remove(node: BSTNode<K, V>?, k: K, v: V): BSTNode<K, V>? {
        if (node == null) {
            return null
        }
        when {
            k < node.getKey() -> node.setLeft(remove(node.getLeft(), k, v))

            k > node.getKey() -> node.setRight(remove(node.getRight(), k, v))

            else -> {
                if (node.getLeft() == null && node.getRight() == null) {
                    count--
                    return null
                }

                // 剩余三种情况  左不空+右不空 左不空+右空  左空+右不空
                if (node.getLeft() != null) {
                    // 左不空+右不空 左不空+右空
                    val leftMax = getMax(node.getLeft())!!

                    node.setKey(leftMax.getKey())
                        .setValue(leftMax.getValue())
                    remove(node.getLeft(), leftMax.getKey(), leftMax.getValue())
                } else {
                    // 左空+右不空
                    val rightMin = getMin(node.getRight())!!
                    node.setKey(rightMin.getKey())
                        .setValue(rightMin.getValue())
                    node.setRight(remove(node.getRight(), rightMin.getKey(), rightMin.getValue()))
                }

            }
        }
        return node.updateHeight().rebalance(Action.REMOVE)
    }

    override fun clear() {
        this.root = null
        this.count = 0
    }

}

/**
 * Retrieves the height of the specified binary search tree (BST) node.
 * If the node is null, returns 0.
 *
 * @param node the node for which the height is to be determined; can be null
 * @return the height of the node, or 0 if the node is null
 */
internal fun <K : Comparable<K>, V> getNodeHeight(node: BSTNode<K, V>?): Int {
    return node?.getHeight() ?: 0
}

internal fun <K : Comparable<K>, V> BSTNode<K, V>.updateHeight(): BSTNode<K, V> {
    val max = max(getNodeHeight(this.getLeft()), getNodeHeight(this.getRight()))
    this.setHeight(max + 1)
    return this
}

关键的旋转逻辑

internal enum class Action {
    ADD, REMOVE
}

internal fun <K : Comparable<K>, V> BSTNode<K, V>.getBalanceFactor(): Int =
    getNodeHeight(this.getLeft()) - getNodeHeight(this.getRight())

internal fun Int.valueIn(min: Int, max: Int): Boolean {
    if (min > max) {
        throw IllegalArgumentException("min should be less than or equal to max")
    }
    return this in min..max
}

internal fun <K : Comparable<K>, V> BSTNode<K, V>.rebalance(action: Action): BSTNode<K, V> {
    val balanceFactor = this.getBalanceFactor()
    if (balanceFactor.valueIn(-1, 1)) {
        return this
    }

    log.debug("balanceFactor: {}, action: {}, node: {}", balanceFactor, action, this)
    return when (action) {
        Action.ADD -> {
            when {
                balanceFactor == 2 && this.getLeft()!!.getBalanceFactor() == 1 -> this.ll()
                balanceFactor == 2 && this.getLeft()!!.getBalanceFactor() == -1 -> this.lr()
                balanceFactor == -2 && this.getRight()!!.getBalanceFactor() == -1 -> this.rr()
                balanceFactor == -2 && this.getRight()!!.getBalanceFactor() == 1 -> this.rl()
                else -> throw IllegalStateException("balanceFactor: $balanceFactor, action: $action, node: $this")
            }
        }

        Action.REMOVE -> {
            when {
                balanceFactor == 2 && this.getLeft()!!.getBalanceFactor() >= 0 -> this.ll()
                balanceFactor == 2 && this.getLeft()!!.getBalanceFactor() < 0 -> this.lr()
                balanceFactor == -2 && this.getRight()!!.getBalanceFactor() <= 0 -> this.rr()
                balanceFactor == -2 && this.getRight()!!.getBalanceFactor() > 0 -> this.rl()
                else -> throw IllegalStateException("balanceFactor: $balanceFactor, action: $action, node: $this")
            }
        }
    }

}

internal fun <K : Comparable<K>, V> BSTNode<K, V>.ll(): BSTNode<K, V> {
    return this.rotateNodeAndLeft()
}

internal fun <K : Comparable<K>, V> BSTNode<K, V>.rr(): BSTNode<K, V> {
    return this.rotateNodeAndRight()
}

internal fun <K : Comparable<K>, V> BSTNode<K, V>.lr(): BSTNode<K, V> {
    /*
        x
       /
       y
        \
        z
     */
    return this.setLeft(
        this.getLeft()?.rotateNodeAndRight()
    ).updateHeight().rotateNodeAndLeft()
}

internal fun <K : Comparable<K>, V> BSTNode<K, V>.rl(): BSTNode<K, V> {
    /*
       x
        \
         y
        /
      z
    */

    return this.setRight(
        this.getRight()?.rotateNodeAndLeft()
    ).updateHeight().rotateNodeAndRight()
}

internal fun <K : Comparable<K>, V> BSTNode<K, V>.rotateNodeAndRight(): BSTNode<K, V> {
    /*
        n                 x
       / \               / \
      ?   x            n   b?
         / \          / \
        a?  b?       ?  a?
     */
    val x = this.getRight() ?: return this
    // 先更新下层节点的高度
    this.setRight(x.getLeft()).updateHeight()
    // 再更新新节点的高度
    return x.setLeft(this).updateHeight()
}

internal fun <K : Comparable<K>, V> BSTNode<K, V>.rotateNodeAndLeft(): BSTNode<K, V> {
    /*
            n              x
           / \             / \
          x   ?          a?  n
         / \                / \
        a?  b?             b?  ?
     */
    val x = this.getLeft() ?: return this
    this.setLeft(x.getRight()).updateHeight()
    return x.setRight(this).updateHeight()
}