LoginSignup
0
0

More than 3 years have passed since last update.

Go, gRPC, LightGBMモデルを使って推論APIを作る

Last updated at Posted at 2021-01-29

1.はじめに

GoでLightGBMモデルを使用したgRPC推論APIを作るというような記事がほぼなかったので、メモ兼情報共有として記事に残しておこうと思います。

2.前提

irisデータセットの特徴量をリクエストすると、probaのリストとpredictのラベルを返すような基本的なAPIを作ってみます
irisデータセットを学習させたモデルを用意し、S3から読み込んでくるということを想定してます

環境
- Ubuntu16.04
- Go: 1.15.6
- leaves(https://github.com/dmitryikh/leaves)
→LightGBMモデル読み込み、推論に使用、loadできるのはtxtかjsonファイルに限られているみたいなので注意です

gRPC, LightGBMなどの基本知識は割愛します。

3 protoの作成、コード生成

3.1 protoの作成

simple.protoでinputとoutputを定義してあげます
proba_listはリストなのでrepeatedを付けてあげます

simple.proto
syntax = "proto3";

package simple;

// request
message SimpleRequest{
    double sepal_length = 1;
    double sepal_width = 2;
    double petal_length = 3;
    double petal_width = 4;
}

// response
message SimpleResponse{
    int64 predict_label = 1;
    repeated double proba_list = 2;
}

// interface
service SimpleService{
    rpc SimpleSend (SimpleRequest) returns (SimpleResponse) {}
}

3.2 コードの生成

作成したsimple.protoを使ってsimple.pb.go, simple_grpc.pb.goのコード生成をします
protocコマンド
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative simple.proto

4 utilsの作成

utilsでは、
- リスト内のargmaxを返してくれる関数(Argmax)
- s3からモデルをダウンロードしてくる関数(DownloadModelFile)
を作成します

今回、AWSとの認証はIAMユーザーのcredentialsによって解決しています
(本来であればroleで解決することを推奨します)

utils.go
package utils

import(
    "os"
    "log"

    "github.com/aws/aws-sdk-go/aws"
    "github.com/aws/aws-sdk-go/aws/session"
    "github.com/aws/aws-sdk-go/service/s3"
    "github.com/aws/aws-sdk-go/service/s3/s3manager"
)

func max(a []float64) float64 {
    max := a[0]
    for _, i := range a {
        if i > max {
            max = i
        }
    }
    return max
}

func Argmax(a []float64) int64 {
    max := max(a)
    for i, _ := range a {
        if max == a[i] {
            return int64(i)
        }
    }
    return -1
}

func DownloadModelFile(bucket_name string, object_key string) string {
    sess := session.Must(session.NewSessionWithOptions(session.Options{
        SharedConfigState: session.SharedConfigEnable,
    }))
    model_name := "model.txt"
    f, err := os.Create(model_name)
    if err != nil {
        log.Fatal(err)
    }

    bucketName := bucket_name
    objectKey := object_key

    downloader := s3manager.NewDownloader(sess)
    n, err := downloader.Download(f, &s3.GetObjectInput{
        Bucket: aws.String(bucketName),
        Key:    aws.String(objectKey),
    })
    if err != nil {
        log.Fatal(err)
    }

    log.Printf("DownloadedSize: %d byte", n)

    return model_name
}

5 serverスクリプトの作成

server用のスクリプトを作成します
モデルはAPI呼び出しの度に読み込むとボトルネックになってしまうので、グローバル変数として持っておくことにします

server.go
package main

import (
    "context"
    "log"
    "net"
    "fmt"
    "time"

    "google.golang.org/grpc"
    "google.golang.org/grpc/keepalive"
    "github.com/dmitryikh/leaves"

    pb "iris/pb"
    utils "iris/utils"
)

type server struct {
    pb.SimpleServiceServer
}

// モデルをs3からダウンロード、グローバル変数として読み込み
var model_name = utils.DownloadModelFile("bucket_no_namae", "model.txt")
var model, err = leaves.LGEnsembleFromFile(model_name, true)

// requestに対するresponseまでの処理を記述
func (s *server) SimpleSend(ctx context.Context, input *pb.SimpleRequest) (result *pb.SimpleResponse, err error) {
    var sepal_length float64 = input.SepalLength
    var sepal_width float64 = input.SepalWidth
    var petal_length float64 = input.PetalLength
    var petal_width float64 = input.PetalWidth

    data := []float64{sepal_length, sepal_width, petal_length, petal_width}
    predictions := make([]float64, 1*model.NOutputGroups())
    model.Predict(data, 100, predictions)
    var predict_label int64 = utils.Argmax(predictions)
    return &pb.SimpleResponse{
            PredictLabel: predict_label,
            ProbaList: predictions},
            nil
}

func main() {
    // サーバー起動時にモデルの諸情報を記述
    fmt.Printf("Name: %s\n", model.Name())
    fmt.Printf("NFeatures: %d\n", model.NFeatures())
    fmt.Printf("NOutputGroups: %d\n", model.NOutputGroups())
    fmt.Printf("NEstimators: %d\n", model.NEstimators())
    fmt.Printf("Transformation: %s\n", model.Transformation().Name())

    lis, err := net.Listen("tcp", "localhost:50051")
    if err != nil {
        log.Fatalf("failed to listen: %v", err)
    }

    s := grpc.NewServer(grpc.KeepaliveParams(keepalive.ServerParameters{ MaxConnectionIdle: 5 * time.Minute,}))
    pb.RegisterSimpleServiceServer(s, &server{})
    if err := s.Serve(lis); err != nil {
        log.Fatalf("failed to serve: %v", err)
    }
}

6 clientスクリプトの作成

最後にclient用のスクリプトを作成します

client.go
package main

import (
    "context"
    "log"
    "time"

    "google.golang.org/grpc"
    pb "iris/pb"
)


func main() {
    conn, err := grpc.Dial("localhost:50051", grpc.WithInsecure())
    if err != nil {
        log.Fatal("client connection error:", err)
    }
    defer conn.Close()
    client := pb.NewSimpleServiceClient(conn)
    ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    defer cancel()
    // 引数などで指定する仕様に直した方が良い
    response, err := client.SimpleSend(ctx, &pb.SimpleRequest{SepalLength: 3.5, SepalWidth: 3.5, PetalLength: 3.5, PetalWidth: 3.5})
    if err != nil {
        log.Fatalf("could not greet: %v", err)
    }
    // responseの情報を出力
    log.Printf("PredictLabel: %d", response.GetPredictLabel())
    log.Printf("PlobaList: %f", response.GetProbaList())
}

これでserver.goを起動した状態でclient.goを実行すればレスポンスが返ってくるはずです

7 参考

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0