PyTorchでValidation Datasetを作る方法
*この記事は以前Qiitaで書いたものです。
目次
- 目次
- 概要
- 課題
- 解決策1 torch.utils.data.Subset
- 解決策2 torch.utils.data.random_split
- Chainerのchainer.datasets.split_dataset_randomについて
- 参考
概要
PyTorchにはあらかじめ有名なデータセットがいくつか用意されている(torchvision.datasets
を使ってMNIST/CIFARなどダウロードできる)。しかし、train/testでしか分離されていないので、ここからvalidationデータセットを作ってみる。
例題としてtorchvision.datasets.MNIST
を使う。
課題
torchvision.datasets.MNIST
を使うと簡単にPyTorchのDataset
を作ることができるが、train/test用のDataset
しか用意されていないためvalidation用のDataset
を自分で作る必要がある。
以下のコードはtrain/test用のDataset
を作っている。
from torchvision import datasets trainval_dataset = datasets.MNIST('./', train=True, download=True) test_dataset = datasets.MNIST('./', train=False, download=True) print(len(trainval_dataset)) # 60000 print(len(test_dataset)) # 10000 print(type(trainval_dataset)) # torchvision.datasets.mnist.MNIST print(type(test_dataset)) # torchvision.datasets.mnist.MNIST
このtrainval_dataset
をtrain/validationに分割したい。
しかし、trainval_dataset
は単純なリスト形式ではなく、PyTorchのDataset
になっているため、「Dataset
が持つデータを取り出して、それをDataset
クラスに再構成する。」みたいなやり方だと手間がかかる上にうまくいかないことがある。(うまくいかない例としては、Dataset
クラスにTransform
クラスを渡している場合。)
解決策1 torch.utils.data.Subset
torch.utils.data.Subset(dataset, indices)
を使うと簡単にDataset
を分割できる。
PyTorchの中のコードは以下のようにシンプルなクラスになっている。
class Subset(Dataset): """ Subset of a dataset at specified indices. Arguments: dataset (Dataset): The whole Dataset indices (sequence): Indices in the whole set selected for subset """ def __init__(self, dataset, indices): self.dataset = dataset self.indices = indices def __getitem__(self, idx): return self.dataset[self.indices[idx]] def __len__(self): return len(self.indices)
つまり、Dataset
とインデックスのリストを受け取って、そのインデックスのリストの範囲内でしかアクセスしないDataset
を生成してくれる。
文章にするとややこしいけどコードの例を見るとわかりやすい。
以下のコードはMNISTの60000のDataset
をtrain:48000とvalidation:12000のDataset
に分割している。
from torchvision import datasets from torch.utils.data.dataset import Subset trainval_dataset = datasets.MNIST('./', train=True, download=True) n_samples = len(trainval_dataset) # n_samples is 60000 train_size = n_samples * 0.8 # train_size is 48000 subset1_indices = list(range(0,train_size)) # [0,1,.....47999] subset2_indices = list(range(train_size,n_samples)) # [48000,48001,.....59999] train_dataset = Subset(trainval_dataset, subset1_indices) val_dataset = Subset(trainval_dataset, subset2_indices) print(len(train_dataset)) # 48000 print(len(val_dataset)) # 12000
解決策2 torch.utils.data.random_split
解決策1はランダム性がない分割の仕方だったが、torch.utils.data.random_split(dataset, lengths)
を使うとランダムに分割することができる。
from torchvision import datasets trainval_dataset = datasets.MNIST('./', train=True, download=True) n_samples = len(trainval_dataset) # n_samples is 60000 train_size = int(len(trainval_dataset) * 0.8) # train_size is 48000 val_size = n_samples - train_size # val_size is 48000 # shuffleしてから分割してくれる. train_dataset, val_dataset = torch.utils.data.random_split(trainval_dataset, [train_size, val_size]) print(len(train_dataset)) # 48000 print(len(val_dataset)) # 12000
Chainerのchainer.datasets.split_dataset_random
について
ちなみにChainerのchainer.datasets.split_dataset_random
がtorch.utils.data.random_split
と同じようなことをしてくれる。
Chainerの実装を参考に同じようなものを作ると以下のようなコードになる。
from torch.utils.data.dataset import Subset def split_dataset(data_set, split_at, order=None): from torch.utils.data.dataset import Subset n_examples = len(data_set) if split_at < 0: raise ValueError('split_at must be non-negative') if split_at > n_examples: raise ValueError('split_at exceeds the dataset size') if order is not None: subset1_indices = order[0:split_at] subset2_indices = order[split_at:n_examples] else: subset1_indices = list(range(0,split_at)) subset2_indices = list(range(split_at,n_examples)) subset1 = Subset(data_set, subset1_indices) subset2 = Subset(data_set, subset2_indices) return subset1, subset2 def split_dataset_random(data_set, first_size, seed=0): order = np.random.RandomState(seed).permutation(len(data_set)) return split_dataset(data_set, int(first_size), order)
これを使うとランダムに分割できる。
from torchvision import datasets trainval_dataset = datasets.MNIST('./', train=True, download=True) n_samples = len(trainval_dataset) # n_samples is 60000 train_size = n_samples * 0.8 # train_size is 48000 train_dataset, val_dataset = split_dataset_random(trainval_dataset, train_size, seed=0) print(len(train_dataset)) # 48000 print(len(val_dataset)) # 12000