takuroooのブログ

勉強したこととか

PyTorchでValidation Datasetを作る方法

*この記事は以前Qiitaで書いたものです。

qiita.com

目次

概要

PyTorchにはあらかじめ有名なデータセットがいくつか用意されている(torchvision.datasetsを使ってMNIST/CIFARなどダウロードできる)。しかし、train/testでしか分離されていないので、ここからvalidationデータセットを作ってみる。 例題としてtorchvision.datasets.MNISTを使う。

課題

torchvision.datasets.MNISTを使うと簡単にPyTorchのDatasetを作ることができるが、train/test用のDatasetしか用意されていないためvalidation用のDatasetを自分で作る必要がある。

以下のコードはtrain/test用のDatasetを作っている。

from torchvision import datasets
trainval_dataset = datasets.MNIST('./', train=True, download=True)
test_dataset = datasets.MNIST('./', train=False, download=True)
print(len(trainval_dataset)) # 60000
print(len(test_dataset)) # 10000
print(type(trainval_dataset)) # torchvision.datasets.mnist.MNIST
print(type(test_dataset)) # torchvision.datasets.mnist.MNIST

このtrainval_datasetをtrain/validationに分割したい。 しかし、trainval_datasetは単純なリスト形式ではなく、PyTorchのDatasetになっているため、「Datasetが持つデータを取り出して、それをDatasetクラスに再構成する。」みたいなやり方だと手間がかかる上にうまくいかないことがある。(うまくいかない例としては、DatasetクラスにTransformクラスを渡している場合。)

解決策1 torch.utils.data.Subset

torch.utils.data.Subset(dataset, indices)を使うと簡単にDatasetを分割できる。 PyTorchの中のコードは以下のようにシンプルなクラスになっている。

class Subset(Dataset):
    """
    Subset of a dataset at specified indices.

    Arguments:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

つまり、Datasetとインデックスのリストを受け取って、そのインデックスのリストの範囲内でしかアクセスしないDatasetを生成してくれる。

文章にするとややこしいけどコードの例を見るとわかりやすい。 以下のコードはMNISTの60000のDatasetをtrain:48000とvalidation:12000のDatasetに分割している。

from torchvision import datasets
from torch.utils.data.dataset import Subset
trainval_dataset = datasets.MNIST('./', train=True, download=True)

n_samples = len(trainval_dataset) # n_samples is 60000
train_size = n_samples * 0.8 # train_size is 48000

subset1_indices = list(range(0,train_size)) # [0,1,.....47999]
subset2_indices = list(range(train_size,n_samples)) # [48000,48001,.....59999]
        
train_dataset = Subset(trainval_dataset, subset1_indices)
val_dataset   = Subset(trainval_dataset, subset2_indices)

print(len(train_dataset)) # 48000
print(len(val_dataset)) # 12000

解決策2 torch.utils.data.random_split

解決策1はランダム性がない分割の仕方だったが、torch.utils.data.random_split(dataset, lengths)を使うとランダムに分割することができる。

from torchvision import datasets

trainval_dataset = datasets.MNIST('./', train=True, download=True)

n_samples = len(trainval_dataset) # n_samples is 60000
train_size = int(len(trainval_dataset) * 0.8) # train_size is 48000
val_size = n_samples - train_size # val_size is 48000

# shuffleしてから分割してくれる.
train_dataset, val_dataset = torch.utils.data.random_split(trainval_dataset, [train_size, val_size])

print(len(train_dataset)) # 48000
print(len(val_dataset)) # 12000

Chainerのchainer.datasets.split_dataset_randomについて

ちなみにChainerのchainer.datasets.split_dataset_randomtorch.utils.data.random_splitと同じようなことをしてくれる。

Chainerの実装を参考に同じようなものを作ると以下のようなコードになる。

from torch.utils.data.dataset import Subset

def split_dataset(data_set, split_at, order=None):
    from torch.utils.data.dataset import Subset
    n_examples = len(data_set)
    
    if split_at < 0:
        raise ValueError('split_at must be non-negative')
    if split_at > n_examples:
        raise ValueError('split_at exceeds the dataset size')
        
    if order is not None:
        subset1_indices = order[0:split_at]
        subset2_indices = order[split_at:n_examples]
    else:
        subset1_indices = list(range(0,split_at))
        subset2_indices = list(range(split_at,n_examples))
        
    subset1 = Subset(data_set, subset1_indices)
    subset2 = Subset(data_set, subset2_indices)
    
    return subset1, subset2

def split_dataset_random(data_set, first_size, seed=0):
    order = np.random.RandomState(seed).permutation(len(data_set))
    return split_dataset(data_set, int(first_size), order)

これを使うとランダムに分割できる。

from torchvision import datasets

trainval_dataset = datasets.MNIST('./', train=True, download=True)
n_samples = len(trainval_dataset) # n_samples is 60000
train_size = n_samples * 0.8 # train_size is 48000

train_dataset, val_dataset = split_dataset_random(trainval_dataset, train_size, seed=0)

print(len(train_dataset)) # 48000
print(len(val_dataset)) # 12000

参考