Edited at

GoでTCPソケットを使ったときのメモ


GoのTCPソケットを触ったときのメモ


ローカルのポートを指定してソケットを生成する方法

remote, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8889")

local, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8899")
conn, err := net.DialTCP("tcp", local, remote)
if err != nil {
return nil, xerrors.Errorf("net.Conn DialTCP error: %w", err)
}

ポートを指定しない場合はもっと簡単にできます。

conn, _ := net.Dial("tcp", "127.0.0.1:8899")


Read

len, err := conn.Read(buf)

ReadはサーバーがWriteをしない限りブロックします。

タイムアウトをしたい場合、ReadDeadlineを指定します。

err := c.conn.SetReadDeadline(time.Now().Add(3 * time.Second))

Deadlineを超えると*net.OpErrori/o timeoutが返ってきます。

Serverから切断された場合はEOFが返ってきます。EOFの場合、必ずlenは0になります。

なので、切断していても、まだ読み込んでいないデータはReadができます。

そのときはエラーは返ってこず、次にReadしたときにerrにEOFが入ります。

server側で切断されて、EOFをうけっとたコネクションをCloseしてもerrorになりません。

ReadDeadlineの設定されていtimeoutを受け取っても同様です。

2度目は use of closed network connection が返ります(panicにはならない)。

buffer,err := ioutil.ReadAll(conn)

ioutil.ReadAllはsocketが切断されるまでブロックします。

なのでerrにEOFは来ません。


Write

server側で切断されている状態でWriteすると*net.OpErrorwrite: broken pipが返ってきます。

server側で切断されて、write: broken pipをうけっとたコネクションをCloseしてもerrorになりません。

2度目は use of closed network connection が返ります(panicにはならない)。


コネクションを操作するプログラムのテスト

コネクションを操作するプログラムのテストはちょっと厄介です。サーバー側の処理をシナリオごとに変えたい場合に工夫が必要です。

ベースはこのブログを参考にしました。


テスト対象のプログラム

指定したbyte数まで読み込んで返すプログラム、足りない場合は待つ。

package main

import (
"net"
"time"

"golang.org/x/xerrors"
)

var (
Timeout = 1 * time.Second
)
type ConnectionConfig struct {
LocalAddr string
RemoteAddr string
}

type Connection interface {
ReadUntil(byteLen int) ([]byte, error)
Close() error
}

type ConnectionImpl struct {
conn net.Conn
}

func CreateConnection(option ConnectionConfig) (Connection, error) {
remote, _ := net.ResolveTCPAddr("tcp", option.RemoteAddr)
local, _ := net.ResolveTCPAddr("tcp", option.LocalAddr)
conn, err := net.DialTCP("tcp", local, remote)
if err != nil {
return nil, xerrors.Errorf("net.Conn DialTCP error: %w", err)
}
return &ConnectionImpl{conn}, nil
}

func (c *ConnectionImpl) ReadUntil(byteLen int) ([]byte, error) {
count := 0
res := make([]byte, 0)
tmp := make([]byte, byteLen)
_ = c.conn.SetReadDeadline(time.Now().Add(Timeout))
for {
l, err := c.conn.Read(tmp)
if err != nil {
_ = c.Close()
return nil, xerrors.Errorf("net.Conn Read error: %w", err)
}
count += l
res = append(res, tmp[:l]...)
if count >= byteLen {
return res, nil
}
tmp = make([]byte, byteLen-count)
}
}

func (c *ConnectionImpl) Close() error {
return c.conn.Close()
}


テストのためサーバー側を作成

テストごとに下記のことを行います。


  • 空いているポートを取得

  • サーバーに接続ポートをキーにサーバー側の処理を登録

  • Accept()

  • テスト実行(登録した関数を実行)

package main

import (
"fmt"
"net"
"os"
"sync"
"testing"
)

type tcpServer struct {
addr string
funcMap map[string]func(conn net.Conn)
server net.Listener
}

var server *tcpServer

func (t *tcpServer) Run() {
var err error
t.server, err = net.Listen("tcp", t.addr)
if err != nil {
fmt.Printf("failed Listen: %v\n", err)
return
}
for {
conn, err := t.server.Accept()
if err != nil {
break
}
if conn == nil {
fmt.Printf("conn nil\n")
break
}
mu.Lock()
t.funcMap[conn.RemoteAddr().String()](conn)
mu.Unlock()
}
}

func (t *tcpServer) SetFunc(addr string, f func(conn net.Conn)) {
mu.Lock()
t.funcMap[addr] = f
mu.Unlock()
}

func (t *tcpServer) Close() (err error) {
return t.server.Close()
}

func setup() *tcpServer {
server = &tcpServer{
addr: "127.0.0.1:7889",
funcMap: map[string]func(conn net.Conn){},
}
go server.Run()
return server
}

func teardown() {
_ = server.Close()
}

func TestMain(m *testing.M) {
setup()
ret := m.Run()
if ret == 0 {
teardown()
}
os.Exit(ret)
}


テスト

下記のパターンのテストをしています。


  • 断続的に送ってきても指定バイト数読み込める

  • 指定バイト数以上Writeがある場合でも指定したバイトだけしか読まない

  • 指定バイト未満で切断した場合エラーを返す

  • 指定バイト読み込み前にタイムアウトした場合エラーを返す

package main

import (
"fmt"
"net"
"reflect"
"testing"
"time"

"github.com/cenkalti/backoff/v3"
)

// 空いているポートを見つける https://github.com/phayes/freeport を参考
func getAddr() int {
var port int
f := func() error {
addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
if err != nil {
return err
}
l, err := net.ListenTCP("tcp", addr)
if err != nil {
return err
}
defer func() { _ = l.Close() }()
port = l.Addr().(*net.TCPAddr).Port
return nil
}
b := backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 5)
err := backoff.Retry(f, b)
if err != nil {
panic("free port not exit")
}
return port
}

func getConfigFunc() func() ConnectionConfig {
return func() ConnectionConfig {
return ConnectionConfig{
LocalAddr: fmt.Sprintf("127.0.0.1:%d", getAddr()),
RemoteAddr: "127.0.0.1:7889",
}
}
}

func TestConnection_ReadUntil(t *testing.T) {
type args struct {
byteLen int
}
tests := []struct {
name string
configFunc func() ConnectionConfig
serverFunc func(conn net.Conn)
args args
want []byte
wantErr bool
}{
{
"Success Write = Read",
getConfigFunc(),
func(conn net.Conn) {
_, _ = conn.Write([]byte{1})
time.Sleep(300 * time.Microsecond)
_, _ = conn.Write([]byte{2})
time.Sleep(300 * time.Microsecond)
_, _ = conn.Write([]byte{3})
_ = conn.Close()
},
args{3},
[]byte{1, 2, 3},
false,
},
{
"Success Write > Read",
getConfigFunc(),
func(conn net.Conn) {
_, _ = conn.Write([]byte{1, 2, 3})
_ = conn.Close()
},
args{1},
[]byte{1},
false,
},
{
"Failed Write < Read",
getConfigFunc(),
func(conn net.Conn) {
_, _ = conn.Write([]byte{1})
_ = conn.Close()
},
args{2},
nil,
true,
},
{
"Failed Timeout",
getConfigFunc(),
func(conn net.Conn) {
_, _ = conn.Write([]byte{1})
time.Sleep(2 * time.Second)
_ = conn.Close()
},
args{2},
nil,
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conf := tt.configFunc()
server.SetFunc(conf.LocalAddr, tt.serverFunc)
c, err := CreateConnection(conf)
if err != nil {
t.Errorf("CreateConnection() error = %v", err)
return
}
defer func() { _ = c.Close() }()
got, err := c.ReadUntil(tt.args.byteLen)
if (err != nil) != tt.wantErr {
t.Errorf("Connection.ReadUntil() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Connection.ReadUntil() = %v, want %v", got, tt.want)
}
_ = c.Close()
})
}
}