LoginSignup
4
0

More than 3 years have passed since last update.

便利にkaggle Datasetにアップロード

Posted at

kaggle notebook縛りのcode competition

最近kaggleでは、推論時にkaggleのnotebook環境しか使えないcode competitionが増えて、深層学習を使う系のコンペですと頻繁にローカルで学習済みのモデルのパラメータファイルをkaggle Datasetにアップロードして使うことがあります。

kaggle APIコマンドでもまだ面倒

Kaggle APIコマンドを使うことで、WebUIでの手作業は省けてデータダウンロード・アップロードが自動化できるのですが、メタデータのJSONファイルやAPIコマンドの編集・作成が面倒だったりします。
そこで、pythonで実行できるwrapper関数を作りましたので、供養しておきます。
関数の入力と、実験パラメータの記載されているyamlファイルと連携すると、実験条件等を自動でデータセットのコメントなどに反映できミスの予防や省力化に繋がります。

必要な前準備

  • kaggle APIのインストールとAPIトークンの生成が必要です。詳細は関連記事をご覧ください。
  • 当然ながら、アップロードしたいファイルのパスにデータがないといけません。
    • この関数はmodelというディレクトリに様々な実験ごとに更にmodel_exp_XXとサブディレクトリがあり、その中にモデルのパラメータファイルがあることを前提しています。
    • 関数の引数ではモデルファイルの拡張子を指定して、.pth、.h5など適宜変更します。
  • loggerを用意すると、ログファイルに出力するように一応しています。

import subprocess
import glob
import json
import os
def upload_to_kaggle(

                     title: str, 
                     k_id: str,  
                     path: str, 
                     comments: str,
                     update:bool,
                     logger=None,
                     extension = '.pth',
                     subtitle='', 
                     description="",
                     isPrivate = True,
                     licenses = "unknown" ,
                     keywords = [],
                     collaborators = []
                     ):
    '''
    >> upload_to_kaggle(title, k_id, path,  comments, update)

    Arguments
    =========
     title: the title of your dataset.
     k_id: kaggle account id.
     path: non-default string argument of the file path of the data to be uploaded.
     comments:non-default string argument of the comment or the version about your upload.
     logger: logger object if you use logging, default is None.
     extension: the file extension of model weight files, default is ".pth"
     subtitle: the subtitle of your dataset, default is empty string.
     description: dataset description, default is empty string.
     isPrivate: boolean to show wheather to make the data public, default is True.
     licenses = the licenses description, default is "unkown"; must be one of /
     ['CC0-1.0', 'CC-BY-SA-4.0', 'GPL-2.0', 'ODbL-1.0', 'CC-BY-NC-SA-4.0', 'unknown', 'DbCL-1.0', 'CC-BY-SA-3.0', 'copyright-authors', 'other', 'reddit-api', 'world-bank'] .
     keywords : the list of keywords about the dataset, default is empty list.
     collaborators: the list of dataset collaborators, default is empty list.
   '''
    model_list = glob.glob(path+f'/*{extension}')
    if len(model_list) == 0:
        raise FileExistsError('File does not exist, check the file extention is correct \
        or the file directory exist.')

    if path[-1] == '/':
        raise ValueError('Please remove the backslash in the end of the path')

    data_json =  {
        "title": title,
        "id": f"{k_id}/{title}",
        "subtitle": subtitle,
        "description": description,
        "isPrivate": isPrivate,
        "licenses": [
            {
                "name": licenses
            }
        ],
        "keywords": [],
        "collaborators": [],
        "data": [

        ]
    }

    data_list = []
    for mdl in model_list:
        mdl_nm = mdl.replace(path+'/', '')
        mdl_size = os.path.getsize(mdl) 
        data_dict = {
            "description": comments,
            "name": mdl_nm,
            "totalBytes": mdl_size,
            "columns": []
        }
        data_list.append(data_dict)
    data_json['data'] = data_list


    with open(path+'/dataset-metadata.json', 'w') as f:
        json.dump(data_json, f)

    script0 = ['kaggle',  'datasets', 'create', '-p', f'{path}' , '-m' , f'\"{comments}\"']
    script1 = ['kaggle',  'datasets', 'version', '-p', f'{path}' , '-m' , f'\"{comments}\"']

    #script0 = ['echo', '1']
    #script1 = ['echo', '2']

    if logger:    
        logger.info(data_json)

        if update:
            logger.info(script1)
            logger.info(subprocess.check_output(script1))
        else:
            logger.info(script0)
            logger.info(script1)
            logger.info(subprocess.check_output(script0))
            logger.info(subprocess.check_output(script1))

    else:
        print(data_json)

        if update:
            print(script1)
            print(subprocess.check_output(script1))
        else:
            print(script0)
            print(script1)
            print(subprocess.check_output(script0))
            print(subprocess.check_output(script1))

こうすることでもっと効率よくできてるよという方がいましたら、ぜひコメントください。

関連記事:
- Kaggle APIで楽にGCPにデータをダウンロード
- Github 上の自分のコードを Kaggle Code Competition で使うのを CI で自動化

4
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
0