GoのTCPソケットを触ったときのメモ
ローカルのポートを指定してソケットを生成する方法
remote, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8889")
local, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8899")
conn, err := net.DialTCP("tcp", local, remote)
if err != nil {
return nil, xerrors.Errorf("net.Conn DialTCP error: %w", err)
}
ポートを指定しない場合はもっと簡単にできます。
conn, _ := net.Dial("tcp", "127.0.0.1:8899")
Read
len, err := conn.Read(buf)
ReadはサーバーがWriteをしない限りブロックします。
タイムアウトをしたい場合、ReadDeadlineを指定します。
err := c.conn.SetReadDeadline(time.Now().Add(3 * time.Second))
Deadlineを超えると*net.OpError
のi/o timeout
が返ってきます。
Serverから切断された場合はEOF
が返ってきます。EOF
の場合、必ずlenは0になります。
なので、切断していても、まだ読み込んでいないデータはReadができます。
そのときはエラーは返ってこず、次にReadしたときにerrにEOFが入ります。
server側で切断されて、EOFをうけっとたコネクションをCloseしてもerrorになりません。
ReadDeadlineの設定されていtimeoutを受け取っても同様です。
2度目は use of closed network connection
が返ります(panicにはならない)。
buffer,err := ioutil.ReadAll(conn)
ioutil.ReadAllはsocketが切断されるまでブロックします。
なのでerrにEOF
は来ません。
Write
server側で切断されている状態でWriteすると*net.OpError
のwrite: broken pip
が返ってきます。
server側で切断されて、write: broken pip
をうけっとたコネクションをCloseしてもerrorになりません。
2度目は use of closed network connection
が返ります(panicにはならない)。
コネクションを操作するプログラムのテスト
コネクションを操作するプログラムのテストはちょっと厄介です。サーバー側の処理をシナリオごとに変えたい場合に工夫が必要です。
ベースはこのブログを参考にしました。
テスト対象のプログラム
指定したbyte数まで読み込んで返すプログラム、足りない場合は待つ。
package main
import (
"net"
"time"
"golang.org/x/xerrors"
)
var (
Timeout = 1 * time.Second
)
type ConnectionConfig struct {
LocalAddr string
RemoteAddr string
}
type Connection interface {
ReadUntil(byteLen int) ([]byte, error)
Close() error
}
type ConnectionImpl struct {
conn net.Conn
}
func CreateConnection(option ConnectionConfig) (Connection, error) {
remote, _ := net.ResolveTCPAddr("tcp", option.RemoteAddr)
local, _ := net.ResolveTCPAddr("tcp", option.LocalAddr)
conn, err := net.DialTCP("tcp", local, remote)
if err != nil {
return nil, xerrors.Errorf("net.Conn DialTCP error: %w", err)
}
return &ConnectionImpl{conn}, nil
}
func (c *ConnectionImpl) ReadUntil(byteLen int) ([]byte, error) {
count := 0
res := make([]byte, 0)
tmp := make([]byte, byteLen)
_ = c.conn.SetReadDeadline(time.Now().Add(Timeout))
for {
l, err := c.conn.Read(tmp)
if err != nil {
_ = c.Close()
return nil, xerrors.Errorf("net.Conn Read error: %w", err)
}
count += l
res = append(res, tmp[:l]...)
if count >= byteLen {
return res, nil
}
tmp = make([]byte, byteLen-count)
}
}
func (c *ConnectionImpl) Close() error {
return c.conn.Close()
}
テストのためサーバー側を作成
テストごとに下記のことを行います。
- 空いているポートを取得
- サーバーに接続ポートをキーにサーバー側の処理を登録
- Accept()
- テスト実行(登録した関数を実行)
package main
import (
"fmt"
"net"
"os"
"sync"
"testing"
)
type tcpServer struct {
addr string
funcMap map[string]func(conn net.Conn)
server net.Listener
}
var server *tcpServer
func (t *tcpServer) Run() {
var err error
t.server, err = net.Listen("tcp", t.addr)
if err != nil {
fmt.Printf("failed Listen: %v\n", err)
return
}
for {
conn, err := t.server.Accept()
if err != nil {
break
}
if conn == nil {
fmt.Printf("conn nil\n")
break
}
mu.Lock()
t.funcMap[conn.RemoteAddr().String()](conn)
mu.Unlock()
}
}
func (t *tcpServer) SetFunc(addr string, f func(conn net.Conn)) {
mu.Lock()
t.funcMap[addr] = f
mu.Unlock()
}
func (t *tcpServer) Close() (err error) {
return t.server.Close()
}
func setup() *tcpServer {
server = &tcpServer{
addr: "127.0.0.1:7889",
funcMap: map[string]func(conn net.Conn){},
}
go server.Run()
return server
}
func teardown() {
_ = server.Close()
}
func TestMain(m *testing.M) {
setup()
ret := m.Run()
if ret == 0 {
teardown()
}
os.Exit(ret)
}
テスト
下記のパターンのテストをしています。
- 断続的に送ってきても指定バイト数読み込める
- 指定バイト数以上Writeがある場合でも指定したバイトだけしか読まない
- 指定バイト未満で切断した場合エラーを返す
- 指定バイト読み込み前にタイムアウトした場合エラーを返す
package main
import (
"fmt"
"net"
"reflect"
"testing"
"time"
"github.com/cenkalti/backoff/v3"
)
// 空いているポートを見つける https://github.com/phayes/freeport を参考
func getAddr() int {
var port int
f := func() error {
addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
if err != nil {
return err
}
l, err := net.ListenTCP("tcp", addr)
if err != nil {
return err
}
defer func() { _ = l.Close() }()
port = l.Addr().(*net.TCPAddr).Port
return nil
}
b := backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 5)
err := backoff.Retry(f, b)
if err != nil {
panic("free port not exit")
}
return port
}
func getConfigFunc() func() ConnectionConfig {
return func() ConnectionConfig {
return ConnectionConfig{
LocalAddr: fmt.Sprintf("127.0.0.1:%d", getAddr()),
RemoteAddr: "127.0.0.1:7889",
}
}
}
func TestConnection_ReadUntil(t *testing.T) {
type args struct {
byteLen int
}
tests := []struct {
name string
configFunc func() ConnectionConfig
serverFunc func(conn net.Conn)
args args
want []byte
wantErr bool
}{
{
"Success Write = Read",
getConfigFunc(),
func(conn net.Conn) {
_, _ = conn.Write([]byte{1})
time.Sleep(300 * time.Microsecond)
_, _ = conn.Write([]byte{2})
time.Sleep(300 * time.Microsecond)
_, _ = conn.Write([]byte{3})
_ = conn.Close()
},
args{3},
[]byte{1, 2, 3},
false,
},
{
"Success Write > Read",
getConfigFunc(),
func(conn net.Conn) {
_, _ = conn.Write([]byte{1, 2, 3})
_ = conn.Close()
},
args{1},
[]byte{1},
false,
},
{
"Failed Write < Read",
getConfigFunc(),
func(conn net.Conn) {
_, _ = conn.Write([]byte{1})
_ = conn.Close()
},
args{2},
nil,
true,
},
{
"Failed Timeout",
getConfigFunc(),
func(conn net.Conn) {
_, _ = conn.Write([]byte{1})
time.Sleep(2 * time.Second)
_ = conn.Close()
},
args{2},
nil,
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
conf := tt.configFunc()
server.SetFunc(conf.LocalAddr, tt.serverFunc)
c, err := CreateConnection(conf)
if err != nil {
t.Errorf("CreateConnection() error = %v", err)
return
}
defer func() { _ = c.Close() }()
got, err := c.ReadUntil(tt.args.byteLen)
if (err != nil) != tt.wantErr {
t.Errorf("Connection.ReadUntil() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Connection.ReadUntil() = %v, want %v", got, tt.want)
}
_ = c.Close()
})
}
}