LoginSignup
1
1

More than 5 years have passed since last update.

pytorchの学習の際のデータcsv作成

Posted at

Data Loading and Processing Tutorialに使われているcsvを作成する方法の備忘録になります.

pytorchにおけるデータ入力

公式のチュートリアルによると,csvファイルによって入力することになっています.
チュートリアルで使用されているのは画像名と特徴量をまとめたものになっています.
今回は,画像名とクラスをまとめたものにしています.
こんな感じ.

data.csv
001.jpg,1
002.jpg,1
003.jpg,0
004.jpg,0

まずはこの形の csv を作っていきます.

フォルダ構造

画像は次のフォルダのように格納されています.
data/
 ├ positive/
 │ ├ 001.jpg
 │ ├ 002.jpg
 │ └ ...
 │
 └ negative/
   ├ 003.jpg
   ├ 004.jpg
   └ ...

csv作成

基本的に,glob で画像名を取得して,csv に書き込む形です.
トレーニングデータをまとめたtrain.csvと,テストデータをまとめたtest.csvを作成します.
今回は, train : test = 7 : 3 にしました.

mkdata.py
    import csv
    import glob
    import numpy as np
    from sklearn.model_selection import train_test_split


    x = glob.glob(normalimg_path)
    positive = glob.glob(positiveimg_path)
    print(len(normal))
    print(len(positive))

    y = np.zeros_like(normal, dtype=int).tolist()
    positive_y = np.ones_like(positive, dtype=int).tolist()

    x.extend(positive)
    y.extend(positive_y)

    (X_train, X_test,
    y_train, y_test) = train_test_split(
        x, y, test_size=0.3,
    )

ここまでで,train, testができましたので,あとはcsvに書き込むだけです.

mkdata.py
    with open('val.csv', 'w', newline='') as f: 
        writer = csv.writer(f)
        for data in zip(X_test, y_test):
            print(data)
            writer.writerow(data)   

    with open('train.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        for data in zip(X_train, y_train):
            print(data)
            writer.writerow(data)   

tran.csv, test.csvができました.

次回

次回は,実際に学習までやる予定です.

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1