LoginSignup
0
0

More than 3 years have passed since last update.

c++のstd::set::lower_boundとupper_boundをGoで作ってみました

Last updated at Posted at 2020-12-02

はじめに

零細企業の事務をやっているおじさんです。Go歴は半年です。

螺旋本のpart3 第16章の16.13線分交差問題AOJ CGL6_JをGoで解こうとしたところ、そこに載っているサンプルコードでc++のlower_boundとupper_boundが、Goにはないようだったので、自作しました。

注意
・二分探索木で作っています
・イテレータではなく、Nodeを返します
・sliceには使えません
・初心者なので、ここで紹介しているものが不完全である可能性は十分あります
・四角い車輪の再発明の可能性大

Go版lower_boundのコード


func lowerBoundNode(n *Node, x interface{}) *Node {
    if n.key < x.(int) {
        if n.right != nil {
            n = lowerBoundNode(n.right, x)
        }
    } else if n.key >= x.(int) {
        if n.left != nil && n.left.key >= x.(int) {
            n = lowerBoundNode(n.left, x)
        }
    }
    return n
}
func (t *Tree) lowerBound(x interface{}) *Node {
    return lowerBoundNode(t.root, x)
}

t.lowerBound(x)で、x <= n.key となる最初の*Nodeを返します。
n.key < x なら右に行ってn.key >= xとなるノードを探し、n.key >= x なら左に行って最初のノードを探しています。

Go版upper_boundのコード


func upperBoundNode(n *Node, x interface{}) *Node {
    if n.key <= x.(int) {
        if n.right != nil {
            n = upperBoundNode(n.right, x)
        }
    } else if n.key > x.(int) {
        if n.left != nil && n.left.key > x.(int) {
            n = upperBoundNode(n.left, x)
        }
    }
    return n
}
func (t *Tree)upperBound(x interface{})*Node{
    return upperBoundNode(t.root,x)
}

t.upperBound(x)で、x < n.key となる最初の*Nodeを返します。
lowerBoundと大体同じです。

コード全部

AOJ CGL6_Jへの提出用ではなく、AOJ ALDS1_8_Cへ提出したものを改変しています。
AOJ ALDS1_8_Cにある入力例を使用し実行→paiza.io

もしかしたらsetとして使えるかもしれませんが、多分遅いです。


package main

import (
    "bufio"
    "fmt"
    "os"
    "strconv"
    "strings"
)

var rdr = bufio.NewReaderSize(os.Stdin, 1024*1024)

func readLine() string {
    buf := []byte{}
    for {
        l, p, e := rdr.ReadLine()
        if e != nil {
            panic(e)
        }
        buf = append(buf, l...)
        if !p {
            break
        }
    }
    return string(buf)
}
func readInts() []int {
    s := strings.Split(readLine(), " ")
    res := []int{}
    for _, v := range s {
        i, _ := strconv.Atoi(v)
        res = append(res, i)
    }
    return res
}

type Node struct {
    key    int
    parent *Node
    left   *Node
    right  *Node
}

func NewNode() *Node {
    res := Node{}
    return &res
}

type Tree struct {
    root *Node
}

func NewTree() *Tree {
    res := Tree{
        root: nil,
    }
    return &res
}

// 先行順巡回アルゴリズムでNodeのkeyを返す
func (t *Tree) preParse(z *Node) string {
    if z == nil {
        return ""
    }
    return " " + strconv.Itoa(z.key) + t.preParse(z.left) + t.preParse(z.right)
}

// 中間順巡回アルゴリズムでNodeのkeyを返す
func (t *Tree) inParse(z *Node) string {
    //fmt.Println(z)
    if z == nil {
        return ""
    }
    return t.inParse(z.left) + " " + strconv.Itoa(z.key) + t.inParse(z.right)

}

func find(sl []int, x int) int {
    for i := 0; i < len(sl); i++ {
        if sl[i] == x {
            return i
        }
    }
    return -1
}

/*
1 insert(T, z)
2     y = NIL // x の親
3     x = 'T の根'
4     while x ≠ NIL
5         y = x // 親を設定
6         if z.key < x.key
7             x = x.left // 左の子へ移動
8         else
9             x = x.right // 右の子へ移動
10    z.p = y
11
12    if y == NIL // T が空の場合
13        'T の根' = z
14    else if z.key < y.key
15        y.left = z // z を y の左の子にする
16    else
17        y.right = z // z を y の右の子にする
*/
func (t *Tree) insert(k int) {
    //fmt.Println("k",k)
    x := t.root
    var y *Node
    //var flag bool=false

    for x != nil {
        y = x
        if k == x.key { //既にxをkeyに持つNodeがあった場合
            return
        } else if k < x.key {
            x = x.left
        } else {
            x = x.right
        }
    }

    var new_node *Node
    new_node = &Node{left: nil, right: nil, key: k}
    new_node.parent = y

    if y == nil {
        t.root = new_node
    } else if new_node.key < y.key {
        y.left = new_node
    } else {
        y.right = new_node
    }
    //fmt.Println(new_node)
}
func (t *Tree) find(k int) *Node {
    x := t.root
    for x != nil && k != x.key {
        if k < x.key {
            x = x.left
        } else {
            x = x.right
        }
    }
    return x
}
func (t *Tree) deleteNode(z *Node) {
    //yを削除対象とする
    var y *Node
    if z.left == nil || z.right == nil {
        y = z
    } else {
        y = getSuccessor(z)
    }

    //yの子xを決める
    var x *Node
    if y.left != nil {
        x = y.left
    } else {
        x = y.right
    }
    //yの親を設定する
    if x != nil {
        x.parent = y.parent
    }
    // 削除する
    if y.parent == nil {
        t.root = x
    } else if y == y.parent.left {
        y.parent.left = x
    } else {
        y.parent.right = x
    }

    if y != z {
        z.key = y.key
    }
}
func getSuccessor(x *Node) *Node {
    if x.right != nil {
        return getMinimum(x.right)
    }
    var y *Node
    y = x.parent
    for y != nil && x == y.right {
        x = y
        y = y.parent
    }
    return y
}
func getMinimum(x *Node) *Node {
    for x.left != nil {
        x = x.left
    }
    return x
}

func lowerBoundNode(n *Node, x interface{}) *Node {
    if n.key < x.(int) {
        if n.right != nil {
            n = lowerBoundNode(n.right, x)
        }
    } else if n.key >= x.(int) {
        if n.left != nil && n.left.key >= x.(int) {
            n = lowerBoundNode(n.left, x)
        }
    }
    return n
}
func (t *Tree) lowerBound(x interface{}) *Node {
    return lowerBoundNode(t.root, x)
}
func upperBoundNode(n *Node, x interface{}) *Node {
    //fmt.Println("lowerBoundNode n:",&n,n,"x",x)
    if n.key <= x.(int) {
        if n.right != nil {
            n = upperBoundNode(n.right, x)
        }
    } else if n.key > x.(int) {
        if n.left != nil && n.left.key > x.(int) {
            n = upperBoundNode(n.left, x)
        }
    }
    return n
}
func (t *Tree) upperBound(x interface{}) *Node {
    return upperBoundNode(t.root, x)
}
func main() {
    n := readInts()[0]
    t := NewTree()
    //fmt.Println(n,*t)
    for i := 0; i < n; i++ {
        tmp := strings.Split(readLine(), " ")
        if tmp[0] == "insert" {
            v, _ := strconv.Atoi(tmp[1])
            t.insert(v)
        } else if tmp[0] == "print" {
            fmt.Println(t.inParse(t.root))
            fmt.Println(t.preParse(t.root))
            fmt.Println("t.root", t.root)
            fmt.Println("t.lowerBound(t.root,x)", t.lowerBound(1))
            fmt.Println("t.upperBound(t.root,x)", t.upperBound(1))
        } else if tmp[0] == "find" {
            v, _ := strconv.Atoi(tmp[1])
            res := t.find(v)
            if res != nil {
                fmt.Println("yes")
            } else {
                fmt.Println("no")
            }
        } else if tmp[0] == "delete" {
            v, _ := strconv.Atoi(tmp[1])
            //fmt.Println(t.find(v))
            t.deleteNode(t.find(v))
        }
    }
}

参考

お気楽 Go 言語プログラミング入門(二分探索木)
プログラミングコンテスト攻略のためのアルゴリズムとデータ構造(Amazon)
https://cpprefjp.github.io/reference/set/set/lower_bound.html
https://cpprefjp.github.io/reference/set/set/upper_bound.html

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0