LoginSignup
8
8

More than 1 year has passed since last update.

ビジネスサイドと喧嘩しないための機械学習モデル管理

Last updated at Posted at 2019-12-09

はじめに

オフィス賃料は,住宅と違って募集賃料があまり公開されておらず(都心5区だと3分の1以下の物件しか公開されていない),また成約賃料は基本手に入れられないため 正解データ集めが困難 です.
そんな中で不動産賃料を推定するにあたって,ビジネスサイドのオフィス不動産のプロの目も取り入れながら,共同でモデルの精度検証を行なっています.

ビジネスサイトと多くやりとりが発生するため,特に注意していること・工夫していることについて書きます.

課題

上記の通り,ビジネスサイドのメンバと共同で対応するため,MLエンジニアだけで開発している場合よりも以下のような課題が発生しやすいです.

課題1. 過去のモデルの管理

ビジネスサイドのエンジニアからのフィードバックやソースデータの拡充に伴い,
外れ値物件除去,特徴量追加,学習データ変更などのロジック変更を頻繁に行ない,より精度の高いモデルを作成しようと日々改善しています.
ただ,前述の通りこれらのモデルの精度は必ずしも数値的な指標だけで評価しきれない(プロの目も必要となる)ため,
過去よりも良いモデルを作ったつもりだったのに,

「あれ,これ2ヶ月前のモデルだったらここの値いい感じになってたのに変な値になってるぞ!」

なんてことがよく起こります.
大抵そういうことに気づくのは切羽詰まっているときなので,ピリピリします.

そんな際に,すぐに過去のモデルの状態に切り戻すことができると,すぐに原因調査をしたり,本番データに反映したりといった対応が高速で可能になります

課題2. モデルの推定値に対する説明

共同で精度を検証するにあたり,「ここの推定値,なんでこんな値になってんの??」という原因説明を求められることが多くあります.
そのような際に,「この特徴量の影響を大きく受けている」など説明できると,より有意義なディスカッションを続けることができます.

課題3. 精度検証のためのアウトプットの可視化

上記の通り,日々新しいモデルを複数作成してモデルのアウトプットを出力し,高速でビジネスサイドに精度を確認してもらっています.
その際,表形式での出力だと直感的な分析が難しく,このエリアの値だけパッと見せてよというコミュニケーションが数往復に渡り発生してしまうことがあります.

やっていること

上記を解決するために以下の対応をしています.

対応1 モデルのバージョン管理

コード管理

ロジック(ハイパラメータ含む),学習データ変更のチケットをgithubのissue管理して,issueごとにbranchを切っています.
次にリリースするversionがv1.1.0で, 対応したissue番号が4と6あればdev/v1.1.0/issue4_6のようなbranch名にしています.release時には一旦v1.1.0branchにmergeして,タグ管理もしています.

中間生成ファイルのバージョン管理

学習や推定に用いるファイルは全てs3管理しています.
s3に機械学習用のバケットを用意して,branch名と同じディレクトリ構成(dev/v1.1.0/issue4_6)以下に中間ファイルを格納しています.

対応2 SHAPを用いた原因説明

「ここの推定値,なんでこんな値になってんの??」といわれたとき,SHAPを使うとなんとなく原因がわかることがまあまああるので,推定する際にshap_valueカラムも追加してあげるうようにしています.
SHAPについてはいろいろ記事があるのでそちらを参照してください.

Shapを用いた機械学習モデルの解釈説明

簡単にいうと,「各特徴量が推定値にどのくらい寄与したのか」を教えてくれるものです.

import shap
def calc_shap(df_, feature_list, model, rank_th=5) -> pd.core.frame.DataFrame:
    '''shap_valueカラムを追加する
    Args:
        df_ (pd.core.frame.DataFrame): 特徴量追加後の,推定賃料を算出したいデータ
        feature_list ([str]): 特徴量名リスト
        model : 学習モデル
        rank_th (int): shap_valueの高い特徴量を何個まで表示するか.default 5.
    Returns:
        pd.core.frame.DataFrame.
    '''
    df = df_.copy()
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(df[feature_list])
    shap_df = pd.DataFrame(shap_values, columns=feature_list) # df[feature_values]の値が全て寄与度になったデータフレーム
    shap_rank = shap_df.applymap(lambda x: abs(x)).rank(axis=1, ascending=False, method='min') # 各レコードごとに,絶対値の大きいもの(⇔寄与度の高いもの)から順位づけする
    main_contri_col = {i: [col for col in r.keys() if r[col] <= rank_th] for i, r in shap_rank.iterrows()} # 各レコードごとの寄与度rank_th位までのカラムリストを取得
    main_contri_val = [shap_df.loc[i, main_contri_col[i]].to_dict() for i in main_contri_col.keys()] # 各レコードごとの寄与度rank_th位までのカラムとその寄与度を取得
    df['shap_value'] = main_contri_val
    return df

このshap_valueカラムの値は,json文字列で
{'面積': 22627, '築年数': 717, 'hoge1': -5409, 'hoge2': 2968, 'hoge3': 3791}
のような感じになります.これによって「面積がめちゃくちゃ寄与しているんだな,」とわかり,調べてみたら推定レコードの面積orderが間違っていた,などの発見もできたり便利です.

対応3 精度検証用の地図上への可視化

ビジネスサイドメンバが,独自に高速に精度を確認するために,
単純な精度や前回のロジックとの結果の差分だけでなく,学習データと様々な物件に対する推定値を可視化しています.
以下の画像は過去のサンプルですが,推定値が高い物件はオレンジ,安い物件は水色でplotするようにしています.(普段は学習データも黒丸でplotしています)

image.png

可視化コードは以下です.

'''foliumによる推定賃料の可視化
Required:
    pandas
    folium
    matplotlib
'''

import subprocess
import pandas as pd
import folium
import matplotlib.colors as cl

def calc_RGB_value(norm_rent: float) -> str:
    '''0-1スケールに圧縮された数値をRGBの16進数表記として返す
    一番安い物件が水色, 高い物件がオレンジとなるようにする
    Args:
        norm_rent (float): 0-1スケールに圧縮された数値
    Returns:
        str
        ex: #54b0c5
    '''
    R_val = 41 + (255 - 41) * norm_rent
    G_val = 182 + (150 - 182) * norm_rent
    B_val = 246 + (0 - 246) * norm_rent
    return cl.to_hex((R_val / 255, G_val / 255, B_val / 255, 1))

def add_color_col(df_: pd.core.frame.DataFrame) -> pd.core.frame.DataFrame:
    '''colorカラムを追加
    Args:
        df_ (pd.core.frame.DataFrame): estimated_rentカラムを含むデータフレーム
    Returns:
        pd.core.frame.DataFrame
        'color'カラムを追加して返す
    '''
    df = df_.copy()
    norm = cl.Normalize(vmin=df['estimated_rent'].min(), vmax=df['estimated_rent'].max())
    norm_rent_ = [norm(v) for v in df['estimated_rent']]  # 推定賃料を0,1スケールにする
    color_ = [calc_RGB_value(norm_rent) for norm_rent in norm_rent_]
    df["color"] = color_
    return df

class Drawer:
    def __init__(self, ld_path, ed_path):
        self.read_ld(ld_path)
        self.read_ed(ed_path)
        self.add_color_col()
        self.init_map()
    def read_ld(self, ld_path):
        '''学習データの読み込み
        '''
        self.ld = pd.read_csv(ld_path)
        assert 'answer_rent' in self.ld.columns
    def read_ed(self, ed_path):
        '''推定賃料算出後のデータ読み込み
        '''
        self.ed = pd.read_csv(ed_path)
        assert 'estimated_rent' in self.ld.columns
    def add_color_col(self):
        self.ed = add_color_col(self.ed)
        self.ld['color'] = '#262626' #黒
    def init_map(self):
        '''mapを初期化する
        '''
        self.map = folium.Map(
            location=[self.ed.latitude.mean(),self.ed.longitude.mean()],
            zoom_start=6, tiles='cartodbpositron')
    def add_ld_plot(self, size=15):
        '''学習データのplot
        size (int): plotの円のsize.default 15.
        '''
        for i, row in self.ld.iterrows():
            folium.Circle(
                radius=size, location=[row['latitude'], row['longitude']],
                popup='物件名称: %s' % (row['物件名称'] if '物件名称' in row.keys() else '' +
                '<br/>正解賃料: {:,.0f}円/坪'.format(row['ans_rent']),
                color=row['color'], fill_color=row['color']).add_to(self.map)
    def add_ed_plot(self, size=5):
        '''推定賃料のplot
        size (int): plotの円のsize.default 5.
        '''
        for i, row in self.ld.iterrows():
            folium.Circle(
                radius=size, location=[row['latitude'], row['longitude']],
                popup='物件名称: %s' % (row['物件名称'] if '物件名称' in row.keys() else '' +
                '<br/>推定賃料: {:,.0f}円/坪'.format(row['estimated_rent']),
                color=row['color'], fill_color=row['color']).add_to(self.map)
if __name__ == '__main__':
    drawer = Drawer(
        ld_path='s3://hogehoge/dev/v1.1.0/issue4_6/ld.csv',
        ed_path='s3://hogehoge/dev/v1.1.0/issue4_6/ed.csv')
    drawer.add_ld_plot()
    drawer.add_ed_plot()
    drawer.map.save('map.html')
    subprocess.call(
        ['aws', 's3', 'mv', 'map.html', 's3://hogehoge/dev/v1.1.0/issue4_6/map.html'])

おわりに

今後

割とやり方がまだ原始的な部分も残っていますが,上のやり方でデータ・コードのバージョン同期管理を行なっています.
今後,管理をよりしやすくするためにMLflowの導入も考えていますが,導入次第続編を書こうかと思います.

8
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
8