Help us understand the problem. What is going on with this article?

GoでTCPソケットを使ったときのメモ

More than 1 year has passed since last update.

GoのTCPソケットを触ったときのメモ

ローカルのポートを指定してソケットを生成する方法

remote, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8889")
local, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8899")
conn, err := net.DialTCP("tcp", local, remote)
if err != nil {
    return nil, xerrors.Errorf("net.Conn DialTCP error: %w", err)
}

ポートを指定しない場合はもっと簡単にできます。

conn, _ := net.Dial("tcp", "127.0.0.1:8899")

Read

len, err := conn.Read(buf)

ReadはサーバーがWriteをしない限りブロックします。
タイムアウトをしたい場合、ReadDeadlineを指定します。

err := c.conn.SetReadDeadline(time.Now().Add(3 * time.Second))

Deadlineを超えると*net.OpErrori/o timeoutが返ってきます。
Serverから切断された場合はEOFが返ってきます。EOFの場合、必ずlenは0になります。
なので、切断していても、まだ読み込んでいないデータはReadができます。
そのときはエラーは返ってこず、次にReadしたときにerrにEOFが入ります。

server側で切断されて、EOFをうけっとたコネクションをCloseしてもerrorになりません。
ReadDeadlineの設定されていtimeoutを受け取っても同様です。
2度目は use of closed network connection が返ります(panicにはならない)。

buffer,err := ioutil.ReadAll(conn)

ioutil.ReadAllはsocketが切断されるまでブロックします。
なのでerrにEOFは来ません。

Write

server側で切断されている状態でWriteすると*net.OpErrorwrite: broken pipが返ってきます。
server側で切断されて、write: broken pipをうけっとたコネクションをCloseしてもerrorになりません。
2度目は use of closed network connection が返ります(panicにはならない)。

コネクションを操作するプログラムのテスト

コネクションを操作するプログラムのテストはちょっと厄介です。サーバー側の処理をシナリオごとに変えたい場合に工夫が必要です。

ベースはこのブログを参考にしました。

テスト対象のプログラム

指定したbyte数まで読み込んで返すプログラム、足りない場合は待つ。

package main

import (
    "net"
    "time"

    "golang.org/x/xerrors"
)

var (
    Timeout = 1 * time.Second
)
type ConnectionConfig struct {
    LocalAddr  string
    RemoteAddr string
}

type Connection interface {
    ReadUntil(byteLen int) ([]byte, error)
    Close() error
}

type ConnectionImpl struct {
    conn net.Conn
}

func CreateConnection(option ConnectionConfig) (Connection, error) {
    remote, _ := net.ResolveTCPAddr("tcp", option.RemoteAddr)
    local, _ := net.ResolveTCPAddr("tcp", option.LocalAddr)
    conn, err := net.DialTCP("tcp", local, remote)
    if err != nil {
        return nil, xerrors.Errorf("net.Conn DialTCP error: %w", err)
    }
    return &ConnectionImpl{conn}, nil
}

func (c *ConnectionImpl) ReadUntil(byteLen int) ([]byte, error) {
    count := 0
    res := make([]byte, 0)
    tmp := make([]byte, byteLen)
    _ = c.conn.SetReadDeadline(time.Now().Add(Timeout))
    for {
        l, err := c.conn.Read(tmp)
        if err != nil {
            _ = c.Close()
            return nil, xerrors.Errorf("net.Conn Read error: %w", err)
        }
        count += l
        res = append(res, tmp[:l]...)
        if count >= byteLen {
            return res, nil
        }
        tmp = make([]byte, byteLen-count)
    }
}

func (c *ConnectionImpl) Close() error {
    return c.conn.Close()
}

テストのためサーバー側を作成

テストごとに下記のことを行います。

  • 空いているポートを取得
  • サーバーに接続ポートをキーにサーバー側の処理を登録
  • Accept()
  • テスト実行(登録した関数を実行)
package main

import (
    "fmt"
    "net"
    "os"
    "sync"
    "testing"
)

type tcpServer struct {
    addr    string
    funcMap map[string]func(conn net.Conn)
    server  net.Listener
}

var server *tcpServer

func (t *tcpServer) Run() {
    var err error
    t.server, err = net.Listen("tcp", t.addr)
    if err != nil {
        fmt.Printf("failed Listen: %v\n", err)
        return
    }
    for {
        conn, err := t.server.Accept()
        if err != nil {
            break
        }
        if conn == nil {
            fmt.Printf("conn  nil\n")
            break
        }
        mu.Lock()
        t.funcMap[conn.RemoteAddr().String()](conn)
        mu.Unlock()
    }
}

func (t *tcpServer) SetFunc(addr string, f func(conn net.Conn)) {
    mu.Lock()
    t.funcMap[addr] = f
    mu.Unlock()
}

func (t *tcpServer) Close() (err error) {
    return t.server.Close()
}

func setup() *tcpServer {
    server = &tcpServer{
        addr:    "127.0.0.1:7889",
        funcMap: map[string]func(conn net.Conn){},
    }
    go server.Run()
    return server
}

func teardown() {
    _ = server.Close()
}

func TestMain(m *testing.M) {
    setup()
    ret := m.Run()
    if ret == 0 {
        teardown()
    }
    os.Exit(ret)
}

テスト

下記のパターンのテストをしています。

  • 断続的に送ってきても指定バイト数読み込める
  • 指定バイト数以上Writeがある場合でも指定したバイトだけしか読まない
  • 指定バイト未満で切断した場合エラーを返す
  • 指定バイト読み込み前にタイムアウトした場合エラーを返す
package main

import (
    "fmt"
    "net"
    "reflect"
    "testing"
    "time"

    "github.com/cenkalti/backoff/v3"
)

// 空いているポートを見つける https://github.com/phayes/freeport を参考
func getAddr() int {
    var port int
    f := func() error {
        addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
        if err != nil {
            return err
        }
        l, err := net.ListenTCP("tcp", addr)
        if err != nil {
            return err
        }
        defer func() { _ = l.Close() }()
        port = l.Addr().(*net.TCPAddr).Port
        return nil
    }
    b := backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 5)
    err := backoff.Retry(f, b)
    if err != nil {
        panic("free port not exit")
    }
    return port
}

func getConfigFunc() func() ConnectionConfig {
    return func() ConnectionConfig {
        return ConnectionConfig{
            LocalAddr:  fmt.Sprintf("127.0.0.1:%d", getAddr()),
            RemoteAddr: "127.0.0.1:7889",
        }
    }
}

func TestConnection_ReadUntil(t *testing.T) {
    type args struct {
        byteLen int
    }
    tests := []struct {
        name       string
        configFunc func() ConnectionConfig
        serverFunc func(conn net.Conn)
        args       args
        want       []byte
        wantErr    bool
    }{
        {
            "Success Write = Read",
            getConfigFunc(),
            func(conn net.Conn) {
                _, _ = conn.Write([]byte{1})
                time.Sleep(300 * time.Microsecond)
                _, _ = conn.Write([]byte{2})
                time.Sleep(300 * time.Microsecond)
                _, _ = conn.Write([]byte{3})
                _ = conn.Close()
            },
            args{3},
            []byte{1, 2, 3},
            false,
        },
        {
            "Success Write > Read",
            getConfigFunc(),
            func(conn net.Conn) {
                _, _ = conn.Write([]byte{1, 2, 3})
                _ = conn.Close()
            },
            args{1},
            []byte{1},
            false,
        },
        {
            "Failed Write < Read",
            getConfigFunc(),
            func(conn net.Conn) {
                _, _ = conn.Write([]byte{1})
                _ = conn.Close()
            },
            args{2},
            nil,
            true,
        },
        {
            "Failed Timeout",
            getConfigFunc(),
            func(conn net.Conn) {
                _, _ = conn.Write([]byte{1})
                time.Sleep(2 * time.Second)
                _ = conn.Close()
            },
            args{2},
            nil,
            true,
        },
    }
    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            conf := tt.configFunc()
            server.SetFunc(conf.LocalAddr, tt.serverFunc)
            c, err := CreateConnection(conf)
            if err != nil {
                t.Errorf("CreateConnection() error = %v", err)
                return
            }
            defer func() { _ = c.Close() }()
            got, err := c.ReadUntil(tt.args.byteLen)
            if (err != nil) != tt.wantErr {
                t.Errorf("Connection.ReadUntil() error = %v, wantErr %v", err, tt.wantErr)
                return
            }
            if !reflect.DeepEqual(got, tt.want) {
                t.Errorf("Connection.ReadUntil() = %v, want %v", got, tt.want)
            }
            _ = c.Close()
        })
    }
}
t10471
mercari
フリマアプリ「メルカリ」を、グローバルで開発しています。
https://tech.mercari.com/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away