同じような記事あるかもしれませんが自分用メモ
pytorchをよく使う機会があり、
なるたけ汎用的にいつも使うものを残しておきたい。
前処理用クラス
#入力画像の前処理クラス
class DataTransform():
'''
参考)https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html
画像の前処理用クラス。
画像のサイズをリサイズ。デフォ値は個人的によく使う
VGG16 訓練時の平均と分散を設定してます。
中の処理は必要に応じて修正。
'''
def __init__(self,resize=244, mean=( 0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) ):
self.data_transform = {
'train' : transforms.Compose([
transforms.RandomResizedCrop(
resize, scale=(0.5,1.5)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean,std)
]),
'val': transforms.Compose([
transforms.Resize(resize),
transforms.CenterCrop(resize),
transforms.ToTensor(),
transforms.Normalize(mean,std)
])
}
def __call__(self,img,phase='train'):
# phaseで'train'か'val'を指定する
return self.data_transform[phase](img)
動作例
# 参考)https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#visualize-a-few-images
image_file_path = 'なんか画像のフルパス指定'
img = Image.open(image_file_path)
#いったん表示
print("原画")
plt.imshow(img)
plt.show()
transform = DataTransform() #作成した前処理クラスを生成
inp = transform(img, phase="train") #trainモードの処理を取得
#torchモードの配列から画像表示用の配列に変換
inp = inp.numpy().transpose((1,2,0)) #Torch型は配列順番が異なるので入れ替える
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean #標準化したものを元に戻す
inp = np.clip(inp, 0, 1)
print("データ拡張後")
plt.imshow(inp)
plt.show()
独自Datasetクラス
class CustomDatasetClass(data.Dataset):
def __init__(self, file_list,label_list, transform='None', phase='train'):
'''
参考)https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
サンプルとの違い
サンプルだと、事前にcsvファイルで、以下のようなcsvファイルを作成することになっていますが、
普通に画像分類したいときに、そんなきれいに整理したデータない事の方が多いので、
フルパス指定のリストと(file_list)と分類の指定のリスト(label_list)から作れるようにしてます。
xxx.csv
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
file_list : 画像パスを格納したリスト
transform : 前処理クラスのインスタンス(上記の"DataTransform"を指定する)
phase : 使用するフォルダ構成に合わせてphaseを指定。例:訓練=train 検証=val
https://download.pytorch.org/tutorial/hymenoptera_data.zip
■想定するフォルダ構成
data
├train
│ ├ants
│ └bees
└val
├ants
└bees
label_list : labelに付与する名称リスト ['ants','bees'] この順番にindexを付与する
フォルダ名とリストの名称を合わせる。
'''
self.file_list = file_list
self.transform = transform
self.phase = phase
self.label_list = label_list
def __len__(self):
return len(self.file_list)
def __getitem__(self, index):
#index番目を取得
img_path = self.file_list[index]
img = Image.open(img_path)
img_transformed = self.transform(img, self.phase)
pattern = ".*"+self.phase+"/(.*)/.*"
label_name = re.sub(pattern, "\1", img_path)
label = self.label_list.index(label_name)
return img_transformed, label