LoginSignup
3
4

More than 3 years have passed since last update.

GoでTCPソケットを使ったときのメモ

Last updated at Posted at 2019-08-03

GoのTCPソケットを触ったときのメモ

ローカルのポートを指定してソケットを生成する方法

remote, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8889")
local, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8899")
conn, err := net.DialTCP("tcp", local, remote)
if err != nil {
    return nil, xerrors.Errorf("net.Conn DialTCP error: %w", err)
}

ポートを指定しない場合はもっと簡単にできます。

conn, _ := net.Dial("tcp", "127.0.0.1:8899")

Read

len, err := conn.Read(buf)

ReadはサーバーがWriteをしない限りブロックします。
タイムアウトをしたい場合、ReadDeadlineを指定します。

err := c.conn.SetReadDeadline(time.Now().Add(3 * time.Second))

Deadlineを超えると*net.OpErrori/o timeoutが返ってきます。
Serverから切断された場合はEOFが返ってきます。EOFの場合、必ずlenは0になります。
なので、切断していても、まだ読み込んでいないデータはReadができます。
そのときはエラーは返ってこず、次にReadしたときにerrにEOFが入ります。

server側で切断されて、EOFをうけっとたコネクションをCloseしてもerrorになりません。
ReadDeadlineの設定されていtimeoutを受け取っても同様です。
2度目は use of closed network connection が返ります(panicにはならない)。

buffer,err := ioutil.ReadAll(conn)

ioutil.ReadAllはsocketが切断されるまでブロックします。
なのでerrにEOFは来ません。

Write

server側で切断されている状態でWriteすると*net.OpErrorwrite: broken pipが返ってきます。
server側で切断されて、write: broken pipをうけっとたコネクションをCloseしてもerrorになりません。
2度目は use of closed network connection が返ります(panicにはならない)。

コネクションを操作するプログラムのテスト

コネクションを操作するプログラムのテストはちょっと厄介です。サーバー側の処理をシナリオごとに変えたい場合に工夫が必要です。

ベースはこのブログを参考にしました。

テスト対象のプログラム

指定したbyte数まで読み込んで返すプログラム、足りない場合は待つ。

package main

import (
    "net"
    "time"

    "golang.org/x/xerrors"
)

var (
    Timeout = 1 * time.Second
)
type ConnectionConfig struct {
    LocalAddr  string
    RemoteAddr string
}

type Connection interface {
    ReadUntil(byteLen int) ([]byte, error)
    Close() error
}

type ConnectionImpl struct {
    conn net.Conn
}

func CreateConnection(option ConnectionConfig) (Connection, error) {
    remote, _ := net.ResolveTCPAddr("tcp", option.RemoteAddr)
    local, _ := net.ResolveTCPAddr("tcp", option.LocalAddr)
    conn, err := net.DialTCP("tcp", local, remote)
    if err != nil {
        return nil, xerrors.Errorf("net.Conn DialTCP error: %w", err)
    }
    return &ConnectionImpl{conn}, nil
}

func (c *ConnectionImpl) ReadUntil(byteLen int) ([]byte, error) {
    count := 0
    res := make([]byte, 0)
    tmp := make([]byte, byteLen)
    _ = c.conn.SetReadDeadline(time.Now().Add(Timeout))
    for {
        l, err := c.conn.Read(tmp)
        if err != nil {
            _ = c.Close()
            return nil, xerrors.Errorf("net.Conn Read error: %w", err)
        }
        count += l
        res = append(res, tmp[:l]...)
        if count >= byteLen {
            return res, nil
        }
        tmp = make([]byte, byteLen-count)
    }
}

func (c *ConnectionImpl) Close() error {
    return c.conn.Close()
}

テストのためサーバー側を作成

テストごとに下記のことを行います。

  • 空いているポートを取得
  • サーバーに接続ポートをキーにサーバー側の処理を登録
  • Accept()
  • テスト実行(登録した関数を実行)
package main

import (
    "fmt"
    "net"
    "os"
    "sync"
    "testing"
)

type tcpServer struct {
    addr    string
    funcMap map[string]func(conn net.Conn)
    server  net.Listener
}

var server *tcpServer

func (t *tcpServer) Run() {
    var err error
    t.server, err = net.Listen("tcp", t.addr)
    if err != nil {
        fmt.Printf("failed Listen: %v\n", err)
        return
    }
    for {
        conn, err := t.server.Accept()
        if err != nil {
            break
        }
        if conn == nil {
            fmt.Printf("conn  nil\n")
            break
        }
        mu.Lock()
        t.funcMap[conn.RemoteAddr().String()](conn)
        mu.Unlock()
    }
}

func (t *tcpServer) SetFunc(addr string, f func(conn net.Conn)) {
    mu.Lock()
    t.funcMap[addr] = f
    mu.Unlock()
}

func (t *tcpServer) Close() (err error) {
    return t.server.Close()
}

func setup() *tcpServer {
    server = &tcpServer{
        addr:    "127.0.0.1:7889",
        funcMap: map[string]func(conn net.Conn){},
    }
    go server.Run()
    return server
}

func teardown() {
    _ = server.Close()
}

func TestMain(m *testing.M) {
    setup()
    ret := m.Run()
    if ret == 0 {
        teardown()
    }
    os.Exit(ret)
}

テスト

下記のパターンのテストをしています。

  • 断続的に送ってきても指定バイト数読み込める
  • 指定バイト数以上Writeがある場合でも指定したバイトだけしか読まない
  • 指定バイト未満で切断した場合エラーを返す
  • 指定バイト読み込み前にタイムアウトした場合エラーを返す
package main

import (
    "fmt"
    "net"
    "reflect"
    "testing"
    "time"

    "github.com/cenkalti/backoff/v3"
)

// 空いているポートを見つける https://github.com/phayes/freeport を参考
func getAddr() int {
    var port int
    f := func() error {
        addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
        if err != nil {
            return err
        }
        l, err := net.ListenTCP("tcp", addr)
        if err != nil {
            return err
        }
        defer func() { _ = l.Close() }()
        port = l.Addr().(*net.TCPAddr).Port
        return nil
    }
    b := backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 5)
    err := backoff.Retry(f, b)
    if err != nil {
        panic("free port not exit")
    }
    return port
}

func getConfigFunc() func() ConnectionConfig {
    return func() ConnectionConfig {
        return ConnectionConfig{
            LocalAddr:  fmt.Sprintf("127.0.0.1:%d", getAddr()),
            RemoteAddr: "127.0.0.1:7889",
        }
    }
}

func TestConnection_ReadUntil(t *testing.T) {
    type args struct {
        byteLen int
    }
    tests := []struct {
        name       string
        configFunc func() ConnectionConfig
        serverFunc func(conn net.Conn)
        args       args
        want       []byte
        wantErr    bool
    }{
        {
            "Success Write = Read",
            getConfigFunc(),
            func(conn net.Conn) {
                _, _ = conn.Write([]byte{1})
                time.Sleep(300 * time.Microsecond)
                _, _ = conn.Write([]byte{2})
                time.Sleep(300 * time.Microsecond)
                _, _ = conn.Write([]byte{3})
                _ = conn.Close()
            },
            args{3},
            []byte{1, 2, 3},
            false,
        },
        {
            "Success Write > Read",
            getConfigFunc(),
            func(conn net.Conn) {
                _, _ = conn.Write([]byte{1, 2, 3})
                _ = conn.Close()
            },
            args{1},
            []byte{1},
            false,
        },
        {
            "Failed Write < Read",
            getConfigFunc(),
            func(conn net.Conn) {
                _, _ = conn.Write([]byte{1})
                _ = conn.Close()
            },
            args{2},
            nil,
            true,
        },
        {
            "Failed Timeout",
            getConfigFunc(),
            func(conn net.Conn) {
                _, _ = conn.Write([]byte{1})
                time.Sleep(2 * time.Second)
                _ = conn.Close()
            },
            args{2},
            nil,
            true,
        },
    }
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            conf := tt.configFunc()
            server.SetFunc(conf.LocalAddr, tt.serverFunc)
            c, err := CreateConnection(conf)
            if err != nil {
                t.Errorf("CreateConnection() error = %v", err)
                return
            }
            defer func() { _ = c.Close() }()
            got, err := c.ReadUntil(tt.args.byteLen)
            if (err != nil) != tt.wantErr {
                t.Errorf("Connection.ReadUntil() error = %v, wantErr %v", err, tt.wantErr)
                return
            }
            if !reflect.DeepEqual(got, tt.want) {
                t.Errorf("Connection.ReadUntil() = %v, want %v", got, tt.want)
            }
            _ = c.Close()
        })
    }
}
3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4