2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

概要

ABC435のEで、区間をsetで管理する問題が出題されました。

BTreeSetで頑張って管理することでこの問題は解くことができたのですが、コンテスト後にIntervalSetなるライブラリがあることを知りました。

今回はRustでIntervalSetを実装した記事になります。

IntervalSetとはなにか

IntervalSetとは、数直線上の区間 [l, r) を互いに重ならない形で管理し区間ごとに値 V を持たせるデータ構造です。

  • 区間は常に 互いに disjoint
  • 同じ値 V を持つ隣接区間は 自動的にマージ
  • 区間の分割・削除・上書きを効率的に行える

という嬉しい点があります。

実装

pub mod interval_set {
    use std::{
        cmp::{max, min, Ordering},
        collections::BTreeSet,
        fmt::Debug,
    };
    use num_traits::Bounded;

    #[derive(Clone, Debug)]
    pub struct Node<T, V> {
        pub l: T,
        pub r: T,
        pub val: V,
    }

    impl<T: Ord, V> PartialEq for Node<T, V> {
        fn eq(&self, other: &Self) -> bool {
            self.l == other.l && self.r == other.r
        }
    }
    impl<T: Ord, V> Eq for Node<T, V> {}
    impl<T: Ord, V> PartialOrd for Node<T, V> {
        fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
            Some(self.cmp(other))
        }
    }
    impl<T: Ord, V> Ord for Node<T, V> {
        fn cmp(&self, other: &Self) -> Ordering {
            match self.l.cmp(&other.l) {
                Ordering::Equal => self.r.cmp(&other.r),
                o => o,
            }
        }
    }

    pub struct IntervalSet<T, V> {
        identity: V,
        set: BTreeSet<Node<T, V>>,
    }

    impl<T, V> IntervalSet<T, V>
    where
        T: Ord + Copy + Debug + Bounded,
        V: Clone + PartialEq + Default + Debug,
    {
        pub fn new(identity: V) -> Self {
            Self {
                identity,
                set: BTreeSet::new(),
            }
        }

        pub fn iter(&self) -> impl Iterator<Item = &Node<T, V>> {
            self.set.iter()
        }

        pub fn get_node(&self, p: T) -> Option<&Node<T, V>> {
            let key = Node {
                l: p,
                r: T::max_value(),
                val: V::default(),
            };
            let mut range = self.set.range(..=key);
            if let Some(node) = range.next_back() {
                if node.l <= p && p < node.r {
                    return Some(node);
                }
            }
            None
        }

        pub fn next_p(&self, p: T) -> Option<&Node<T, V>> {
            if let Some(n) = self.get_node(p) {
                return Some(n);
            }
            let key = Node {
                l: p,
                r: p,
                val: V::default(),
            };
            self.set.range(key..).next()
        }

        pub fn covered_point(&self, p: T) -> bool {
            self.get_node(p).is_some()
        }

        pub fn covered_range(&self, l: T, r: T) -> bool {
            assert!(l <= r);
            if l == r {
                return true;
            }
            if let Some(node) = self.get_node(l) {
                r <= node.r
            } else {
                false
            }
        }

        pub fn same(&self, p: T, q: T) -> bool {
            let h1 = self.get_node(p);
            let h2 = self.get_node(q);
            match (h1, h2) {
                (Some(a), Some(b)) => a.l == b.l && a.r == b.r,
                _ => false,
            }
        }

        pub fn get_val(&self, p: T) -> &V {
            if let Some(node) = self.get_node(p) {
                &node.val
            } else {
                &self.identity
            }
        }

        pub fn mex(&self, p: T) -> T {
            let key = Node {
                l: p,
                r: T::max_value(),
                val: V::default(),
            };
            let mut range = self.set.range(..=key);
            if let Some(node) = range.next_back() {
                if node.l <= p && p < node.r {
                    return node.r;
                }
            }
            p
        }

        pub fn update<F>(&mut self, mut l: T, mut r: T, val: V, mut f: F)
        where
            F: FnMut(bool, T, T, &V),
        {
            assert!(l <= r);
            if l == r {
                return;
            }

            let mut to_process: Vec<Node<T, V>> = Vec::new();

            if let Some(node) = self.get_node(l) {
                if node.l < r {
                    to_process.push(node.clone());
                }
            }
            let key = Node {
                l,
                r: l,
                val: V::default(),
            };
            for node in self.set.range(key..) {
                if node.l >= r {
                    break;
                }
                to_process.push(node.clone());
            }

            for node in to_process.iter() {
                if self.set.remove(node) {
                    f(false, node.l, node.r, &node.val);

                    if node.l < l {
                        let left_l = node.l;
                        let left_r = min(node.r, l);
                        if left_l < left_r {
                            let left_val = node.val.clone();
                            self.set.insert(Node {
                                l: left_l,
                                r: left_r,
                                val: left_val.clone(),
                            });
                            f(true, left_l, left_r, &left_val);
                        }
                    }
                    if node.r > r {
                        let right_l = max(node.l, r);
                        let right_r = node.r;
                        if right_l < right_r {
                            let right_val = node.val.clone();
                            self.set.insert(Node {
                                l: right_l,
                                r: right_r,
                                val: right_val.clone(),
                            });
                            f(true, right_l, right_r, &right_val);
                        }
                    }
                }
            }

            let left_key = Node {
                l,
                r: l,
                val: V::default(),
            };
            if let Some(prev) = self.set.range(..left_key).next_back().cloned() {
                if prev.r == l && prev.val == val {
                    if self.set.remove(&prev) {
                        f(false, prev.l, prev.r, &prev.val);
                        l = prev.l;
                    }
                }
            }

            let right_key = Node {
                l: r,
                r,
                val: V::default(),
            };
            if let Some(next) = self.set.range(right_key..).next().cloned() {
                if next.l == r && next.val == val {
                    if self.set.remove(&next) {
                        f(false, next.l, next.r, &next.val);
                        r = next.r;
                    }
                }
            }

            if l < r {
                self.set.insert(Node {
                    l,
                    r,
                    val: val.clone(),
                });
                f(true, l, r, &val);
            }
        }

        pub fn erase<F>(&mut self, l: T, r: T, mut f: F)
        where
            F: FnMut(bool, T, T, &V),
        {
            assert!(l <= r);
            if l == r {
                return;
            }

            let mut to_process: Vec<Node<T, V>> = Vec::new();

            if let Some(node) = self.get_node(l) {
                if node.l < r {
                    to_process.push(node.clone());
                }
            }
            let key = Node {
                l,
                r: l,
                val: V::default(),
            };
            for node in self.set.range(key..) {
                if node.l >= r {
                    break;
                }
                to_process.push(node.clone());
            }

            for node in to_process.iter() {
                if self.set.remove(node) {
                    f(false, node.l, node.r, &node.val);

                    if node.l < l {
                        let left_l = node.l;
                        let left_r = min(node.r, l);
                        if left_l < left_r {
                            let left_val = node.val.clone();
                            self.set.insert(Node {
                                l: left_l,
                                r: left_r,
                                val: left_val.clone(),
                            });
                            f(true, left_l, left_r, &left_val);
                        }
                    }
                    if node.r > r {
                        let right_l = max(node.l, r);
                        let right_r = node.r;
                        if right_l < right_r {
                            let right_val = node.val.clone();
                            self.set.insert(Node {
                                l: right_l,
                                r: right_r,
                                val: right_val.clone(),
                            });
                            f(true, right_l, right_r, &right_val);
                        }
                    }
                }
            }
        }
    }
}

各APIの概要と計算量

区間数をKとして表現しています。

API 概要 計算量
iter() 現在保持している区間を左端昇順で列挙 O(K)
get_node(p) p を含む区間を取得。存在しなければ None O(log K)
next_p(p) p を含む区間、または p 以上で最初に現れる区間を取得 O(log K)
covered_point(p) p がいずれかの区間に含まれているか O(log K)
covered_range(l, r) 区間 [l, r) が 1 つの区間で完全に覆われているか O(log K)
same(p, q) pq が同一の区間に属しているか O(log K)
get_val(p) p に対応する値を取得。含まれなければ identity O(log K)
mex(p) mexを返す O(log K)
update(l, r, val, f) 区間 [l, r) を値 val で上書き。既存区間は分割・削除され、隣接区間は自動マージ O((k+1) log K)
erase(l, r, f) 区間 [l, r) を削除。交差する区間は分割される O((k+1) log K)

注意点

C++の実装ではaddとdeleteの2つの関数が用意されていますが、本実装では一つの関数fで賄っています。
第一引数がtrueの時はadd, falseの時はdeleteが実行されるように実装して下さい。

verify

素で書くと結構な実装量になりますが、
Intervalsetを用いれば貼るだけの問題になります。

fn main() {
    input! {
        n: usize,
        q: usize,
        lr: [(usize, usize); q],
    }

    let mut set = IntervalSet::new(false);
    let mut b = 0;
    for (l, r) in lr {
        set.update(l, r + 1, true, |added, ll, rr, _seg_val| {
            if added {
                b += rr - ll;
            } else {
                b -= rr - ll;
            }
        });
        let ans = n - b;
        println!("{}", ans);
    }
}

終わりに

よりよい実装があればおしえてください。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?