takuroooのブログ

勉強したこととか

Vue.jsでChatスタイルのWebSocketクライアントを作る

Vue.jsを触ってみたかったので、Chat風のWebSocketクライアントを作ってみた。

f:id:takuroooooo:20201129093432p:plain

目次

作ったもの

github.com

Chat部分

f:id:takuroooooo:20201129114523p:plain

Chat部分はこちらのCodepenのコードを使った。 codepen.io

メッセージ表示部分

<section ref="messageArea" class="message-area">
    <p v-for="message in messages" class="message"
        :class="{ 'sent-message': message.from === 'client', 'received-message': message.from === 'server', 'system-message': message.from === 'system' }">
        {{ message.date }}
        <br>
        {{ message.body }}
    </p>
</section>
  • messagesにはWebSocketサーバーと送受信した情報が入っている。これをv-forを使ってレンダリングしていく。
  • {{ message.date }}はメッセージを送受信した時刻が入っている。
  • {{ message.body }}は送受信したメッセージそのものが入っている。
  • :class="{ 'sent-message': message.from === 'client', ... }"の部分ではmessage送信者(message.from)に応じて動的にクラスを切り替えている。:class=...v-binad:class=...の省略系となっている。

送信したメッセージ、受信したメッセージ、それ以外のメッセージで文字の色、背景色、配置を変えている。

.sent-message {
    background: #66ff00;
    color: black;
    margin-left: 55%;
}

.received-message {
    background: #d8d8d8;
    color: black;
}

.system-message {
    background: #007bff;
    color: white;
    width: 100%;
    margin: 10px auto;
}

メッセージ送受信後に自動でスクロール実装部分

<section ref="messageArea" class="message-area">
...
  • ref="messageArea"はVueで直接DOM操作するときに使うものらしい(getElementByIdのようなもの)。これを使ってメッセージ送受信後に自動でスクロールするように制御している。
Vue.nextTick(() => {
    let messageArea = this.$refs.messageArea;
    messageArea.scrollTop = messageArea.scrollHeight;
})
  • Vue.nextTickはDOMの更新後に実行される関数。DOM更新後にscrollTopを更新していて最新のmessageが画面に表示されるようにしてる。

WebSocketサーバー接続部分

WebSocketサーバーと接続する部分。 f:id:takuroooooo:20201129120436p:plain

<form>
    <div class="form-group">
        <label for="serverURL">URL</label>
        <input @keydown.enter.prevent="connect" v-model="url" id="serverURL" type="text"
            class="form-control" placeholder="wss://..."></input>
    </div>
    <div>
        <button @click="connect" id="connectBtn" type="button" class="btn btn-primary">Connect</button>
        <button @click="disconnect" id="disconnectBtn" type="button"
            class="btn btn-secondary">Disconnect</button>
    </div>
</form>
  • WebSocketサーバーのURLを入力するinputフォームでEnterキーを押した時に実行される関数は@keydown.enter.prevent="..."で指定している。keydown.enterはEnterキー押下時のイベントを表している。.preventevent.preventDefault()と同じようにデフォルト動作を抑制している。

jp.vuejs.org

  • inputフォームに入力されたURLはv-model="url"を使ってurlという名前の変数に代入している。
  • ボタンが押されたときに実行される関数は@click="..."で指定している。これはv-on:click="..."の省略系となっている。

接続ボタン押下時のロジック

・
・
this.conn = new WebSocket(this.url)

this.conn.onopen = (event) => {
    this.onOpenListener(`Connected: ${this.url}`)
}
this.conn.onclose = (event) => {
    this.onCloseListener(`Disconnected: ${this.url} (code:${event.code} reason:${event.reason})`)
}
this.conn.onerror = (event) => {
    console.error("WebSocket error observed:", event);
    this.onErrorListener(`Error: ${this.url}`);
}
this.conn.onmessage = (msg) => {
    this.onMessageListener(msg.data)
}
・
・
  • this.urlv-modelによってinputフォームと紐付けされた変数。
  • event.codeにはWebSocketのStatus Codeが入っている。

Status Code定義 tools.ietf.org

画面更新時にVueのテンプレートが表示される問題

画面を再読み込みしたりすると{{ message.date }}{{ message.body }}が一瞬画面に表示されてしまう。これを回避するためにv-cloakを使っている。

[v-cloak] {
            display: none;
}

・
・
・

<div id="app" v-cloak>
    <h1>WebSocket Client</h1>

公式ドキュメントによると、

このディレクティブは関連付けられた Vue インスタンスコンパイルが終了するまでの間残存します。[v-cloak] { display: none } のような CSS のルールと組み合わせて、このディレクティブは Vue インスタンス が用意されるまでの間、コンパイルされていない Mustache バインディングを隠すのに使うことができます。

API — Vue.js

とのこと。

参考


AtCoder Beginner Contest 183

4完だったがB問題に解くのにすごく時間がかかってしまった。

atcoder.jp

目次

B - Billiards

X軸への入射角と反射角が必ず等しくなるとき、スタートの座標(Sx,Sy)からゴールの座標(Gx,Gy)にたどり着くにはX軸のどの部分にボールがぶつかればよいか。

答えのX軸の座標をxとすると、 Sy/(x-Sx) = Gy/(Gx-x)になるので、これを式変形すると、x = (Sy*Gx + Gy*Sx) / (Gy+Sy)となるのでこれを実装すればよい。(これに気づくのにすごく時間がかかった...)

SX, SY, GX, GY = list(map(int,input().split()))
print((SY*GX + GY*SX) / (GY+SY))

同じように式変形する問題 atcoder.jp

C - Travel

N個(2<=N<=8)の都市があって、ある都市からある都市への移動時間が与えられる。 都市1から出発して全ての都市を通って、最後に都市1に戻る場合の移動時間の合計がちょうどKになるものはいくつあるか。

都市が3つの場合は
1 - 2 - 3 - 1
1 - 3 - 2 - 1
の二通りの移動の仕方がある。 都市は最大8個しかないので都市の並び方を全て列挙して試してみる。(順列全列挙)

import itertools
 
N, K = list(map(int, input().split()))
 
arr = []
for _ in range(N):
    arr.append(list(map(int, input().split())))
 
ans = 0
t = list(range(N))[1:]
for p in list(itertools.permutations(t, len(t))):
    pp = [0] + list(p) + [0]
    dist = 0
    for src, dst in zip(pp[0:], pp[1:]):
        dist += arr[src][dst]
    if dist == K:
        ans += 1
print(ans)

D - Water Heater

給湯器は毎分Wリットルお湯を供給できる。 N人の人が時刻SからTまでの間(時刻Tは含まない)、お湯を毎分Pリットル使う。 全ての人にお湯を供給することは可能か。

ある時刻でもっともお湯が使われる量を求めて、それがWリットル以下であれば供給可能。 こういうときはいもす法が使える。

imoz.jp

いもす法
1 開始時刻に使用するお湯の量を足す。終了時刻に引く。
2 これをN人分あらかじめ行っておき、累積和を計算する。
3 累積和の結果がその区間で使用するお湯の量となる。 f:id:takuroooooo:20201116223919p:plain

この場合最大で使用されるお湯の量は11リットルになる。

N, W = list(map(int, input().split()))
MAX = (2 * 10**5) + 1

arr = [0] * MAX
for n in range(N):
    s, t, p = list(map(int, input().split()))
    arr[s] += p
    arr[t] -= p

for i in range(1, MAX):
    arr[i] += arr[i - 1]

print(['No','Yes'][max(arr) <= W])

いもす法を使う他の問題
atcoder.jp

Pythonのurllibを使ってImageNetから画像をダウンロードする

*この記事は以前Qiitaで書いたものです。

qiita.com

概要

PythonでImageNetから画像をダウンロードする方法を解説する記事。

ImageNetの画像をダウンロードする方法は2つある。 一つはImageNet経由で一括ダウンロードする方法と、もう一つはImageNetが管理している画像元のURL一覧取得して、そのURLを使って自分でダウンロードする方法である。

前者は、非営利目的の研究/教育目的のみ利用可能。 なのでそれ以外の人は後者の方法をとらなければならない。 今回は後者の方法をPythonを使ってダウンロードする。

ダウンロードに必要な知識と実際のコードをそれぞれ解説していく。

ImageNetとは

f:id:takuroooooo:20201030083153p:plain

  • 研究目的で作成された画像のデータベース
  • 機械学習の学習データによく使用されている。
  • 画像の種類はWordNetで管理 WordNet(Wikipedia)
  • ImageNetが画像の著作権を保持しているわけではない。ImageNetは画像元のURLと画像のサムネイルを提供しているだけ。
  • 画像数:14,197,122
  • バウンディングボックス:1,034,908

WordnetIDとSynsetとは

ImageNetはWordNetという辞書で画像が管理されている。 具体的にはWordNetIDとSynsetという組み合わせで画像の種別を表現している。

例えば、WordNetID 「n02113335」 は、 Synset 「Poodle, poodle dog」 を表現している。

例から分かるように、Synsetは人間が分かる単語で表現され、WordNetIDはその単語と紐付いている番号('n' + 8桁の番号)である。

ImageNetのサイトで画像を探すときはSynsetで検索できる。またSynsetは親(上位概念)と子(下位概念)で構成されている。(「Poodle, poodle dog」の下位概念には「Toy poodle」がいる。)

f:id:takuroooooo:20201030083214p:plain

左のツリーがSynsetの階層構造。 WordNetIDは、右上の黄色いマークを押すと取得できる。

人間に分かりやすいSynsetだけ分かればいいのでは?と思うが、プログラムで画像をダウンロードする際にはWordNetIDでダウンロードしたい種別を選択する必要がある。

ダウンロード方法

プログラムでダウンロードするためにはImageNetが提供しているAPIを使う。APIとはあるURLのことで、このURLにWordNetIDをくっつけてアクセスすると、そのWordNetIDに関する情報が取得できる。 APIにはいろいろな種類がある。

ImageNet API

No. PageName URL
1 Synset検索画面のページ http://www.image-net.org/synset?wnid=***
2 指定したwnidの下位概念のwnid表示 http://www.image-net.org/api/text/wordnet.structure.hyponym?wnid=***
3 指定したwnidの下位概念のwnid表示(最下層まで探索) http://www.image-net.org/api/text/wordnet.structure.hyponym?wnid=***&full=1
4 指定したwnidに対応するSynset表示 http://www.image-net.org/api/text/wordnet.synset.getwords?wnid=***
5 指定したwnidに対応するファイル名と画像のURL表示 http://www.image-net.org/api/text/imagenet.synset.geturls.getmapping?wnid=***
6 指定したwnidに対応する画像のURL表示 http://www.image-net.org/api/text/imagenet.synset.geturls?wnid=***
  • 米印にはWordNetIDが入ります。
  • wnid = WordNetIDです。

順番に見ていくと、 1は「WordnetIDとSynsetとは」で見た検索画面にジャンプするURL。 2と3は、指定したWordNetIDの下位概念を表示する。2は直下のものしか表示しないが、3は最下層まで表示する。


f:id:takuroooooo:20201030083319p:plain

4は、指定したWordNetIDに対応するSynsetを表示するURL。


f:id:takuroooooo:20201030083352p:plain


5,6は、指定したWordNetIDに属する画像のURL一覧が表示される。5はファイル名も合わせて表示される。


f:id:takuroooooo:20201030083419p:plain


n20113335_17679というのがファイル名になっている。 バウンディングボックスなどはファイル名と紐付いているので、5番のAPIでアクセスするのがいいと思う。

ということで画像ダウンロードに使うのは5番です。 方法としては、 1.Pythonurllibを使って5番のページに情報取得の要求を出す。 2.ファイル名とURLの一覧が手に入る。 3.取得したURLに対してさらに情報取得の要求を出す。 4.URL先の画像データが手に入る。 5.画像データをファイルに書く。 以下3-5を繰り返す。

コード

以下サンプルコード

1.ファイル名と画像元URLの取得

from urllib import request
IMG_LIST_URL="http://www.image-net.org/api/text/imagenet.synset.geturls.getmapping?wnid={}"

url = IMG_LIST_URL.format("n02113335")
with request.urlopen(url) as response:
    html = response.read() 

URLをurlopen()で開いて、read()するだけでページの情報が取得できます。この時点でhtmlには、バイナリ型で以下のような文字列が入っている。

n02113335_17679 http://farm1.static.flickr.com/194/467227983_ce131cca2a.jpg
n02113335_4957 http://static.flickr.com/164/388222083_d98ab2ec7e.jpg
n02113335_4907 http://www.dkimages.com/discover/previews/919/65004609.JPG
n02113335_4943 http://farm1.static.flickr.com/82/225053708_e1b941261a.jpg
n02113335_4942 http://farm1.static.flickr.com/17/19821754_6cb866105a.jpg
n02113335_4935 http://farm3.static.flickr.com/2397/2132261952_28dd898274.jpg
・
・
・

文字列として扱いたいのでdecodeする。(バイナリ型->文字列型)

data = html.decode()

あとはこれを1列目と2列目に分割すればファイル名と画像元のURLが取得できる。 例えば以下のようにする。

data = data.split()
# data = [fname_0, url_0, fname_1, url_1, .....]
fnames = data[::2]
urls = data[1::2]

2.画像の取得

先ほど取得した1枚目のURLを使って、1と同様にアクセスする。

url="http://farm1.static.flickr.com/194/467227983_ce131cca2a.jpg"
with request.urlopen(url) as response:
    img = response.read() 

このときimgにバイナリ形式の画像データが入っている。 画像なので、以下のようにそのままバイナリ形式でファイルに書けばダウンロード完了。

with open('n02113335_17679.jpg', 'wb') as f:
    f.write(img)

ImageNet_Downloader

f:id:takuroooooo:20201030083556p:plain

github.com

ここまで紹介したImageNetAPIのラッパークラスを書いてみた。中身の処理は先ほど説明したことがメインなので割愛。以下、クラスの簡単な使い方。

import downloader
import os
root_dir = os.getcwd()
wnid = "n02113335"
api = downloader.ImageNet(root_dir)
api.download(wnid, verbose=True)

downloadでwnidを設定すると、そのwnidの画像をダウンロードし始める。以下のフォルダを作成して順次ダウンロードしていく。n02113335.txtは、ファイル名と画像のURLが書いてあるテキストファイル。

root/  
 ├ img/  
 │ └ n02113335/
 │       └ n02113335_xxx.jpg
 │       └ n02113335_xxx.jpg
 │       └ ....
 ├ list/  
 │ └ n02113335.txt

example.pyを使うと簡単にダウンロードできる。

python example.py <WordnetID> -v

下位層も含めてダウンロードしたい場合は-rをつける。

python example.py <WordnetID> -v -r

ダウンロードする枚数を指定した場合は、-limit <num>を指定する。

python example.py <WordnetID> -v -r -limit 100

まとめ

Pythonのurllibを使うと簡単にネット上の画像をダウンロードできる。

参考

Random Erasingの動きを見てみる

*この記事は以前Qiitaで書いたものです。

qiita.com

目次

概要

データ拡張の一つであるRandom Erasingの処理を説明する記事

論文の内容は以下の記事で別にまとめています。

takuroooooo.hatenablog.com

Random Erasingとは

2017年に発表されたデータ拡張。 Random Erasing Data Augmentation 画像上に矩形を重畳することでデータの水増しを行う。

  • 実装が簡単で。
  • 他のデータ拡張と併用可能で。
  • Occlusionに対して強いモデルを作れる。
    という特徴がある。

アルゴリズム

こちらは論文に書いてあるRandomErasingのアルゴリズム

f:id:takuroooooo:20201030082407p:plain

論文のアルゴリズムPythonで実装すると以下のようになる。 変数名は論文の中の記号と合わせている。

def random_erasing(img, p=0.5, sl=0.02, sh=0.4, r1=0.3, r2=3.3):
    target_img = img.copy()

    if p < np.random.rand():
        # RandomErasingを実行しない
        return target_img 

    H, W, _ = target_img.shape
    S = H * W

    while True:
        Se = np.random.uniform(sl, sh) * S # 画像に重畳する矩形の面積
        re = np.random.uniform(r1, r2) # 画像に重畳する矩形のアスペクト比

        He = int(np.sqrt(Se * re)) # 画像に重畳する矩形のHeight
        We = int(np.sqrt(Se / re)) # 画像に重畳する矩形のWidth

        xe = np.random.randint(0, W) # 画像に重畳する矩形のx座標
        ye = np.random.randint(0, H) # 画像に重畳する矩形のy座標

        if xe + We <= W and ye + He <= H:
            # 画像に重畳する矩形が画像からはみ出していなければbreak
            break

    mask = np.random.randint(0, 255, (He, We, 3)) # 矩形がを生成 矩形内の値はランダム値
    target_img[ye:ye + He, xe:xe + We, :] = mask # 画像に矩形を重畳

    return target_img
記号 意味
p Random Erasingを実行する確率
S 入力画像の面積
H, W 入力画像の高さと幅
sl, sh 入力画像面積に対する矩形面積の比率のレンジ
r1, r2 画像上に描画される矩形のアスペクト比のレンジ
Se 画像上に描画される矩形の面積
He, We 画像上に描画される矩形の高さと幅
re 画像上に描画される矩形のアスペクト比 He/We
xe, ye 画像上に描画される矩形のxy座標

各記号を図形上で表すと以下のようになる。

f:id:takuroooooo:20201030082457p:plain

アルゴリズムの流れ

  1. 確率pを使ってRandomErasingを実行するかを判定する。
  2. slshから矩形の面積Seを求める。
  3. r1r2から矩形のアスペクト比reを求める。
  4. Sereから矩形の縦幅He矩形の横幅Weを求める。
  5. 矩形を重畳する位置xeyeを求める。
  6. (xe,ye,We,He)を使って画像上に矩形を重畳したときに矩形が画像からはみ出していれば1からやり直し。
  7. 矩形内を埋める値をランダムに生成し、矩形を画像に重畳する。
  8. 終わり。

論文では以下の設定が一番性能が良かったとしている。

  • p=0.5
  • sl=0.02, sh=0.4
  • r1=1/r2=0.3, r2=3.3

また

  • 矩形内はランダム値で埋めるのが性能が良い。(ImageNetの平均値や全て0 or 255と比較して)
  • ObjectDetectionの場合は、矩形の生成位置を各BoundingBoxの中+画像全体からランダムで決めると性能が良い
  • RandomCropやRandomFlipと併用しても性能が良くなる。(他のデータ拡張手法と補完的な関係になっている。)

ということが論文の中で述べられている。

RandomErasingのパラメータをいじって結果の変化を可視化する

RandomErasingを使う際にユーザーは以下のパラメータを決めなければいけない。

記号 意味
p Random Erasingを実行する確率
sl, sh 入力画像面積に対する矩形面積の比率のレンジ
r1, r2 画像上に描画される矩形のアスペクト比のレンジ

上記パラメータを決める際に、パラメータを変えるとどんな矩形が生成されるのか可視化できたら便利だと思いmatplotlibでRandomErasingの可視化ツールを作った。

github.com

f:id:takuroooooo:20201030082531g:plain

このツールを使用すると

  • shを大きくする = 画像に対して大きい矩形が生成される
  • shを小さくする = 画像に対して小さい矩形が生成される
  • r1を大きくする = 正方形に近いの矩形が生成される
  • r1を小さくする = 細長い矩形が生成される

ということが分かる。 ちなみに

  • sl=0.02
  • r2=1/r1
  • p = 1.0

に固定している。


PyTorch 入力画像と教師画像の両方にランダムなデータ拡張を実行する方法

*この記事は以前Qiitaで書いたものです。

qiita.com

目次

概要

DeepLearningのタスクの1つであるセマンティックセグメンテーション(Semantic Segmentation)では、分類や検出のタスクと異なって、教師データが画像形式になっている。そのためデータ拡張する場合(クロップや反転など)、入力画像と教師データそれぞれに同じように画像処理を行う必要がある。

この記事では入力画像と教師データの両方に同様のランダムなデータ拡張を実行する方法を紹介する記事。

セマンティックセグメンテーションとは

セマンティックセグメンテーションについては以下が参考になります。 U-NetでPascal VOC 2012の画像をSemantic Segmentationする (TensorFlow)

今回はサンプル画像としてVOCデータセットの画像を使用する。 解像度はどちらも500x281になっている。

・入力画像

f:id:takuroooooo:20201030081329j:plain

・教師画像

f:id:takuroooooo:20201030081340p:plain

ランダムなデータ拡張

今回はPyTorchで予め用意されているtorchvision.transforms.RandomCrop(size, padding=0, pad_if_needed=False)を使って画像に対してクロップを実行する。

関数の中では、乱数でクロップする位置を決めて、指定したsizeでクロップしている。(最終的には内部でtorchvision.transforms.functional.crop(img, i, j, h, w)がコールされている。)

詳細な使い方やパラメータについてはPyTorchのリファレンスを参照してください。 PyTorch TORCHVISION.TRANSFORMS

課題

torchvision.transforms.RandomCropは内部で乱数を発生させているため、実行するたびに結果が異なってしまう。 よって、以下のように実行すると入力画像と教師画像が異なる位置がクロップされてしまう。

from PIL import Image
from torchvision import transforms

trans_crop = transforms.RandomCrop((224,224))

img = Image.open(img_path) # 入力画像
target = Image.open(target_path) # 教師画像

img = trans_crop(img) # 入力画像を(224,224)でランダムクロップ
target = trans_crop(target) # 教師画像を(224,224)でランダムクロップ

img.show()
target.show()

f:id:takuroooooo:20201030081421p:plain f:id:takuroooooo:20201030081454p:plain

このように入力画像と教師画像が一致しないため学習ができなくなってしまう。

解決策1 乱数シードを固定する

PyTorchのクロップ関数の内部ではrandom.randint()で乱数を発生させているので、random.seed()を使って乱数シードを設定すれば同じ結果が得られる。

from PIL import Image
from torchvision import transforms
import random

trans = transforms.RandomCrop((224,224))

img = Image.open(img_path)
target = Image.open(target_path)

seed = random.randint(0, 2**32) # 乱数で乱数シードを決定

random.seed(seed) # 乱数シードを固定
img = trans(img)

random.seed(seed) # こっちでも乱数シードを固定
target = trans(target)

img.show()
target.show()

f:id:takuroooooo:20201030081546p:plain f:id:takuroooooo:20201030081556p:plain

この実装のよくない点はPyTorchの中で「random.randint()で乱数を発生させている」ということを前提としていること。 PyTorchの中で乱数の発生のさせ方が変わると急に上手く動作しなくなったりする。

解決策2 transforms.RandomCrop.get_params(img, output_size))を使う

transforms.RandomCrop.get_params(img, output_size))は乱数で決めたクロップする位置とサイズを返してくれる関数。

torchvision.transforms.RandomCropの中でもこの関数でクロップ位置を決めた後、torchvision.transforms.functional.crop(img, i, j, h, w)でクロップしている。

なので、これと同じような処理の流れを自分で実装すれば解決できる。

from PIL import Image
from torchvision import transforms
from torchvision.transforms import functional as tvf
import random

trans = transforms.RandomCrop((224,224))

img = Image.open(img_path)
target = Image.open(target_path)

# クロップ位置を乱数で決定
i, j, h, w = transforms.RandomCrop.get_params(img, output_size=(224,224))

img = tvf.crop(img, i, j, h, w) # 入力画像を(224,224)でクロップ
target = tvf.crop(target, i, j, h, w) # 教師画像を(224,224)でクロップ

f:id:takuroooooo:20201030081630p:plain f:id:takuroooooo:20201030081640p:plain


PyTorch transforms/Dataset/DataLoaderの基本動作を確認する

*この記事は以前Qiitaで書いたものです。

qiita.com

目次

概要

PyTorchの前処理とデータのロードを担当するtransforms/Dataset/DataLoaderの動作を簡単な例で確認する。

この記事の対象読者

  • これからPyTorchを勉強しようとしている人
  • PyTorchのtransforms/Dataset/DataLoaderの役割を知りたい人
  • オリジナルのtransforms/Dataset/DataLoaderを実装したい人

前置き

DeepLearningのフレームワークではだいたい以下のような機能をサポートしている。

この中で「データの前処理」と「データセットのロード」は自分達の環境によってカスタマイズすることがよくあるので、フレームワークがどのような機能をサポートしているのかを把握することが実装の効率化に繋がる。

今回はPyTorchの「データの前処理」と「データセットのロード」を実現するためのモジュールtransforms/Dataset/DataLoaderの動きを簡単なデータセットを使って確認してみる。

PyTorch Tutorial

今回は、PyTorchのTutorialのDATA LOADING AND PROCESSING TUTORIALで紹介されている内容を参考にしている。

このTutorialでは実際のデータセット(顔画像と顔の特徴点)を使用してtransforms/Dataset/DataLoaderについて解説している。

Tutorialの内容は分かりやすい構成になっているが、画像データやcsvファイルを扱うので、コード上にこれらのファイル特有の処理が入っており、transforms/Dataset/DataLoaderだけの動きを知りたい場合は、やや冗長な部分がある。

よって今回は数字の1次元のリスト形式のデータセットを使って、transforms/Dataset/DataLoaderを動かしていく。

transforms/Dataset/DataLoaderの役割

  • transforms
    • データの前処理を担当するモジュール
  • Dataset
    • データとそれに対応するラベルを1組返すモジュール
    • データを返すときにtransformsを使って前処理したものを返す。
  • DataLoader
    • データセットからデータをバッチサイズに固めて返すモジュール

上記説明にあるとおり Datasetはtransformsを制御して DataLoaderはDatasetを制御する という関係になっている。

なので流れとしては、 1.Datasetクラスをインスタンス化するときに、transformsを引数として渡す。 2.DataLoaderクラスをインスタンス化するときに、Datasetを引数で渡す。 3.トレーニングするときにDataLoaderを使ってデータとラベルをバッチサイズで取得する。 という流れになる。

以下各詳細を、transforms、Dataset、DataLoaderの順に動作を見ていく。

transforms

transformsはデータの前処理を行う。 PyTorchではあらかじめ便利な前処理がいくつか実装されている。 例えば、画像に関する前処理はtorchvision.transformsにまとまっており、CropやFlipなどメジャーな前処理があらかじめ用意されている。

今回は自分で簡単なtransformsを実装することで処理内容の理解を深めていく。

transformsを実装するのに必要な要件

予め用意されているtransformsの動作に習うために「コール可能なクラス」として実装する必要がある。 (「コール可能」とは__call__を実装しているクラスのこと) なぜ「コール可能なクラス」にする必要があるのかというと、Tutorialでは以下のように説明している。

We will write them as callable classes instead of simple functions so that parameters of the transform need not be passed everytime it’s called.

つまり、クラスにしておけば、インスタンス化時に前処理に使うパラメータを全部渡しておけるので、前処理を実行するたびにパラメータを渡す手間が省ける、ということで「コール可能なクラス」を推奨している。

今回はデータとして数字の1次元配列を使うので入力値を二乗するtransformsを実装する。

実装
class Square(object):
    def __init__(self):
        pass
    
    def __call__(self, sample):
        return sample ** 2

実装はこれだけ。

使い方
transform = Square()
print(transform(1)) # -> 1
print(transform(2)) # -> 4
print(transform(3)) # -> 9
print(transform(4)) # -> 16

渡した数値が二乗になっていることが確認できる。 もし画像データ用のtransformsを実装したい場合は、__call__の中に画像処理を実装すればいい。

Dataset

Datasetは、入力データとそれに対応するラベルを1組返すモジュール。 データはtransformsで前処理を行った後に返す。そのためDatasetを作るときは引数でtransformsを渡す必要がある。

PyTorchでは有名なデータセットがあらかじめtorchvision.datasetsに定義されている。(MNIST/CIFAR/STL10など)

自前のデータを扱いたいときは自分のデータをリードして返してくれるDatasetを実装する必要がある。

扱うデータが画像でクラスごとにフォルダ分けされている場合はtorchvision.datasets.ImageFolderという便利なクラスもある。(KerasのImageDataGeneratorflow_from_directory()のような機能)

Datasetを実装するのに必要な要件

オリジナルDatasetを実装するときに守る必要がある要件は以下3つ。

  • torch.utils.data.Datasetを継承する。
  • __len__を実装する。
  • __getitem__を実装する。

__len__は、len(obj)で実行されたときにコールされる関数。 __getitem__は、obj[i]のようにインデックスで指定されたときにコールされる関数。

今回は、データは数字のリスト、ラベルは偶数の場合だけTrueになるものを出力するDatasetを実装する。

実装
import torch
class MyDataset(torch.utils.data.Dataset):
    
    def __init__(self, data_num, transform=None):
        self.transform = transform
        self.data_num = data_num
        self.data = []
        self.label = []
        for x in range(self.data_num):
            self.data.append(x) # 0 から (data_num-1) までのリスト
            self.label.append(x%2 == 0) # 偶数ならTrue 奇数ならFalse
            
    def __len__(self):
        return self.data_num
    
    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label =  self.label[idx]
        
        if self.transform:
            out_data = self.transform(out_data)
            
        return out_data, out_label

ポイントは__getitem__でデータを返す前にtransformでデータに前処理をしてから返しているところ。 データセットとして画像やcsvファイルを扱う場合は、__init__や__getitem__の中でファイルをオープンする必要がある。

使い方
data_set = MyDataset(10, transform=None)
print(data_set[0]) # -> (0, True)
print(data_set[1]) # -> (1, False)
print(data_set[2]) # -> (2, True)
print(data_set[3]) # -> (3, False)
print(data_set[4]) # -> (4, True)

# 先ほど実装したtransformsを渡してみる.
# データが二乗されていることに注目.
data_set = MyDataset(10, transform=Square())
print(data_set[0]) # -> (0, True)
print(data_set[1]) # -> (1, False)
print(data_set[2]) # -> (4, True)
print(data_set[3]) # -> (9, False)
print(data_set[4]) # -> (16, True)

指定したインデックスのデータとラベルがセットで取得できている。 次に説明するDataLoaderは、この仕組みを利用してバッチサイズ分のデータを生成する。

DataLoader

データセットからデータをバッチサイズに固めて返すモジュール DataLoaderはデータセットを使ってバッチサイズ分のデータを生成する。またデータのシャッフル機能も持つ。 データを返すときは、データをtensor型に変換して返す。 tensor型は、計算グラフを保持することができる変数でDeepLearningの勾配計算に不可欠な変数になっている。

DataLoaderは、torch.utils.data.DataLoaderというクラスが既に用意されている。たいていの場合、このクラスで十分対応できるので、今回はこのクラスにこれまで実装してきたDatasetを渡して動作を見てみる。

使い方
import torch
data_set = MyDataset(10, transform=Square())
dataloader = torch.utils.data.DataLoader(data_set, batch_size=2, shuffle=True)

for i in dataloader:
    print(i)

# [tensor([ 4, 25]), tensor([1, 0])]
# [tensor([64,  0]), tensor([1, 1])]
# [tensor([36, 16]), tensor([1, 1])]
# [tensor([1, 9]), tensor([0, 0])]
# [tensor([81, 49]), tensor([0, 0])]

指定したバッチサイズでかつデータがシャッフルされていることがわかる。transformsで値が二乗に変換されている。 shuffle=Falseにすると順番にデータが出力される。

import torch
data_set = MyDataset(10, transform=Square())
dataloader = torch.utils.data.DataLoader(data_set, batch_size=2, shuffle=False)

for i in dataloader:
    print(i)

# [tensor([0, 1]), tensor([1, 0])]
# [tensor([4, 9]), tensor([1, 0])]
# [tensor([16, 25]), tensor([1, 0])]
# [tensor([36, 49]), tensor([1, 0])]
# [tensor([64, 81]), tensor([1, 0])]

学習のときはdataloaderのループをさらにepochのループでかぶせる。

epochs = 4
for epoch in epochs:
    for i in dataloader:
        # 学習処理

PyTorchでValidation Datasetを作る方法

*この記事は以前Qiitaで書いたものです。

qiita.com

目次

概要

PyTorchにはあらかじめ有名なデータセットがいくつか用意されている(torchvision.datasetsを使ってMNIST/CIFARなどダウロードできる)。しかし、train/testでしか分離されていないので、ここからvalidationデータセットを作ってみる。 例題としてtorchvision.datasets.MNISTを使う。

課題

torchvision.datasets.MNISTを使うと簡単にPyTorchのDatasetを作ることができるが、train/test用のDatasetしか用意されていないためvalidation用のDatasetを自分で作る必要がある。

以下のコードはtrain/test用のDatasetを作っている。

from torchvision import datasets
trainval_dataset = datasets.MNIST('./', train=True, download=True)
test_dataset = datasets.MNIST('./', train=False, download=True)
print(len(trainval_dataset)) # 60000
print(len(test_dataset)) # 10000
print(type(trainval_dataset)) # torchvision.datasets.mnist.MNIST
print(type(test_dataset)) # torchvision.datasets.mnist.MNIST

このtrainval_datasetをtrain/validationに分割したい。 しかし、trainval_datasetは単純なリスト形式ではなく、PyTorchのDatasetになっているため、「Datasetが持つデータを取り出して、それをDatasetクラスに再構成する。」みたいなやり方だと手間がかかる上にうまくいかないことがある。(うまくいかない例としては、DatasetクラスにTransformクラスを渡している場合。)

解決策1 torch.utils.data.Subset

torch.utils.data.Subset(dataset, indices)を使うと簡単にDatasetを分割できる。 PyTorchの中のコードは以下のようにシンプルなクラスになっている。

class Subset(Dataset):
    """
    Subset of a dataset at specified indices.

    Arguments:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

つまり、Datasetとインデックスのリストを受け取って、そのインデックスのリストの範囲内でしかアクセスしないDatasetを生成してくれる。

文章にするとややこしいけどコードの例を見るとわかりやすい。 以下のコードはMNISTの60000のDatasetをtrain:48000とvalidation:12000のDatasetに分割している。

from torchvision import datasets
from torch.utils.data.dataset import Subset
trainval_dataset = datasets.MNIST('./', train=True, download=True)

n_samples = len(trainval_dataset) # n_samples is 60000
train_size = n_samples * 0.8 # train_size is 48000

subset1_indices = list(range(0,train_size)) # [0,1,.....47999]
subset2_indices = list(range(train_size,n_samples)) # [48000,48001,.....59999]
        
train_dataset = Subset(trainval_dataset, subset1_indices)
val_dataset   = Subset(trainval_dataset, subset2_indices)

print(len(train_dataset)) # 48000
print(len(val_dataset)) # 12000

解決策2 torch.utils.data.random_split

解決策1はランダム性がない分割の仕方だったが、torch.utils.data.random_split(dataset, lengths)を使うとランダムに分割することができる。

from torchvision import datasets

trainval_dataset = datasets.MNIST('./', train=True, download=True)

n_samples = len(trainval_dataset) # n_samples is 60000
train_size = int(len(trainval_dataset) * 0.8) # train_size is 48000
val_size = n_samples - train_size # val_size is 48000

# shuffleしてから分割してくれる.
train_dataset, val_dataset = torch.utils.data.random_split(trainval_dataset, [train_size, val_size])

print(len(train_dataset)) # 48000
print(len(val_dataset)) # 12000

Chainerのchainer.datasets.split_dataset_randomについて

ちなみにChainerのchainer.datasets.split_dataset_randomtorch.utils.data.random_splitと同じようなことをしてくれる。

Chainerの実装を参考に同じようなものを作ると以下のようなコードになる。

from torch.utils.data.dataset import Subset

def split_dataset(data_set, split_at, order=None):
    from torch.utils.data.dataset import Subset
    n_examples = len(data_set)
    
    if split_at < 0:
        raise ValueError('split_at must be non-negative')
    if split_at > n_examples:
        raise ValueError('split_at exceeds the dataset size')
        
    if order is not None:
        subset1_indices = order[0:split_at]
        subset2_indices = order[split_at:n_examples]
    else:
        subset1_indices = list(range(0,split_at))
        subset2_indices = list(range(split_at,n_examples))
        
    subset1 = Subset(data_set, subset1_indices)
    subset2 = Subset(data_set, subset2_indices)
    
    return subset1, subset2

def split_dataset_random(data_set, first_size, seed=0):
    order = np.random.RandomState(seed).permutation(len(data_set))
    return split_dataset(data_set, int(first_size), order)

これを使うとランダムに分割できる。

from torchvision import datasets

trainval_dataset = datasets.MNIST('./', train=True, download=True)
n_samples = len(trainval_dataset) # n_samples is 60000
train_size = n_samples * 0.8 # train_size is 48000

train_dataset, val_dataset = split_dataset_random(trainval_dataset, train_size, seed=0)

print(len(train_dataset)) # 48000
print(len(val_dataset)) # 12000

参考