LoginSignup
3
1

More than 1 year has passed since last update.

いまさらCIFAR-10をダウンロードして画像として保存する

Posted at

この記事を読むべき人

簡単に言えば、cifar10を画像として保存して使いたい人。
pytorchとかkerasとかでダウンロードしてチュートリアルはできるけど、凝ったことをするために画像として保存しつつ、色々いじいじしたい人向け。

意外とドンピシャなコードが落ちてないので、備忘録も兼ねて記載。

コード全体

main.py
import os
import tarfile
import requests
import pickle
import joblib
from PIL import Image
import numpy as np
import tqdm


# Download the tar.gz file
def download_cifar10(save_dir):
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    header = requests.head(url).headers
    size = int(header["Content-Length"])
    with open(os.path.join(save_dir, "cifar-10-python.tar.gz"), "wb") as f:
        pbar = tqdm.tqdm(total=size, unit="B", unit_scale=True)
        for chunk in requests.get(url, stream=True).iter_content(chunk_size=1024):
            ff = f.write(chunk)
            pbar.update(len(chunk))
        pbar.close()


# Unpickle the file
def unpickle(file):
    with open(file, "rb") as f:
        dict = pickle.load(f, encoding="bytes")
    return dict


# Save cifar10 image
def save_image(label, data, fname, type="train"):
    img = data.reshape(3, 32, 32)
    img = np.transpose(img, (1, 2, 0))
    pil_img = Image.fromarray(img)
    dir = f"./data/cifar-10/{type}/{label}"
    os.makedirs(dir, exist_ok=True)
    pil_img.save(f"{dir}/{fname.decode('utf-8')}")


save_dir = "./data"
os.makedirs(save_dir, exist_ok=True)
print(f"Process start: save directory={save_dir}")

print("Download the CIFAR10.")
if not os.path.isfile(os.path.join(save_dir, "cifar-10-python.tar.gz")):
    download_cifar10(save_dir)

print("Extract tar.gz.")
with tarfile.open(os.path.join(save_dir, "cifar-10-python.tar.gz"), "r:gz") as tar:
    tar.extractall(path=save_dir)

print("Save the training data as png.")
for i in range(5):
    unp = unpickle(os.path.join(save_dir, f"cifar-10-batches-py/data_batch_{i+1}"))
    joblib.Parallel(n_jobs=-1)(
        joblib.delayed(save_image)(label, data, filename)
        for label, data, filename in zip(
            unp[b"labels"], unp[b"data"], unp[b"filenames"]
        )
    )

print("Save the test data as png.")
unp = unpickle(os.path.join(save_dir, "cifar-10-batches-py/test_batch"))
joblib.Parallel(n_jobs=-1)(
    joblib.delayed(save_image)(label, data, filename, "test")
    for label, data, filename in zip(unp[b"labels"], unp[b"data"], unp[b"filenames"])
)

print("Done.")

実行結果

実行すると、スクリプトと同階層にdataフォルダが作られて、tar.gzがダウンロードされて、展開されて、画像として保存されます。

Process start: save directory=./data
Download the CIFAR10.
100%|██████| 170M/170M [00:27<00:00, 6.19MB/s]
Extract tar.gz.
Save the training data as png.
Save the test data as png.
Done.

以上。

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1