2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

RedshiftMLでtitanic

Posted at

はじめに

2021年にGAされたRedshiftMLについて触ってみようと思い、皆さん大好きTitanicでやってみました。
参考にしたのは、いつもお世話になっているクラメソさんのブログ「AWS RedShift MLでSQLを使った機械学習をする」です。(クラメソさんは別のデータで実施してます)

事前準備

データ取得

Kaggleのtitanicデータページから3つのファイルを取得。
・titanic_train.csv
・titanic_test.csv
・gender_submission.csv

IAMロール作成&Redshiftクラスタ作成

(IMAロール系は詳しくなく)「AmazonS3FullAccess」、「AmazonSageMakerFullAccess」、「AmazonRedshiftDataFullAccess」を付けたRedshift用ロールを作ってRedshiftクラスターを作成。
(おそらくいらないポリシーもある)
Redshiftインスタンスはdc2.large×1nodeでやってます。

テーブル作成

titanicデータ投入用のテーブルを作成。testテーブルは予測項目の「Survived 」を保有していないことに注意。

train用テーブル
CREATE TABLE titanic_train(
PassengerId integer,
Survived integer,
Pclass integer,
Name varchar,
Sex varchar,
Age decimal,
SibSp integer,
Parch integer,
Ticket varchar,
Fare varchar, 
Cabin varchar,
Embarked varchar);
test用テーブル(Survived無し)
CREATE TABLE titanic_test(
PassengerId integer,
Pclass integer,
Name varchar,
Sex varchar,
Age decimal,
SibSp integer,
Parch integer,
Ticket varchar,
Fare varchar, 
Cabin varchar,
Embarked varchar);
gender_submission
CREATE TABLE gender_submission(
PassengerId integer,
Survived integer);

データ投入

S3からテーブルにデータをロード。(他のわかりやすいページで解説されていると思いますので省略)

モデル作成

CREATE MODEL

このあたりは開発ガイドも参照しながら作成。
今回は単純なモデル作成で実行してますが、XGBoostを指定したトレーニングもできるようですね。

CREATE MODEL
CREATE MODEL titanic_survived
FROM titanic_train
TARGET Survived
FUNCTION predict_survived
IAM_ROLE '(IAMロールのarn)'
SETTINGS (
  S3_BUCKET '(S3バケット名)'
);

SHOW MODEL(TRAINING中)

SHOW MODEL(TRAINING中)
dev=# SHOW MODEL titanic_survived;
           Key            |                                    Value
--------------------------+------------------------------------------------------------------------------
 Model Name               | titanic_survived
 Schema Name              | public
 Owner                    | postgres01
 Creation Time            | Mon, 14.02.2022 03:58:14
 Model State              | TRAINING            ←★実行中はTRAINING
                          |
 TRAINING DATA:           |
 Query                    | SELECT *
                          | FROM "TITANIC_TRAIN"
 Target Column            | SURVIVED
                          |
 PARAMETERS:              |
 Model Type               | auto
 Problem Type             |
 Objective                |
 AutoML Job Name          | redshiftml-20220214035814151238
 Function Name            | predict_survived
 Function Parameters      | passengerid pclass name sex age sibsp parch ticket fare cabin embarked
 Function Parameter Types | int4 int4 varchar varchar numeric int4 int4 varchar varchar varchar varchar
 IAM Role                 | (IAMロールのarn)
 S3 Bucket                | (S3バケット)
 Max Runtime              | 5400
(22 rows)

SHOW MODEL(TRAINING完了)

待つこと・・・、1時間半・・・。

SHOW MODEL(TRAINING完了)
dev=# SHOW MODEL titanic_survived;
           Key            |                                    Value
--------------------------+------------------------------------------------------------------------------
 Model Name               | titanic_survived
 Schema Name              | public
 Owner                    | postgres01
 Creation Time            | Mon, 14.02.2022 03:58:14
 Model State              | READY            ←★Readyになったら完了★
 Training Job Status      | MaxAutoMLJobRuntimeReached
 validation:f1_binary     | 0.783060
 Estimated Cost           | 9.685500
                          |
 TRAINING DATA:           |
 Query                    | SELECT *
                          | FROM "TITANIC_TRAIN"
 Target Column            | SURVIVED
                          |
 PARAMETERS:              |
 Model Type               | auto
 Problem Type             | BinaryClassification
 Objective                | F1
 AutoML Job Name          | redshiftml-20220214035814151238
 Function Name            | predict_survived
 Function Parameters      | passengerid pclass name sex age sibsp parch ticket fare cabin embarked
 Function Parameter Types | int4 int4 varchar varchar numeric int4 int4 varchar varchar varchar varchar
 IAM Role                 | (IAMロールのarn)
 S3 Bucket                | (S3バケット)
 Max Runtime              | 5400
(25 rows)

予測

TRAINデータで予測

予測実行
SELECT predict_survived, Survived, COUNT(*)
 FROM (SELECT predict_survived(
                  PassengerId ,Pclass ,Name ,Sex ,Age ,SibSp ,Parch ,
                  Ticket ,Fare , Cabin ,Embarked), Survived
         FROM titanic_train)
 GROUP BY predict_survived, Survived order by predict_survived, Survived;
結果
 predict_survived | survived | count
------------------+----------+-------
                0 |        0 |   534
                0 |        1 |    17
                1 |        0 |    15
                1 |        1 |   325
(4 rows)

TRAINデータの正解率は96.4%

TESTデータで予測

予測実行
SELECT predict_survived, Survived, COUNT(*)
  FROM (SELECT predict_survived(
                   t.PassengerId ,Pclass ,Name ,Sex ,Age ,SibSp ,Parch ,
                   Ticket ,Fare , Cabin ,Embarked), Survived
          FROM titanic_test t inner join gender_submission s on t.PassengerId = s.PassengerId)
 GROUP BY predict_survived, Survived order by predict_survived, Survived;
結果
 predict_survived | survived | count
------------------+----------+-------
                0 |        0 |   241
                0 |        1 |    44
                1 |        0 |    25
                1 |        1 |   108
(4 rows)

TESTデータの正解率は83.4%
何も頭を使わずに作成したモデルの精度としては十分な感じでしょうか。

最後に

あまり詰まることなく、サクサク実行できた感触です。
CREATE MODELが終わるまでRedshiftを立ち上げっぱなしなのが個人アカウントで利用しているとややしんどい・・・。

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?