天天看點

用Java實作線段樹

線段樹是為區間更新和區間查詢而生的資料結構,旨在快速解決區間問題。

一般來說,線段樹是不會加節點的,也不支援動态添加節點。線段樹也是二叉樹的一種,不過它的節點是以一個區間來定義節點的。具有一個單一區間的就是葉子節點。是以線段樹,本質上就是一棵區間樹。

我們在查找的時候,隻需要找出結果區間由哪些子區間構成即可。

實作代碼

首先定義出基礎的結構

public class SegmentTree {
    
    private Integer value;
    private Integer maxValue;

    private Integer l;
    private Integer r;
    
    private SegmentTree leftChild;
    private SegmentTree rightChild;
}      

l和r用來唯一刻畫這個區間。然後其他的内容,與标準的二叉樹沒得任何差別。

建樹過程

與二叉樹建樹沒得差別,我們這裡采用前序建樹的方式進行。代碼如下:

public static SegmentTree buildTree(int left, int right, int[] value) {
    if (left > right) {
        return null;
    }

    SegmentTree node = new SegmentTree();
    node.setValue(value[left]);
    node.setL(left);
    node.setR(right);
    if (left == right) {
        // TODO: 2022/1/17 退出條件
        node.setMaxValue(node.getValue());
        return node;
    }
    int mid = (left + right) >>> 1;
    node.setLeftChild(buildTree(left, mid, value));
    node.setRightChild(buildTree(mid + 1, right, value));
    if (Objects.isNull(node.getLeftChild())) {
        if (Objects.isNull(node.getRightChild())) {
            node.setMaxValue(node.getValue());
        } else {
            node.setMaxValue(node.getRightChild().getMaxValue());
        }
    } else {
        if (Objects.isNull(node.getRightChild())) {
            node.setMaxValue(node.getLeftChild().getMaxValue());
        } else {
            node.setMaxValue(Math.max(node.getLeftChild().getMaxValue(),
                                      node.getRightChild().getMaxValue()));
        }
    }
    return node;
}      

可以看見,這裡的葉子節點判斷條件就是 left == right。其他方面和二叉樹沒有任何差別。

查詢區間最大值

public static Integer getMaxValue(SegmentTree segmentTree, int left, int right) {
    if (Objects.isNull(segmentTree)) return null;
    if (segmentTree.getL() == left && segmentTree.getR() == right) {
        System.out.println("擷取了區間 [" + left + "," + right + "] 的最大值" + segmentTree.getMaxValue());
        return segmentTree.getMaxValue();
    }
    int segMid = (segmentTree.getL() + segmentTree.getR()) >>> 1;
    if (segMid < left) {
        return getMaxValue(segmentTree.getRightChild(), left, right);
    }
    if (segMid >= right) {
        return getMaxValue(segmentTree.getLeftChild(), left, right);
    }
    // TODO: 2022/1/17 左半邊答案
    Integer leftMax = getMaxValue(segmentTree.getLeftChild(), left, segMid);
    Integer rightMax = getMaxValue(segmentTree.getRightChild(), segMid + 1, right);
    if (Objects.isNull(leftMax)) {
        if (Objects.isNull(rightMax)) {
            return -100000;
        } else {
            return rightMax;
        }
    } else {
        if (Objects.isNull(rightMax)) {
            return leftMax;
        } else {
            return Math.max(leftMax, rightMax);
        }
    }
}      

從上面的代碼分析,設目前節點的區間為【L,R】,那麼對于區間[l,r]的最大值來說,就需要進行分類讨論,如果LR的區間中點Mid在lr區間的左邊,那麼max(lr) = max(右子樹,l,r);如果LR的區間中點在lr區間的右邊,則max(lr) = max(左子樹,l,r);如果Mid在lr區間裡面,則 max(lr) = max(左子樹,l,mid) 和 max(右子樹,mid+1,r)中的較大值。

下面我們來看看測試用例和運作結果:

public static void main(String[] args) {
    int[] a = new int[]{2, 5, 4, 7, 6, 0, 1, -1, 2, 3, 6, 7, 0, 2, 9, 8, 5, 4, 7, 2};
    SegmentTree segmentTree = buildTree(0, a.length - 1, a);
    System.out.println(getMaxValue(segmentTree, 0, 16));
}      

結果如下

擷取了區間 [0,9] 的最大值7

擷取了區間 [10,14] 的最大值9

擷取了區間 [15,16] 的最大值8

9

擷取區間和

現在需要對原來的建樹過程進行改造,首先,在基礎結構中添加sum字段

public class SegmentTree {

    private Integer value;
    private Integer maxValue;
    private Integer sum;

    private Integer l;
    private Integer r;

    private SegmentTree leftChild;
    private SegmentTree rightChild;
}      

然後在建樹方法中,添加對和的維護

public static SegmentTree buildTree(int left, int right, int[] value) {
    if (left > right) {
        return null;
    }

    SegmentTree node = new SegmentTree();
    node.setValue(value[left]);
    node.setL(left);
    node.setR(right);
    if (left == right) {
        // TODO: 2022/1/17 退出條件
        node.setMaxValue(node.getValue());
        node.setSum(node.getValue());
        return node;
    }
    int mid = (left + right) >>> 1;
    node.setLeftChild(buildTree(left, mid, value));
    node.setRightChild(buildTree(mid + 1, right, value));
    if (Objects.isNull(node.getLeftChild())) {
        if (Objects.isNull(node.getRightChild())) {
            node.setMaxValue(node.getValue());
            node.setSum(node.getValue());
        } else {
            node.setMaxValue(node.getRightChild().getMaxValue());
            node.setSum(node.getRightChild().getSum());
        }
    } else {
        if (Objects.isNull(node.getRightChild())) {
            node.setMaxValue(node.getLeftChild().getMaxValue());
            node.setSum(node.getLeftChild().getSum());
        } else {
            node.setMaxValue(Math.max(node.getLeftChild().getMaxValue(),
                                      node.getRightChild().getMaxValue()));
            node.setSum(node.getLeftChild().getSum() + node.getRightChild().getSum());
        }
    }
    return node;
}      

然後擷取總和

public static Integer getSum(SegmentTree segmentTree, int left, int right) {
    if (Objects.isNull(segmentTree)) return null;
    if (segmentTree.getL() == left && segmentTree.getR() == right) {
        System.out.println("擷取了區間 [" + left + "," + right + "] 的和" + segmentTree.getSum());
        return segmentTree.getSum();
    }
    int segMid = (segmentTree.getL() + segmentTree.getR()) >>> 1;
    if (segMid < left) {
        return getSum(segmentTree.getRightChild(), left, right);
    }
    if (segMid >= right) {
        return getSum(segmentTree.getLeftChild(), left, right);
    }
    // TODO: 2022/1/17 左半邊答案
    Integer leftSum = getSum(segmentTree.getLeftChild(), left, segMid);
    Integer rightSum = getSum(segmentTree.getRightChild(), segMid + 1, right);
    if (Objects.isNull(leftSum)) {
        if (Objects.isNull(rightSum)) {
            return segmentTree.getSum();
        } else {
            return rightSum;
        }
    } else {
        if (Objects.isNull(rightSum)) {
            return leftSum;
        } else {
            return leftSum + rightSum;
        }
    }
}      

測試程式和結果如下:

public static void main(String[] args) {
    int[] a = new int[]{2, 5, 4, 7, 6, 0, 1, -1, 2, 3, 6, 7, 0, 2, 9, 8, 5, 4, 7, 2};
    SegmentTree segmentTree = buildTree(0, a.length - 1, a);
    System.out.println(getSum(segmentTree,0,3));
}      

擷取了區間 [0,2] 的和11

擷取了區間 [3,3] 的和7

18

單點更新

/**
     * 這裡的left == right
     *
     * @param segmentTree
     * @param left
     * @param right
     * @param value
     */
public static void update(SegmentTree segmentTree, int left, int right, int value) {
    if (segmentTree.getL() == left && segmentTree.getR() == right) {
        segmentTree.setValue(value);
        segmentTree.setMaxValue(value);
        segmentTree.setSum(value);
        return;
    }
    int mid = (segmentTree.getL() + segmentTree.getR()) >>> 1;
    if (mid >= left) {
        update(segmentTree.getLeftChild(), left, right, value);
    }
    if (mid < left) {
        update(segmentTree.getRightChild(), left, right, value);
    }
    segmentTree.setMaxValue(Math.max(segmentTree.getLeftChild().getMaxValue(),segmentTree.getRightChild().getMaxValue()));
    segmentTree.setSum(segmentTree.getLeftChild().getSum() + segmentTree.getRightChild().getSum());
}      

總結

繼續閱讀