takuroooのブログ

勉強したこととか

PyTorch 入力画像と教師画像の両方にランダムなデータ拡張を実行する方法

*この記事は以前Qiitaで書いたものです。

qiita.com

目次

概要

DeepLearningのタスクの1つであるセマンティックセグメンテーション(Semantic Segmentation)では、分類や検出のタスクと異なって、教師データが画像形式になっている。そのためデータ拡張する場合(クロップや反転など)、入力画像と教師データそれぞれに同じように画像処理を行う必要がある。

この記事では入力画像と教師データの両方に同様のランダムなデータ拡張を実行する方法を紹介する記事。

セマンティックセグメンテーションとは

セマンティックセグメンテーションについては以下が参考になります。 U-NetでPascal VOC 2012の画像をSemantic Segmentationする (TensorFlow)

今回はサンプル画像としてVOCデータセットの画像を使用する。 解像度はどちらも500x281になっている。

・入力画像

f:id:takuroooooo:20201030081329j:plain

・教師画像

f:id:takuroooooo:20201030081340p:plain

ランダムなデータ拡張

今回はPyTorchで予め用意されているtorchvision.transforms.RandomCrop(size, padding=0, pad_if_needed=False)を使って画像に対してクロップを実行する。

関数の中では、乱数でクロップする位置を決めて、指定したsizeでクロップしている。(最終的には内部でtorchvision.transforms.functional.crop(img, i, j, h, w)がコールされている。)

詳細な使い方やパラメータについてはPyTorchのリファレンスを参照してください。 PyTorch TORCHVISION.TRANSFORMS

課題

torchvision.transforms.RandomCropは内部で乱数を発生させているため、実行するたびに結果が異なってしまう。 よって、以下のように実行すると入力画像と教師画像が異なる位置がクロップされてしまう。

from PIL import Image
from torchvision import transforms

trans_crop = transforms.RandomCrop((224,224))

img = Image.open(img_path) # 入力画像
target = Image.open(target_path) # 教師画像

img = trans_crop(img) # 入力画像を(224,224)でランダムクロップ
target = trans_crop(target) # 教師画像を(224,224)でランダムクロップ

img.show()
target.show()

f:id:takuroooooo:20201030081421p:plain f:id:takuroooooo:20201030081454p:plain

このように入力画像と教師画像が一致しないため学習ができなくなってしまう。

解決策1 乱数シードを固定する

PyTorchのクロップ関数の内部ではrandom.randint()で乱数を発生させているので、random.seed()を使って乱数シードを設定すれば同じ結果が得られる。

from PIL import Image
from torchvision import transforms
import random

trans = transforms.RandomCrop((224,224))

img = Image.open(img_path)
target = Image.open(target_path)

seed = random.randint(0, 2**32) # 乱数で乱数シードを決定

random.seed(seed) # 乱数シードを固定
img = trans(img)

random.seed(seed) # こっちでも乱数シードを固定
target = trans(target)

img.show()
target.show()

f:id:takuroooooo:20201030081546p:plain f:id:takuroooooo:20201030081556p:plain

この実装のよくない点はPyTorchの中で「random.randint()で乱数を発生させている」ということを前提としていること。 PyTorchの中で乱数の発生のさせ方が変わると急に上手く動作しなくなったりする。

解決策2 transforms.RandomCrop.get_params(img, output_size))を使う

transforms.RandomCrop.get_params(img, output_size))は乱数で決めたクロップする位置とサイズを返してくれる関数。

torchvision.transforms.RandomCropの中でもこの関数でクロップ位置を決めた後、torchvision.transforms.functional.crop(img, i, j, h, w)でクロップしている。

なので、これと同じような処理の流れを自分で実装すれば解決できる。

from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as tvf
import random

trans = transforms.RandomCrop((224,224))

img = Image.open(img_path)
target = Image.open(target_path)

# クロップ位置を乱数で決定
i, j, h, w = transforms.RandomCrop.get_params(img, output_size=(224,224))

img = tvf.crop(img, i, j, h, w) # 入力画像を(224,224)でクロップ
target = tvf.crop(target, i, j, h, w) # 教師画像を(224,224)でクロップ

f:id:takuroooooo:20201030081630p:plain f:id:takuroooooo:20201030081640p:plain