*この記事は以前Qiitaで書いたものです。
目次
概要
PyTorchの前処理とデータのロードを担当するtransforms/Dataset/DataLoaderの動作を簡単な例で確認する。
この記事の対象読者
- これからPyTorchを勉強しようとしている人
- PyTorchのtransforms/Dataset/DataLoaderの役割を知りたい人
- オリジナルのtransforms/Dataset/DataLoaderを実装したい人
前置き
DeepLearningのフレームワークではだいたい以下のような機能をサポートしている。
この中で「データの前処理」と「データセットのロード」は自分達の環境によってカスタマイズすることがよくあるので、フレームワークがどのような機能をサポートしているのかを把握することが実装の効率化に繋がる。
今回はPyTorchの「データの前処理」と「データセットのロード」を実現するためのモジュールtransforms/Dataset/DataLoaderの動きを簡単なデータセットを使って確認してみる。
PyTorch Tutorial
今回は、PyTorchのTutorialのDATA LOADING AND PROCESSING TUTORIALで紹介されている内容を参考にしている。
このTutorialでは実際のデータセット(顔画像と顔の特徴点)を使用してtransforms/Dataset/DataLoaderについて解説している。
Tutorialの内容は分かりやすい構成になっているが、画像データやcsvファイルを扱うので、コード上にこれらのファイル特有の処理が入っており、transforms/Dataset/DataLoaderだけの動きを知りたい場合は、やや冗長な部分がある。
よって今回は数字の1次元のリスト形式のデータセットを使って、transforms/Dataset/DataLoaderを動かしていく。
transforms/Dataset/DataLoaderの役割
- transforms
- データの前処理を担当するモジュール
- Dataset
- データとそれに対応するラベルを1組返すモジュール
- データを返すときにtransformsを使って前処理したものを返す。
- DataLoader
- データセットからデータをバッチサイズに固めて返すモジュール
上記説明にあるとおり Datasetはtransformsを制御して DataLoaderはDatasetを制御する という関係になっている。
なので流れとしては、 1.Datasetクラスをインスタンス化するときに、transformsを引数として渡す。 2.DataLoaderクラスをインスタンス化するときに、Datasetを引数で渡す。 3.トレーニングするときにDataLoaderを使ってデータとラベルをバッチサイズで取得する。 という流れになる。
以下各詳細を、transforms、Dataset、DataLoaderの順に動作を見ていく。
transforms
transformsはデータの前処理を行う。 PyTorchではあらかじめ便利な前処理がいくつか実装されている。 例えば、画像に関する前処理はtorchvision.transformsにまとまっており、CropやFlipなどメジャーな前処理があらかじめ用意されている。
今回は自分で簡単なtransformsを実装することで処理内容の理解を深めていく。
transformsを実装するのに必要な要件
予め用意されているtransformsの動作に習うために「コール可能なクラス」として実装する必要がある。 (「コール可能」とは__call__を実装しているクラスのこと) なぜ「コール可能なクラス」にする必要があるのかというと、Tutorialでは以下のように説明している。
We will write them as callable classes instead of simple functions so that parameters of the transform need not be passed everytime it’s called.
つまり、クラスにしておけば、インスタンス化時に前処理に使うパラメータを全部渡しておけるので、前処理を実行するたびにパラメータを渡す手間が省ける、ということで「コール可能なクラス」を推奨している。
今回はデータとして数字の1次元配列を使うので入力値を二乗するtransformsを実装する。
実装
class Square(object): def __init__(self): pass def __call__(self, sample): return sample ** 2
実装はこれだけ。
使い方
transform = Square() print(transform(1)) # -> 1 print(transform(2)) # -> 4 print(transform(3)) # -> 9 print(transform(4)) # -> 16
渡した数値が二乗になっていることが確認できる。 もし画像データ用のtransformsを実装したい場合は、__call__の中に画像処理を実装すればいい。
Dataset
Datasetは、入力データとそれに対応するラベルを1組返すモジュール。 データはtransformsで前処理を行った後に返す。そのためDatasetを作るときは引数でtransformsを渡す必要がある。
PyTorchでは有名なデータセットがあらかじめtorchvision.datasetsに定義されている。(MNIST/CIFAR/STL10など)
自前のデータを扱いたいときは自分のデータをリードして返してくれるDatasetを実装する必要がある。
扱うデータが画像でクラスごとにフォルダ分けされている場合はtorchvision.datasets.ImageFolder
という便利なクラスもある。(KerasのImageDataGenerator
のflow_from_directory()
のような機能)
Datasetを実装するのに必要な要件
オリジナルDatasetを実装するときに守る必要がある要件は以下3つ。
- torch.utils.data.Datasetを継承する。
- __len__を実装する。
- __getitem__を実装する。
__len__は、len(obj)で実行されたときにコールされる関数。 __getitem__は、obj[i]のようにインデックスで指定されたときにコールされる関数。
今回は、データは数字のリスト、ラベルは偶数の場合だけTrueになるものを出力するDatasetを実装する。
実装
import torch class MyDataset(torch.utils.data.Dataset): def __init__(self, data_num, transform=None): self.transform = transform self.data_num = data_num self.data = [] self.label = [] for x in range(self.data_num): self.data.append(x) # 0 から (data_num-1) までのリスト self.label.append(x%2 == 0) # 偶数ならTrue 奇数ならFalse def __len__(self): return self.data_num def __getitem__(self, idx): out_data = self.data[idx] out_label = self.label[idx] if self.transform: out_data = self.transform(out_data) return out_data, out_label
ポイントは__getitem__でデータを返す前にtransformでデータに前処理をしてから返しているところ。 データセットとして画像やcsvファイルを扱う場合は、__init__や__getitem__の中でファイルをオープンする必要がある。
使い方
data_set = MyDataset(10, transform=None) print(data_set[0]) # -> (0, True) print(data_set[1]) # -> (1, False) print(data_set[2]) # -> (2, True) print(data_set[3]) # -> (3, False) print(data_set[4]) # -> (4, True) # 先ほど実装したtransformsを渡してみる. # データが二乗されていることに注目. data_set = MyDataset(10, transform=Square()) print(data_set[0]) # -> (0, True) print(data_set[1]) # -> (1, False) print(data_set[2]) # -> (4, True) print(data_set[3]) # -> (9, False) print(data_set[4]) # -> (16, True)
指定したインデックスのデータとラベルがセットで取得できている。 次に説明するDataLoaderは、この仕組みを利用してバッチサイズ分のデータを生成する。
DataLoader
データセットからデータをバッチサイズに固めて返すモジュール DataLoaderはデータセットを使ってバッチサイズ分のデータを生成する。またデータのシャッフル機能も持つ。 データを返すときは、データをtensor型に変換して返す。 tensor型は、計算グラフを保持することができる変数でDeepLearningの勾配計算に不可欠な変数になっている。
DataLoaderは、torch.utils.data.DataLoader
というクラスが既に用意されている。たいていの場合、このクラスで十分対応できるので、今回はこのクラスにこれまで実装してきたDatasetを渡して動作を見てみる。
使い方
import torch data_set = MyDataset(10, transform=Square()) dataloader = torch.utils.data.DataLoader(data_set, batch_size=2, shuffle=True) for i in dataloader: print(i) # [tensor([ 4, 25]), tensor([1, 0])] # [tensor([64, 0]), tensor([1, 1])] # [tensor([36, 16]), tensor([1, 1])] # [tensor([1, 9]), tensor([0, 0])] # [tensor([81, 49]), tensor([0, 0])]
指定したバッチサイズでかつデータがシャッフルされていることがわかる。transformsで値が二乗に変換されている。
shuffle=False
にすると順番にデータが出力される。
import torch data_set = MyDataset(10, transform=Square()) dataloader = torch.utils.data.DataLoader(data_set, batch_size=2, shuffle=False) for i in dataloader: print(i) # [tensor([0, 1]), tensor([1, 0])] # [tensor([4, 9]), tensor([1, 0])] # [tensor([16, 25]), tensor([1, 0])] # [tensor([36, 49]), tensor([1, 0])] # [tensor([64, 81]), tensor([1, 0])]
学習のときはdataloaderのループをさらにepochのループでかぶせる。
epochs = 4 for epoch in epochs: for i in dataloader: # 学習処理