takuroooのブログ

勉強したこととか

PyTorch transforms/Dataset/DataLoaderの基本動作を確認する

*この記事は以前Qiitaで書いたものです。

qiita.com

目次

概要

PyTorchの前処理とデータのロードを担当するtransforms/Dataset/DataLoaderの動作を簡単な例で確認する。

この記事の対象読者

  • これからPyTorchを勉強しようとしている人
  • PyTorchのtransforms/Dataset/DataLoaderの役割を知りたい人
  • オリジナルのtransforms/Dataset/DataLoaderを実装したい人

前置き

DeepLearningのフレームワークではだいたい以下のような機能をサポートしている。

この中で「データの前処理」と「データセットのロード」は自分達の環境によってカスタマイズすることがよくあるので、フレームワークがどのような機能をサポートしているのかを把握することが実装の効率化に繋がる。

今回はPyTorchの「データの前処理」と「データセットのロード」を実現するためのモジュールtransforms/Dataset/DataLoaderの動きを簡単なデータセットを使って確認してみる。

PyTorch Tutorial

今回は、PyTorchのTutorialのDATA LOADING AND PROCESSING TUTORIALで紹介されている内容を参考にしている。

このTutorialでは実際のデータセット(顔画像と顔の特徴点)を使用してtransforms/Dataset/DataLoaderについて解説している。

Tutorialの内容は分かりやすい構成になっているが、画像データやcsvファイルを扱うので、コード上にこれらのファイル特有の処理が入っており、transforms/Dataset/DataLoaderだけの動きを知りたい場合は、やや冗長な部分がある。

よって今回は数字の1次元のリスト形式のデータセットを使って、transforms/Dataset/DataLoaderを動かしていく。

transforms/Dataset/DataLoaderの役割

  • transforms
    • データの前処理を担当するモジュール
  • Dataset
    • データとそれに対応するラベルを1組返すモジュール
    • データを返すときにtransformsを使って前処理したものを返す。
  • DataLoader
    • データセットからデータをバッチサイズに固めて返すモジュール

上記説明にあるとおり Datasetはtransformsを制御して DataLoaderはDatasetを制御する という関係になっている。

なので流れとしては、 1.Datasetクラスをインスタンス化するときに、transformsを引数として渡す。 2.DataLoaderクラスをインスタンス化するときに、Datasetを引数で渡す。 3.トレーニングするときにDataLoaderを使ってデータとラベルをバッチサイズで取得する。 という流れになる。

以下各詳細を、transforms、Dataset、DataLoaderの順に動作を見ていく。

transforms

transformsはデータの前処理を行う。 PyTorchではあらかじめ便利な前処理がいくつか実装されている。 例えば、画像に関する前処理はtorchvision.transformsにまとまっており、CropやFlipなどメジャーな前処理があらかじめ用意されている。

今回は自分で簡単なtransformsを実装することで処理内容の理解を深めていく。

transformsを実装するのに必要な要件

予め用意されているtransformsの動作に習うために「コール可能なクラス」として実装する必要がある。 (「コール可能」とは__call__を実装しているクラスのこと) なぜ「コール可能なクラス」にする必要があるのかというと、Tutorialでは以下のように説明している。

We will write them as callable classes instead of simple functions so that parameters of the transform need not be passed everytime it’s called.

つまり、クラスにしておけば、インスタンス化時に前処理に使うパラメータを全部渡しておけるので、前処理を実行するたびにパラメータを渡す手間が省ける、ということで「コール可能なクラス」を推奨している。

今回はデータとして数字の1次元配列を使うので入力値を二乗するtransformsを実装する。

実装
class Square(object):
    def __init__(self):
        pass
    
    def __call__(self, sample):
        return sample ** 2

実装はこれだけ。

使い方
transform = Square()
print(transform(1)) # -> 1
print(transform(2)) # -> 4
print(transform(3)) # -> 9
print(transform(4)) # -> 16

渡した数値が二乗になっていることが確認できる。 もし画像データ用のtransformsを実装したい場合は、__call__の中に画像処理を実装すればいい。

Dataset

Datasetは、入力データとそれに対応するラベルを1組返すモジュール。 データはtransformsで前処理を行った後に返す。そのためDatasetを作るときは引数でtransformsを渡す必要がある。

PyTorchでは有名なデータセットがあらかじめtorchvision.datasetsに定義されている。(MNIST/CIFAR/STL10など)

自前のデータを扱いたいときは自分のデータをリードして返してくれるDatasetを実装する必要がある。

扱うデータが画像でクラスごとにフォルダ分けされている場合はtorchvision.datasets.ImageFolderという便利なクラスもある。(KerasのImageDataGeneratorflow_from_directory()のような機能)

Datasetを実装するのに必要な要件

オリジナルDatasetを実装するときに守る必要がある要件は以下3つ。

  • torch.utils.data.Datasetを継承する。
  • __len__を実装する。
  • __getitem__を実装する。

__len__は、len(obj)で実行されたときにコールされる関数。 __getitem__は、obj[i]のようにインデックスで指定されたときにコールされる関数。

今回は、データは数字のリスト、ラベルは偶数の場合だけTrueになるものを出力するDatasetを実装する。

実装
import torch
class MyDataset(torch.utils.data.Dataset):
    
    def __init__(self, data_num, transform=None):
        self.transform = transform
        self.data_num = data_num
        self.data = []
        self.label = []
        for x in range(self.data_num):
            self.data.append(x) # 0 から (data_num-1) までのリスト
            self.label.append(x%2 == 0) # 偶数ならTrue 奇数ならFalse
            
    def __len__(self):
        return self.data_num
    
    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label =  self.label[idx]
        
        if self.transform:
            out_data = self.transform(out_data)
            
        return out_data, out_label

ポイントは__getitem__でデータを返す前にtransformでデータに前処理をしてから返しているところ。 データセットとして画像やcsvファイルを扱う場合は、__init__や__getitem__の中でファイルをオープンする必要がある。

使い方
data_set = MyDataset(10, transform=None)
print(data_set[0]) # -> (0, True)
print(data_set[1]) # -> (1, False)
print(data_set[2]) # -> (2, True)
print(data_set[3]) # -> (3, False)
print(data_set[4]) # -> (4, True)

# 先ほど実装したtransformsを渡してみる.
# データが二乗されていることに注目.
data_set = MyDataset(10, transform=Square())
print(data_set[0]) # -> (0, True)
print(data_set[1]) # -> (1, False)
print(data_set[2]) # -> (4, True)
print(data_set[3]) # -> (9, False)
print(data_set[4]) # -> (16, True)

指定したインデックスのデータとラベルがセットで取得できている。 次に説明するDataLoaderは、この仕組みを利用してバッチサイズ分のデータを生成する。

DataLoader

データセットからデータをバッチサイズに固めて返すモジュール DataLoaderはデータセットを使ってバッチサイズ分のデータを生成する。またデータのシャッフル機能も持つ。 データを返すときは、データをtensor型に変換して返す。 tensor型は、計算グラフを保持することができる変数でDeepLearningの勾配計算に不可欠な変数になっている。

DataLoaderは、torch.utils.data.DataLoaderというクラスが既に用意されている。たいていの場合、このクラスで十分対応できるので、今回はこのクラスにこれまで実装してきたDatasetを渡して動作を見てみる。

使い方
import torch
data_set = MyDataset(10, transform=Square())
dataloader = torch.utils.data.DataLoader(data_set, batch_size=2, shuffle=True)

for i in dataloader:
    print(i)

# [tensor([ 4, 25]), tensor([1, 0])]
# [tensor([64,  0]), tensor([1, 1])]
# [tensor([36, 16]), tensor([1, 1])]
# [tensor([1, 9]), tensor([0, 0])]
# [tensor([81, 49]), tensor([0, 0])]

指定したバッチサイズでかつデータがシャッフルされていることがわかる。transformsで値が二乗に変換されている。 shuffle=Falseにすると順番にデータが出力される。

import torch
data_set = MyDataset(10, transform=Square())
dataloader = torch.utils.data.DataLoader(data_set, batch_size=2, shuffle=False)

for i in dataloader:
    print(i)

# [tensor([0, 1]), tensor([1, 0])]
# [tensor([4, 9]), tensor([1, 0])]
# [tensor([16, 25]), tensor([1, 0])]
# [tensor([36, 49]), tensor([1, 0])]
# [tensor([64, 81]), tensor([1, 0])]

学習のときはdataloaderのループをさらにepochのループでかぶせる。

epochs = 4
for epoch in epochs:
    for i in dataloader:
        # 学習処理