LoginSignup
0
0

More than 1 year has passed since last update.

Amazon ComprehendをAWS SDK for Go v2を使って呼び出す

Last updated at Posted at 2021-07-24

はじめに

AWS SDK for Goがありますが、現在(2021年7月24日)のところ、AWS SDK for Go v2のAmazon Comporehendのコードサンプルがなかったので、とりあえず自分で作ってみることにしました。
個人的にはPythonを使用することが多いのですが、Go言語も業務で使用することがあるので、Go言語でもチャレンジしてみました。
※他の言語や他のサービスのコードサンプルは以下から参照できます。

Amazon Comprehendを使用するための事前準備

Go言語のプログラムをLinuxで実行してAmazon Comprehendのエンティティ認識のバッチを呼び出して結果をダウンロードして表示させることを考えます。オンラインのAPI呼び出しではなく、バッチ処理にしたのはAPI呼び出しだと文字数に(少し)きつい制限があるためです。エンティティ認識の処理だと、UTF-8のエンコーディングで5,000バイトが処理の上限です。
Amazon Comprehendを使用するには事前準備として権限の付与とロールの定義が必要です。

  • コマンド(Goのソースコードをコンパイルしたもの)を実行するユーザ(aws configureで設定したアクセスキーやシークレットキーを持つユーザ)にAmazon Comprehendの呼び出しを許可する
  • コマンドを実行するユーザにAWS のサービスにロールを渡すアクセス権限を付与する。(こちらを参考にしてください。)
  • Amazon Comprehendに渡すロールを定義する。(こちらを参考にしてください。)

処理の流れ

  1. 処理対象ファイルをS3にアップロードする。
  2. Amazon Comprehendを呼び出す(エンティティ認識のバッチ処理)
  3. Amazon Comprehendのジョブの完了を一定間隔で確認する。
  4. 処理結果のtar.gzファイルをダウンロードする。
  5. tar.gzファイルに含まれる結果のJSONからScoreの高い順にエンティティのキーワードと種別、スコアの値を表示する(デフォルトでは20個)

ソースコード(Go言語)

ソースコードは以下のとおりです。(GitHubにも置いてあります。)
43行目の「roleArn string = "arn:aws:iam::123456789012:role/comprehend-access-role"」のところは、事前準備で定義した「Amazon Comprehendに渡すロール」のARNを指定します。

main.go
package main

import (
    "archive/tar"
    "bufio"
    "compress/gzip"
    "context"
    "encoding/json"
    "flag"
    "fmt"
    "io"
    "net/url"
    "os"
    "sort"
    "time"

    "github.com/aws/aws-sdk-go-v2/config"
    "github.com/aws/aws-sdk-go-v2/service/comprehend"
    comprehend_types "github.com/aws/aws-sdk-go-v2/service/comprehend/types"
    "github.com/aws/aws-sdk-go-v2/service/s3"
)

type Entity struct {
    BeginOffset int     `json:"BeginOffset"`
    EndOffset   int     `json:"EndOffset"`
    Score       float64 `json:"Score"`
    Text        string  `json:"Text"`
    Type        string  `json:"Type"`
}

type JsonOutput struct {
    File     string   `json:"File"`
    Entities []Entity `json:"Entities"`
}

const (
    resultFileName string = "output"
    bufferSize     int    = 1024 * 1024
    interval       int    = 10
    timeoutCount   int    = 100
    jobName        string = "sample-entities-detection-job"
    // ここのロールは置き換えること
    roleArn string = "arn:aws:iam::123456789012:role/comprehend-access-role"
)

func main() {
    pBucketName := flag.String("bucket", "testcomprehend-tn", "Bucket to put a content on.")
    pPrefixName := flag.String("prefix", "comprehend/", "Prefix to store a content file")
    pContentFileName := flag.String("file", "content.txt", "The content file")
    pLimitNumber := flag.Int("limit", 20, "The limit to display keywords")

    flag.Parse()

    fmt.Println("START!!!!")
    fmt.Println("Bucket:", *pBucketName)
    fmt.Println("Prefix:", *pPrefixName)
    fmt.Println("Limit:", *pLimitNumber)
    fmt.Println("File:", *pContentFileName)

    cfg, err := config.LoadDefaultConfig(context.TODO())
    if err != nil {
        panic("configuration error, " + err.Error())
    }

    s3Client := s3.NewFromConfig(cfg)
    comprehendClient := comprehend.NewFromConfig(cfg)

    // 処理対象ファイルをS3にアップロードする
    file, err := os.Open(*pContentFileName)
    if err != nil {
        fmt.Println("Unable to open file " + *pContentFileName)
        return
    }
    defer file.Close()

    objectName := *pPrefixName + "input/" + *pContentFileName
    fmt.Println("Object Name: " + objectName)
    input := &s3.PutObjectInput{
        Bucket: pBucketName,
        Key:    &objectName,
        Body:   file,
    }

    _, err = s3Client.PutObject(context.TODO(), input)
    if err != nil {
        fmt.Println("Got error uploading file:")
        fmt.Println(err)
        return
    }
    // Amazon Comprehend 呼び出し
    inputS3Uri := "s3://" + *pBucketName + "/" + objectName
    fmt.Println("InputS3URI: " + inputS3Uri)
    inputConfig := &comprehend_types.InputDataConfig{
        S3Uri:       &inputS3Uri,
        InputFormat: comprehend_types.InputFormatOneDocPerFile,
    }
    outputS3Uri := "s3://" + *pBucketName + "/" + *pPrefixName + "output/"
    fmt.Println("OutputS3URI: " + outputS3Uri)
    outputConfig := &comprehend_types.OutputDataConfig{
        S3Uri: &outputS3Uri,
    }
    roleArnForInput := roleArn
    jobNameForInput := jobName
    jobInput := &comprehend.StartEntitiesDetectionJobInput{
        DataAccessRoleArn: &roleArnForInput,
        InputDataConfig:   inputConfig,
        LanguageCode:      comprehend_types.LanguageCodeJa,
        OutputDataConfig:  outputConfig,
        JobName:           &jobNameForInput,
    }
    out, err := comprehendClient.StartEntitiesDetectionJob(context.TODO(), jobInput)
    if err != nil {
        fmt.Println("Starting an entities detection job Error:")
        fmt.Println(err)
        return
    }
    jobId := *out.JobId
    fmt.Println("Job ID: " + jobId)

    // Amazon Comprehend 完了確認
    describeJobInput := &comprehend.DescribeEntitiesDetectionJobInput{
        JobId: &jobId,
    }
    var outDesc *comprehend.DescribeEntitiesDetectionJobOutput
    for i := 0; i < timeoutCount; i++ {
        fmt.Println("In Progress...")
        time.Sleep(time.Duration(interval) * time.Second)
        outDesc, err = comprehendClient.DescribeEntitiesDetectionJob(context.TODO(), describeJobInput)
        if err != nil {
            fmt.Println("Getting a status of the entities detection job Error:")
            fmt.Println(err)
            return
        }
        if outDesc.EntitiesDetectionJobProperties.JobStatus == comprehend_types.JobStatusCompleted {
            fmt.Println("Job Completed.")
            break
        }
    }
    if outDesc.EntitiesDetectionJobProperties.JobStatus != comprehend_types.JobStatusCompleted {
        fmt.Println("Job Timeout.")
        return
    }

    // Amazon Comprehend 結果ダウンロード
    downloadS3Uri := *outDesc.EntitiesDetectionJobProperties.OutputDataConfig.S3Uri
    fmt.Println("Output S3 URI: " + downloadS3Uri)
    parsedUri, err := url.Parse(downloadS3Uri)
    if err != nil {
        fmt.Println("URI Parse Error:")
        fmt.Println(err)
        return
    }
    downloadObjectName := parsedUri.Path[1:]
    fmt.Println("Download Object Name: " + downloadObjectName)

    getObjectInput := &s3.GetObjectInput{
        Bucket: pBucketName,
        Key:    &downloadObjectName,
    }
    outGetObject, err := s3Client.GetObject(context.TODO(), getObjectInput)
    if err != nil {
        fmt.Println("Getting S3 object Error:")
        fmt.Println(err)
        return
    }
    defer outGetObject.Body.Close()

    // tar.gzファイルから結果のJSONを取り出す
    gzipReader, err := gzip.NewReader(outGetObject.Body)
    if err != nil {
        fmt.Println("Reading a gzip file Error:")
        fmt.Println(err)
        return
    }
    defer gzipReader.Close()
    tarfileReader := tar.NewReader(gzipReader)

    var jsonOutput JsonOutput
    for {
        tarfileHeader, err := tarfileReader.Next()
        if err == io.EOF {
            break
        }
        if err != nil {
            fmt.Println("Reading a tar file Error:")
            fmt.Println(err)
            return
        }

        if tarfileHeader.Name == resultFileName {
            jsonReader := bufio.NewReaderSize(tarfileReader, bufferSize)
            for {
                // TODO: isPrefixを見て、行がバッファに対して長すぎる場合の処理を行なう
                line, _, err := jsonReader.ReadLine()
                if err == io.EOF {
                    break
                }
                if err != nil {
                    fmt.Println("Reading a line Error:")
                    fmt.Println(err)
                    return
                }
                json.Unmarshal([]byte(line), &jsonOutput)
                fmt.Println()
                fmt.Println("File: " + jsonOutput.File)
                // Score順にソートする
                entities := jsonOutput.Entities
                sort.Slice(entities, func(i, j int) bool { return entities[i].Score > entities[j].Score })
                // Amazon Comprehend 結果表示
                for i := 0; i < *pLimitNumber; i++ {
                    fmt.Printf("%s (%s): %f\n", entities[i].Text, entities[i].Type, entities[i].Score)
                }
            }
        }
    }

    fmt.Println()
    fmt.Println()
    fmt.Println("END!!!!")
}

go buildでビルドできます。また、以下のコマンドでオプション(引数)を確認できます。-fileの指定が解析したいファイルになります。デフォルトはcontent.txtという名前です。UTF-8で文章を格納してください。

./mycomprehend -h

Usage of ./mycomprehend:
  -bucket string
        Bucket to put a content on. (default "testcomprehend-tn")
  -file string
        The content file (default "content.txt")
  -limit int
        The limit to display keywords (default 20)
  -prefix string
        Prefix to store a content file (default "comprehend/")

実行結果

content.txtについてエンティティ認識の処理をした結果は以下のようになりました。

(前略)
File: content.txt
2020年12月 (DATE): 0.999014
3時間 (QUANTITY): 0.997277
Linux Academy (ORGANIZATION): 0.995822
3年 (QUANTITY): 0.995796
180分 (QUANTITY): 0.995555
2回 (QUANTITY): 0.995218
33,000円 (QUANTITY): 0.993613
https://qiita.com/takanattie/items/7dd188ce14a2a5b9ef14 (OTHER): 0.990118
1度 (QUANTITY): 0.988592
3ヶ月半 (QUANTITY): 0.988081
AWS (ORGANIZATION): 0.987514
AWS (ORGANIZATION): 0.980833
U.S. (ORGANIZATION): 0.976798
AWS (ORGANIZATION): 0.976790
AWS (ORGANIZATION): 0.975587
3年間 (QUANTITY): 0.975035
2周 (QUANTITY): 0.972928
Linux Academy (ORGANIZATION): 0.966450
AWS (ORGANIZATION): 0.966190
Qiita Advent Calendar 2020 (TITLE): 0.962282

認識されたエンティティのタイプ(ORGANIZATIONなど)の情報もあるので、そのタイプでフィルタリングすることもできると思います。

補足

上記のコードは、サンプルとなるように書いたので不充分な点があります。特に以下の点は改善の余地があると考えています。

  • コマンドの戻り値(Exit Code)を設定する
  • S3にAmazon Comprehendの結果が格納されるので、それをトリガとしてLambdaを起動して結果を取り出す、などの処理を行なうようにする。(イベントドリブンにする)

また、今回はエンティティ認識だけでしたが、キーフレーズ抽出なども同じように実行できるはずです(コードの修正は必要になりますが)。APIの定義はソースコードとしてGitHubにもありますし、APIのドキュメントもありますので、そこを読み解けばコードをかけるはずですが、やはり一連の流れを示したサンプルコードが欲しいところです。

以上。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0