Skip to content

Union Find

并查集,用于快速定位元素是否在同一个集合中

定义通用接口

interface UnionFind<T> {

    fun isEmpty(): Boolean {
        return size() == 0
    }

    fun size(): Int

    // 是否包含指定元素
    fun contains(value: T): Boolean

    fun add(value: T): Boolean

    // union
    fun union(value1: T, value2: T): Boolean

    // 交集判断
    fun isConnected(value1: T, value2: T): Boolean

    fun clear()

}

实现

设计两类实现、三种实现

基于数组索引

基于Tree

  • 元素本身存在唯一ID
  • 元素本身能保证唯一性

Indexed Union Find

class IndexedUnionFind<T> : UnionFind<T> {

    companion object {
        val log = getLogger(IndexedUnionFind::class.java)
    }

    private val indexFunc: (T) -> Int

    private var data: Array<Any?> = arrayOfNulls<Any?>(2)

    // 存储data的索引的元素的父节点索引
    private var parent = Array(2) { -1 }

    // 存储data的索引的元素的子节点索引列表
    private var children = Array<MutableList<Int>>(2) { mutableListOf() }

    private var count: Int = 0

    constructor(indexFunc: (T) -> Int) {
        this.indexFunc = indexFunc
    }

    override fun size(): Int = this.count

    override fun contains(value: T): Boolean {
        return false
    }

    override fun add(value: T): Boolean {
        val index = indexFunc(value)
        if (index < 0) {
            throw IllegalArgumentException("Index must be non-negative")
        }
        expand(index + 1)
        if (data[index] == null) {
            // 如果当前索引位置没有元素,则添加新元素
            data[index] = value
            parent[index] = index
            count++
        } else {
            // 如果当前索引位置已经有元素,替换
            data[index] = value
        }
        return true
    }

    private fun expand(newSize: Int) {
        if (newSize > data.size) {
            val newData = Array<Any?>(newSize) { null }
            val newParent = Array(newSize) { -1 }
            val newChildren = Array<MutableList<Int>>(newSize) { mutableListOf() }
            System.arraycopy(data, 0, newData, 0, data.size)
            System.arraycopy(parent, 0, newParent, 0, parent.size)
            System.arraycopy(children, 0, newChildren, 0, children.size)
            data = newData
            parent = newParent
            children = newChildren
        }
    }

    override fun union(value1: T, value2: T): Boolean {
        val v1Index = indexFunc(value1)
        val v2Index = indexFunc(value2)

        if (v1Index < 0) {
            log.equals("Index for value1 must be non-negative")
            return false
        }

        if (v2Index < 0) {
            log.equals("Index for value2 must be non-negative")
            return false
        }

        if (v1Index == v2Index) {
            log.debug("Both values are the same, no union needed")
            return false
        }

        // update parent when add
        this.add(value1)
        this.add(value2)

        this.union(v1Index, v2Index)
        return true
    }

    private fun union(v1Index: Int, v2Index: Int) {
        val v1ParentIndex = getParent(v1Index)
        val v2ParentIndex = getParent(v2Index)

        if (v1ParentIndex == v2ParentIndex) {
            return
        }

        children[v1ParentIndex].add(v2ParentIndex)
        children[v1ParentIndex].addAll(
            children[v2ParentIndex].apply {
                this.forEach {
                    // Update the parent of all children from v2 to v1
                    parent[it] = v1ParentIndex
                }
            }
        )
        children[v2ParentIndex].clear()

        parent[v2ParentIndex] = v1ParentIndex
        parent[v2Index] = v1ParentIndex

    }

    private fun getParent(valueIndex: Int): Int {
        return if (valueIndex.indexInArr(parent)) {
            parent[valueIndex]
        } else {
            -1
        }
    }

    private fun getParent(value: T): Int {
        val curIndex = indexFunc(value)
        return if (curIndex.indexInArr(parent)) {
            parent[curIndex]
        } else {
            -1
        }
    }

    override fun isConnected(value1: T, value2: T): Boolean {
        val v1Index = indexFunc(value1)
        val v2Index = indexFunc(value2)
        if (!v1Index.indexInArr(parent) || !v2Index.indexInArr(parent)) {
            return false
        }
        val parent1 = getParent(value1)
        val parent2 = getParent(value2)

        if (parent1 == -1 || parent2 == -1) {
            return false
        }
        return parent1 == parent2
    }

    override fun clear() {
        this.data = arrayOfNulls<Any?>(2)
        this.parent = Array(2) { -1 }
        this.children = Array(2) { mutableListOf() }
        this.count = 0
    }

    internal fun <T> Int.indexInArr(arr: Array<T>): Boolean {
        return this >= 0 && this < arr.size
    }

}

Tree Id Union Find

连接对象的唯一标识符

class TreeIdUnionFind<T, ID : Comparable<ID>>(
    private val identifierFunction: (T) -> ID
) : UnionFind<T> {

    private val storage: TreeMap<ID, Node> = TreeMap()

    override fun contains(value: T): Boolean {
        return doContains(value)
    }

    private fun doContains(data: T): Boolean {
        val id = getIdentifier(data) ?: return false
        return storage.containsKey(id)
    }

    override fun size(): Int {
        return storage.size
    }

    override fun add(value: T): Boolean {
        val identifier = getIdentifier(value) ?: return false
        return doAdd(Node(identifier), check = true) != null
    }

    private fun doAdd(addNode: Node, check: Boolean): Node? {
        val id = addNode.id
        return if (storage.containsKey(id)) {
            if (check) {
                null
            } else {
                storage[id]
            }
        } else {
            storage[id] = addNode
            addNode
        }
    }

    override fun union(value1: T, value2: T): Boolean {
        val xNode = doAdd(Node(getIdentifier(value1) ?: return false), check = false)
        val yNode = doAdd(Node(getIdentifier(value2) ?: return false), check = false)

        if (xNode == null || yNode == null) {
            return false
        }

        union(xNode, yNode)
        return true
    }

    private fun union(src: Node, cur: Node) {
        val srcParent = getParent(src)
        val curParent = getParent(cur)

        // 已经连接
        if (idEquals(srcParent.id, curParent.id)) {
            return
        }

        // 树压缩
        curParent.parent = srcParent
        srcParent.children.add(curParent)

        srcParent.children.addAll(curParent.children)
        curParent.children.forEach { child ->
            child.parent = srcParent
        }
        curParent.children.clear()
    }

    private fun getParent(node: Node): Node {
        return node.parent?.let { getParent(it) } ?: node
    }

    override fun isConnected(value1: T, value2: T): Boolean {
        if (!contains(value1) || !contains(value2)) {
            return false
        }

        val id1 = getIdentifier(value1)!!
        val id2 = getIdentifier(value2)!!
        val node1 = storage[id1]!!
        val node2 = storage[id2]!!

        return idEquals(getParent(node1).id, getParent(node2).id)
    }

    private fun getIdentifier(data: T): ID? {
        return identifierFunction(data)
    }

    // 先比较 equals 再比较 compareTo
    private fun idEquals(id1: ID, id2: ID): Boolean {
        return id1 == id2 || id1.compareTo(id2) == 0
    }

    override fun clear() {
        storage.clear()
    }

    /**
     * 内部节点类
     */
    private inner class Node(val id: ID) {
        var parent: Node? = null
        val children: MutableList<Node> = mutableListOf()
    }

}

Tree Union Find

连接对象本身

class TreeUnionFind<T> : UnionFind<T> {

    private val comparator: Comparator<T>
    private val storage: TreeMap<T, Node<T>>

    constructor(comparator: Comparator<T>) {
        this.comparator = comparator
        this.storage = TreeMap(comparator)
    }

    override fun size(): Int = this.storage.size

    override fun contains(value: T): Boolean {
        return this.storage.containsKey(value)
    }

    override fun add(value: T): Boolean {
        if (this.storage.containsKey(value)) {
            return false
        }
        doAdd(value)
        return true
    }

    private fun doAdd(value: T): Node<T> {
        val node = storage[value]
        return if (node == null) {
            storage[value] = Node(value)
            storage[value]!!
        } else {
            node
        }
    }

    override fun union(value1: T, value2: T): Boolean {
        val node1 = doAdd(value1)
        val node2 = doAdd(value2)

        if (comparator.compare(node1.data, node2.data) == 0) {
            return false
        }

        val parent1 = getParent(node1)
        val parent2 = getParent(node2)

        if (comparator.compare(parent1.data, parent2.data) == 0) {
            return true
        }

        parent1.children.addAll(parent2.children.apply {
            forEach { it.parent = parent1 }
        })
        parent2.children.clear()

        parent2.parent = parent1
        parent1.children.add(parent2)

        return true
    }

    private fun getParent(node: Node<T>): Node<T> {
        var current = node
        while (current.parent != null) {
            current = current.parent!!
        }
        return current
    }

    override fun isConnected(value1: T, value2: T): Boolean {
        val node1 = storage[value1] ?: return false
        val node2 = storage[value2] ?: return false

        val parent1 = getParent(node1)
        val parent2 = getParent(node2)

        return comparator.compare(parent1.data, parent2.data) == 0
    }

    override fun clear() {
        this.storage.clear()
    }

    inner class Node<T> {
        var parent: Node<T>? = null

        val data: T
        val children: MutableList<Node<T>> = mutableListOf()

        constructor(data: T) {
            this.data = data
        }
    }

}