1.はじめに
GoでLightGBMモデルを使用したgRPC推論APIを作るというような記事がほぼなかったので、メモ兼情報共有として記事に残しておこうと思います。
2.前提
irisデータセットの特徴量をリクエストすると、probaのリストとpredictのラベルを返すような基本的なAPIを作ってみます
irisデータセットを学習させたモデルを用意し、S3から読み込んでくるということを想定してます
環境
- Ubuntu16.04
- Go: 1.15.6
- leaves(https://github.com/dmitryikh/leaves)
→LightGBMモデル読み込み、推論に使用、loadできるのはtxtかjsonファイルに限られているみたいなので注意です
gRPC, LightGBMなどの基本知識は割愛します。
3 protoの作成、コード生成
3.1 protoの作成
simple.proto
でinputとoutputを定義してあげます
proba_list
はリストなのでrepeated
を付けてあげます
syntax = "proto3";
package simple;
// request
message SimpleRequest{
double sepal_length = 1;
double sepal_width = 2;
double petal_length = 3;
double petal_width = 4;
}
// response
message SimpleResponse{
int64 predict_label = 1;
repeated double proba_list = 2;
}
// interface
service SimpleService{
rpc SimpleSend (SimpleRequest) returns (SimpleResponse) {}
}
3.2 コードの生成
作成したsimple.protoを使ってsimple.pb.go
, simple_grpc.pb.go
のコード生成をします
protocコマンド
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative simple.proto
4 utilsの作成
utilsでは、
- リスト内のargmaxを返してくれる関数(Argmax)
- s3からモデルをダウンロードしてくる関数(DownloadModelFile)
を作成します
今回、AWSとの認証はIAMユーザーのcredentialsによって解決しています
(本来であればroleで解決することを推奨します)
package utils
import(
"os"
"log"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)
func max(a []float64) float64 {
max := a[0]
for _, i := range a {
if i > max {
max = i
}
}
return max
}
func Argmax(a []float64) int64 {
max := max(a)
for i, _ := range a {
if max == a[i] {
return int64(i)
}
}
return -1
}
func DownloadModelFile(bucket_name string, object_key string) string {
sess := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))
model_name := "model.txt"
f, err := os.Create(model_name)
if err != nil {
log.Fatal(err)
}
bucketName := bucket_name
objectKey := object_key
downloader := s3manager.NewDownloader(sess)
n, err := downloader.Download(f, &s3.GetObjectInput{
Bucket: aws.String(bucketName),
Key: aws.String(objectKey),
})
if err != nil {
log.Fatal(err)
}
log.Printf("DownloadedSize: %d byte", n)
return model_name
}
5 serverスクリプトの作成
server用のスクリプトを作成します
モデルはAPI呼び出しの度に読み込むとボトルネックになってしまうので、グローバル変数として持っておくことにします
package main
import (
"context"
"log"
"net"
"fmt"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/dmitryikh/leaves"
pb "iris/pb"
utils "iris/utils"
)
type server struct {
pb.SimpleServiceServer
}
// モデルをs3からダウンロード、グローバル変数として読み込み
var model_name = utils.DownloadModelFile("bucket_no_namae", "model.txt")
var model, err = leaves.LGEnsembleFromFile(model_name, true)
// requestに対するresponseまでの処理を記述
func (s *server) SimpleSend(ctx context.Context, input *pb.SimpleRequest) (result *pb.SimpleResponse, err error) {
var sepal_length float64 = input.SepalLength
var sepal_width float64 = input.SepalWidth
var petal_length float64 = input.PetalLength
var petal_width float64 = input.PetalWidth
data := []float64{sepal_length, sepal_width, petal_length, petal_width}
predictions := make([]float64, 1*model.NOutputGroups())
model.Predict(data, 100, predictions)
var predict_label int64 = utils.Argmax(predictions)
return &pb.SimpleResponse{
PredictLabel: predict_label,
ProbaList: predictions},
nil
}
func main() {
// サーバー起動時にモデルの諸情報を記述
fmt.Printf("Name: %s\n", model.Name())
fmt.Printf("NFeatures: %d\n", model.NFeatures())
fmt.Printf("NOutputGroups: %d\n", model.NOutputGroups())
fmt.Printf("NEstimators: %d\n", model.NEstimators())
fmt.Printf("Transformation: %s\n", model.Transformation().Name())
lis, err := net.Listen("tcp", "localhost:50051")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
s := grpc.NewServer(grpc.KeepaliveParams(keepalive.ServerParameters{ MaxConnectionIdle: 5 * time.Minute,}))
pb.RegisterSimpleServiceServer(s, &server{})
if err := s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}
6 clientスクリプトの作成
最後にclient用のスクリプトを作成します
package main
import (
"context"
"log"
"time"
"google.golang.org/grpc"
pb "iris/pb"
)
func main() {
conn, err := grpc.Dial("localhost:50051", grpc.WithInsecure())
if err != nil {
log.Fatal("client connection error:", err)
}
defer conn.Close()
client := pb.NewSimpleServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
// 引数などで指定する仕様に直した方が良い
response, err := client.SimpleSend(ctx, &pb.SimpleRequest{SepalLength: 3.5, SepalWidth: 3.5, PetalLength: 3.5, PetalWidth: 3.5})
if err != nil {
log.Fatalf("could not greet: %v", err)
}
// responseの情報を出力
log.Printf("PredictLabel: %d", response.GetPredictLabel())
log.Printf("PlobaList: %f", response.GetProbaList())
}
これでserver.goを起動した状態でclient.goを実行すればレスポンスが返ってくるはずです