0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

pytorchでよく使うものを残しておきたい メモ。

Last updated at Posted at 2022-09-14

同じような記事あるかもしれませんが自分用メモ

pytorchをよく使う機会があり、
なるたけ汎用的にいつも使うものを残しておきたい。

前処理用クラス

#入力画像の前処理クラス

class DataTransform():
  '''
 参考)https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
  画像の前処理用クラス。
  画像のサイズをリサイズ。デフォ値は個人的によく使う
 VGG16 訓練時の平均と分散を設定してます。
  中の処理は必要に応じて修正。
  '''
  def __init__(self,resize=244, mean=( 0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) ):
    self.data_transform = {
        'train' : transforms.Compose([
            transforms.RandomResizedCrop(
              resize, scale=(0.5,1.5)), 
              transforms.RandomHorizontalFlip(),
              transforms.ToTensor(), 
              transforms.Normalize(mean,std) 
        ]),
        'val': transforms.Compose([
            transforms.Resize(resize), 
            transforms.CenterCrop(resize), 
            transforms.ToTensor(), 
            transforms.Normalize(mean,std) 
        ])
    }
  
  def __call__(self,img,phase='train'):
    # phaseで'train'か'val'を指定する
    return self.data_transform[phase](img)

動作例

# 参考)https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#visualize-a-few-images
image_file_path = 'なんか画像のフルパス指定'
img = Image.open(image_file_path)

#いったん表示
print("原画")
plt.imshow(img)
plt.show()

transform = DataTransform() #作成した前処理クラスを生成
inp  = transform(img, phase="train") #trainモードの処理を取得

#torchモードの配列から画像表示用の配列に変換
inp = inp.numpy().transpose((1,2,0)) #Torch型は配列順番が異なるので入れ替える
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean #標準化したものを元に戻す
inp = np.clip(inp, 0, 1)

print("データ拡張後")
plt.imshow(inp)
plt.show()

独自Datasetクラス

class CustomDatasetClass(data.Dataset):
  def __init__(self, file_list,label_list, transform='None', phase='train'):
    '''
    参考)https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
    サンプルとの違い
     サンプルだと、事前にcsvファイルで、以下のようなcsvファイルを作成することになっていますが、
     普通に画像分類したいときに、そんなきれいに整理したデータない事の方が多いので、
     フルパス指定のリストと(file_list)と分類の指定のリスト(label_list)から作れるようにしてます。
      xxx.csv
      tshirt1.jpg, 0
      tshirt2.jpg, 0
      ......
      ankleboot999.jpg, 9
     

    file_list  : 画像パスを格納したリスト
    transform  : 前処理クラスのインスタンス(上記の"DataTransform"を指定する)
    phase      : 使用するフォルダ構成に合わせてphaseを指定。例:訓練=train 検証=val
         https://download.pytorch.org/tutorial/hymenoptera_data.zip
              ■想定するフォルダ構成
                 data
                   ├train
                   │ ├ants
                   │ └bees
                   └val
                     ├ants
                     └bees
    label_list : labelに付与する名称リスト ['ants','bees']  この順番にindexを付与する
          フォルダ名とリストの名称を合わせる。
    '''
    self.file_list = file_list
    self.transform = transform
    self.phase = phase
    self.label_list = label_list

  def __len__(self):
    return len(self.file_list)
  
  def __getitem__(self, index):
    #index番目を取得
    img_path = self.file_list[index]
    img = Image.open(img_path)

    img_transformed = self.transform(img, self.phase)

    pattern = ".*"+self.phase+"/(.*)/.*"
    label_name = re.sub(pattern, "\1", img_path)

    label = self.label_list.index(label_name)

    return img_transformed, label


0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?