takuroooのブログ

勉強したこととか

Kerasを勉強した後にPyTorchを勉強して躓いたこと

*この記事は以前Qiitaで書いたものです。

qiita.com

目次

概要

DeppLearningのフレームワークで最初にKerasを勉強した後に、Define by RunのPyTorch勉強してみて躓いたポイントをまとめてみる。

この記事の対象読者

Kerasの次にPyTorchを勉強してみようと思っている人。

はじめに

今回いくつか挙げている躓いたポイントはPyTorchに限らないものがある。またKerasといえばバックエンドはTensorFlowのものを指す。バックエンドがTensorFlowでない場合は話が当てはまらないものもあるので注意。

今回挙げたポイントは以下の5つ

  1. Channel First
  2. GPUへの転送
  3. CrossEntropyがSoftmax+CrossEntropyになっている
  4. CrossEntropyがone-hot-vectorに対応していない
  5. 学習と評価を区別する

以下、各ポイントの詳細について説明していく。

Channel First

PyTorchではモデルの入力と出力がChannel Firstの形式になっている。Channel Firstとは画像の次元の並びが(C, H, W)のようにChannelの次元が最初になっていること。 KerasではChannel Lastになっているため、(H, W, C)のようにChannelの次元が最後にくる。

実際にモデルに入力するときは、バッチサイズも合わせた4次元で表現する必要があるため、 PyTorch:(N, C, H, W) Keras:(N, H, W, C) となる。

記号の意味は N:バッチサイズ C:チャネル数 H:画像のHeight W:画像のWidth

画像を読み込む際は、OpenCVかPILを使用する場合が多いが、これらのモジュールはChannel Lastで画像を扱う仕様になっている。なので、PyTorchのモデルに入力する前に以下のコードのようにChannel Firstに変換する必要がある。

img = cv2.imread(img_path)
img = img.transpose((2, 0, 1)) # H x W x C -> C x H x W

モデルの出力もChannel Firstなのでmatplotlibなどで表示したい場合はChannel Lastに変換してから表示する。

output = output.numpy() # tensor -> ndarray
output = output.transpose(1, 2, 0) # C x H x W -> H x W x C

GPUへの転送

KerasではGPUを使う場合、GPU側のメモリを意識することがなかったが、PyTorchではGPUを使用する場合、明示的に学習するパラメータや入力データをGPU側のメモリに転送しなければならない。 以下のコードではモデルと入力データをGPUに転送している。

device = torch.device("cuda:0")
# modelはnn.Moduleを継承したクラス
model = model.to(device) # GPUへ転送
・
・
・
for imgs, labels in train_loader:
    imgs, labels = imgs.to(device), labels.to(device) # GPUへ転送

GPU上にあるデータCPUに転送したい場合も以下のようにコードを書く必要がある。

device = torch.device("cpu")
model.to(device)

CrossEntropyがSoftmax+CrossEntropyになっている

Kerasで多クラスの識別モデルを学習するときは、モデルの最終層でsoftmaxを実行してからcategorical_crossentropyでロスを計算する流れになっている。 一方PyTorchではロス関数であるtorch.nn.CrossEntropyの中でSoftmaxの計算も一緒に行っているので、モデルの最終層でSoftmaxは不要になる。

たまにPyTorchのサンプルコードで最終層にtorch.nn.LogSoftmaxを置いて、ロス関数にtorch.nn.NLLLossを指定している場合がある。これは最終層を恒等関数にしてtorch.nn.CrossEntropyを使っているのと同じになる。 つまり、 torch.nn.CrossEntropy=torch.nn.LogSoftmaxtorch.nn.NLLLoss という関係になっている。

torch.nn.LogSoftmaxは名前の通りSoftmaxの計算にLogをかぶせたものになっている。

f:id:takuroooooo:20201030075738p:plain

このLogはCrossEntropyの式にあるLogを持ってきているのだが、LogとSoftmaxを先に一緒に計算しておくことで、計算結果を安定させている。 なぜLog+Softmaxが計算的に安定するかは以下のページで解説されている。

Tricks of the Trade: LogSumExp

ちなみにtorch.nn.NLLLossはCrossEntropyのLogを抜いた他の計算を行っている。

CrossEntropyがone-hot-vectorに対応していない

Kerasではロスを計算するときに、labelはone-hot-vector形式で渡す必要があるがPyTorchでは正解の値をそのまま渡す。

例えば、3クラスの分類で正解が2番目のクラスの場合、Kerasでは[0, 1, 0]というリストをロス関数に渡すが、PyTorchでは2という値を渡す。

学習と評価を区別する

PyTorchでは、モデルを動作させるときに学習中なのか評価中なのかを明示的にコードで示す必要がある。なぜこれが必要なのかは理由が2つある。

1.学習中と評価中に挙動が変わるレイヤーがあるから 2.学習中には必要で評価中には不必要な計算があるから

1は、DropOutやBatchNormalizationなどのことで、これらのレイヤーは学習中と評価中で動作が変わる。よって、コードでこれから動作するのが学習なのか評価なのかを知らせる必要がある。 具体的には以下のようなコードになる。

# modelはnn.Moduleを継承したクラス
model.train() # 学習モードに遷移
model.eval() # 評価モードに遷移

2の不必要な計算とは計算グラフを作ることである。学習中は計算グラフを作って、誤差逆伝播法で誤差を計算グラフ上に伝播させて重みを更新する必要がある。しかし、学習以外の処理ではこの計算グラフの構築が不要になるので「計算グラフを作りません」とコードで示す必要がある。 具体的にはwith torch.no_grad()を使う。

model.eval() # 評価モードに遷移
with torch.no_grad(): # この中では計算グラフは作らない
    output = model(imgs)

Pythonで公約数の列挙

この記事で公約数列挙の仕方が2つあることを学んだのでメモ

qiita.com

公約数の列挙は

  1. 二つの整数を割り切れる数をループで探す
  2. 最大公約数の約数を列挙する

の2通りあるらしい。 これをPythonで実装してみる。

目次

1.二つの整数を割り切れる数をループで探す

こちらの方法は特に難しいことはなく、二つの整数を割り切れる数を愚直に探索すれば公約数を求めることができる。

def common_divisor(a, b):
    arr = []
    for i in range(1, min(a, b) + 1):
        if a % i == 0 and b % i == 0:
            arr.append(i)
    return arr

print(common_divisor(12, 18))  # [1, 2, 3, 6]

2.最大公約数の約数を列挙する

AtCoder 版!マスター・オブ・整数 (素因数分解編) - Qiita
この記事によると

二つの整数 a,b の公約数は、a,b の最大公約数の約数である。

とのこと。

つまり、

  1. 2つの整数の最大公約数を求めて
  2. この最大公約数の約数を列挙

すれば公約数を列挙したことになる。

どうしてこれが公約数になるのかイメージできなかったが、AtCoder Beginner Contest 142 の「D - Disjoint Set of Common Divisors」の解説で図を使った面白い解説があった。

youtu.be

(以下動画の解説をそのまま図にしたもの)

12と18の公約数列挙を考える。 まずは12を素因数分解する。出てきた素数を二次元上に並べて掛け合わせると、掛け合わせた答えが12の約数になっていることがわかる。

f:id:takuroooooo:20201018183453p:plain

同様に18も素因数分解してみる。

f:id:takuroooooo:20201018183754p:plain

二つの表を見ると右上に元となる数(12と18)があり、そこから左下に向かって約数が広がっているという特徴がある。

もう一つの特徴として、二つの表には重なるところがある。実はこの重なったところが12と18の公約数となっている。

f:id:takuroooooo:20201018183956p:plain

そして最大公約数はどこにあるかというと赤枠内の右上にある。

f:id:takuroooooo:20201018184454p:plain

この図のルールとして、右上の数字(緑のマス)の約数は緑マスを最右端として構成される四角形内の数字なので、つまり緑マス6(12と18の最大公約数)の約数は赤枠内の1,2,3,6(12と18の公約数)となる。

これは
「二つの整数 a,b の公約数は、a,b の最大公約数の約数である。」
を意味している。

このことから、繰り返しになるが、

  1. 2つの整数の最大公約数を求めて
  2. この最大公約数の約数を列挙

を実装すれば公約数を列挙することができる。 最大公約数はmathモジュールのgcdで求めることができる。
約数の列挙は以下記事のコードを使用した。

qiita.com

def divisors(n):
    lower_divisors, upper_divisors = [], []
    i = 1
    while i * i <= n:
        if n % i == 0:
            lower_divisors.append(i)
            if i != n // i:
                upper_divisors.append(n // i)
        i += 1
    return lower_divisors + upper_divisors[::-1]

import math
print(divisors(math.gcd(12, 18)))  # [1, 2, 3, 6]

AtCoder Regular Contest 105

Aしか解けなかった...

atcoder.jp

A - Fourtune Cookies

takuroooooo.hatenablog.com

この記事で練習したbit全探索を使ってAC 解説読むと等符号の関係(A<=B<=C<=D)からA+D=B+CA+B+C=Dのにパターンだけを確認すればよいらしい。

提出したコード

arr = list(map(int,input().split()))
 
n = len(arr)
for bit in range(1<<n):
  x = 0
  y = 0
  for i in range(n):
    if bit & (1<<i):
      x += arr[i]
    else:
      y += arr[i]
  if x == y:
    print('Yes')
    exit(0)
print('No')

B - Max-=min

他の人は結構解けていたみたいだけど、解けなかった... 与えられた数列のgcdを計算すればよいとのこと。

解説みてACしたコード

import math
import functools
N=int(input())
A=list(map(int,input().split()))
ans = functools.reduce(math.gcd, A)
print(ans)

これと全く同じ問題が過去のABCで出ていた。 atcoder.jp

解けた人のコメントを読むと

  • N=2を手計算していたらユーグリッドの互除法と同じ操作であることに気づいた
  • N=2でいくつかパターンを計算したら答えあがgcdあることに気づいた

といった感じなので簡単な例で試すってこととgcdの性質をちゃんと理解していることが必要だと思った。(0<=x<=yのときgcd(x,y)==gcd(x,y-x)は同じになるなど)

gcdの記事があったのでこれ読んで勉強してみよう。 qiita.com

Python bit全探索

こちらの記事の「bit全探索」をやってみた。

qiita.com

目次

bit全探索とは

例えば["hoge", "fuga", "piyo"]というリストがあってこれらすべての組み合わせを網羅的に列挙したいときにbit全探索が役に立つ。

組み合わせを網羅的に列挙とは下記パターンを列挙すること (左側の数字はただのインデックス番号、真ん中はインデックス番号をbitで表したもの)

0 :   0 : [ ]
1 :   1 : [ hoge ]
2 :  10 : [ fuga ]
3 :  11 : [ hoge fuga ]
4 : 100 : [ piyo ]
5 : 101 : [ hoge piyo ]
6 : 110 : [ fuga piyo ]
7 : 111 : [ hoge fuga piyo ]

このとき真ん中のbitが立っている位置と組み合わせで選ばれているもののリストのインデックス番号が一致している。 (1 : 1 : [ hoge ]は0bit目が立っている。hogeはリストのインデックス0の値。)

bitの立っている位置を組み合わせを列挙したいリストのインデックスとして利用することで全探索が可能になる。これをbit全探索と呼ぶらしい。

先の出力をするコードは以下のようになる。

w = ["hoge", "fuga", "piyo"]
n_bit = len(w)

for bit in range(1 << n_bit): # 1 << n_bit == 2 ** n_bit
    arr = []
    for i in range(n_bit):
        if bit & (1 << i): # bitの立っている位置を確認
            arr.append(w[i])

    # ここはprintしているだけなのでbit全探索とは関係ない
    print(f'{bit} : {bin(bit).replace("0b", ""):>3} : [ ', end='')
    for a in arr:
        print(a, end=' ')
    print(']')

リストの要素数がn個のとき、組み合わせ数は2**nになる。 この例では要素数が3なので2**3 == 8個の組み合わせがprintされる。(8個の中には空集合も含む)

bit全探索が何者か分かったので冒頭のQiitaの記事にあるAtCoderの問題を解いてみる。

ABC061 C - たくさんの数式

C - Many Formulas

問題

  • 125みたいな文字列Sが与えられたときに、 (Sは10文字以内)
  • +を使って考えられる数式を全て作り、
  • その計算結果の合計値を出力しなさい
    という問題

125を例にすると125、1+25、12+5、1+2+5 という4通りがある。

解答

  • +記号は文字の間にしか入れられないので、文字の間に+を入れるか入れないかを全探索する。(bitが立っていたら+を入れる、立っていなかったら+を入れない。)

  • +を入れられる隙間の数は文字列の長さ-1なので、125の場合は2箇所に+を入れることができる。この2箇所に対して全探索するので(この例では)式は2**2 == 4個の組み合わせがでる。

  • Pythonでは組み込み関数のevalを使うことで文字列の式を評価できる。(intで結果を返してくれる。)

print(eval("1+1")) # 2 

解答コード

s = input()

n = len(s) - 1
total = 0
for bit in range(1 << n):
  siki = s[0]
  for i in range(n):
    if bit & (1 << i):
      siki += '+'
    siki += s[i + 1]
  total += eval(siki)
print(total)

ABC079 C - Train Ticket

C - Train Ticket

問題

  • 1222 みたいな文字列が与えられたときに、
  • +-記号を使って、
  • 答えが7になる式を出力しなさい という問題 各文字列間に必ず+-を入れなくてはいけない。

解答

  • 「ABC061 C - たくさんの数式」と同様に文字間に+を入れるか、-を入れるかを全探索する。
s = input()
 
def solve(): # 答えがでたときにループを抜けるのが面倒くさかったので関数にしてある
  l = len(s) - 1
  for bit in range(1<<l):
    siki = s[0]
    for i in range(l):
      if bit & (1<<i):
        siki += "+"
      else:
        siki += "-"
      siki += s[i+1]
      
    if eval(siki) == 7:
      siki += "=7"
      print(siki)
      return
 
solve()

ABC104 C - All Green

この問題は難しくて解けなかった...
C - All Green

問題

  • 目標スコアGと問題数piとその問題数を全て解いた時にもらえるボーナスciが与えられた時に
  • 目標スコア以上の点数になるためには最低何問の問題を解けばよいか出力しなさい という問題 ボーナスという設定がなければ、点数の高い問題から解いていけばよいが、ボーナスをうまく利用すると点数の低い問題を解くだけで最短で目標点数を達成することがある。

解答

  • ボーナスを獲得したパターンを全探索すればいいらしい。
  • 目標点数に達していない場合は、これに加えて点数の高い問題から解いていく。このときボーナスが発生するまで解く必要はない。ボーナスが発生するパターンは全探索するから。)
D, G = map(int, input().split())
 
P = []
C = []
for _ in range(D):
  p, c = map(int, input().split())
  P.append(p)
  C.append(c)
 
ans = []
for bit in range(1<<D):
  total_score = 0
  solved = 0
  # ボーナスを獲得するパターンを全探索
  for i in range(D):
    if bit & (1<<i):
      total_score += C[i] + ((i+1)*100) * P[i] # 解いた問題の点数 + ボーナスを加算
      solved += P[i]
 
  # まだ目標点数に達していない場合は点数の高い問題を解いていく
  for i in range(D - 1, -1, -1):
    if G <= total_score:
      ans.append(solved)
      break
      
    # bitが立っている問題はすでに解いてあるのでskip
    if bit & (1<<i):
      continue
      
    s = (i+1) * 100
    required_score = G - total_score # 目標達成までに必要な点数
    required_problem = (required_score + s - 1) // s # 目標達成までに解く必要のある問題数
    total_score += min(required_problem, P[i]-1) * s # ボーナスが発生しないようにP[i]-1 で1問解かないようにする
    solved += min(required_problem, P[i]-1)
 
    if G <= total_score:
      ans.append(solved)
    
    break
    
print(min(ans))

ARC029 A - 高橋君とお肉

A - 高橋君とお肉

問題

  • N個のお肉があって、(各お肉には焼けるまでの時間が設定されている)
  • 2つの肉焼き器を使った場合に最短で全てのお肉が焼けるまでの時間を出力しなさい という問題

解答

  • お肉を2つの肉焼き器のうちどちらで焼くかを全探索すればいい
  • bitが立っているお肉は肉焼き器1、bitが立っていないお肉は肉焼き器2みたいな感じ
n = int(input())
arr = [int(input()) for _ in range(n)]

ans = float("inf")

for bit in range(1<<n):
  device1 = []
  device2 = []
  for i in range(n):
    if bit & (1<<i):
      device1.append(arr[i]) # 肉焼き器1にお肉を置く
    else:
      device2.append(arr[i]) # 肉焼き器2にお肉を置く

  elaps = max(sum(device1), sum(device2))
  ans = min(ans, elaps)

print(ans)

ABC002 D - 派閥

解けなかった...
D - 派閥

問題

  • 議員の関係性(どの議員とどの議員が知り合いか)が与えられて、そこから派閥の最大数を求めよ という問題
    派閥はすべての議員同士が知り合いでなければいけない。

解答

  • 議員の最大数が12人程度なので派閥を全探索していけばいいらしい
  • 全探索で議員の組み合わせを作り、そこからその組み合わせの派閥が成り立つかをチェックしていく
N, M = map(int, input().split())
con = [[False] * N for _ in range(N)]

for i in range(M):
  x, y = map(int, input().split())
  con[x-1][y-1] = True
  con[y-1][x-1] = True
    
ans = 0
for bit in range(1<<N):
  ok = True
  cnt = bin(bit).count('1') # bitの数が派閥に所属する議員数
  if cnt <= ans:
    continue
  for i in range(N):
    for j in range(N):
      if bit & (1<<i) and bit & (1<<j) and \
         i != j and not con[i][j]: # con[i][j]がTrueのときお互いが知り合いである
        ok = False
  if ok:
      ans = cnt
print(ans)

YouTube Live Streaming API でライブ配信する

YouTube Live Streaming APIを触ってみたのでその覚え書き

目次

YouTube Live Streaming APIとは

公式ページからの引用

The YouTube Live Streaming API lets you create, update, and manage live events on YouTube. Using the API, you can schedule events (broadcasts) and associate them with video streams, which represent the actual broadcast content.

developers.google.com

つまり下記のようなことがAPIでできる。

  • YouTube Liveのイベントを管理(生成、更新、削除)できる。
  • YouTube Liveに配信する映像の設定ができる。(配信プロトコルの設定、解像度、フレームレート、映像を受け付けるURLとストリームキーの取得など)

YouTube関連のAPIにはYoutube Data APIというものもある。こちらはYouTube関連のデータ操作を行えるAPIになっている。YouTube Live Streaming APIYoutube Data APIの機能を使ってるようで、後述するようにAPIを有効にする操作ではYoutube Data APIを有効にすることでYouTube Live Streaming APIが使えるようになる。

YouTube Data API の概要  |  Google Developershttps://developers.google.com/youtube/v3/getting-started?hl=ja

とりあえずAPIを触ってみる

ここでAPIのサンプルを動かすことができる。 developers.google.com

ここではさらにクライアントライブラリを使ったサンプルコードも表示してくれるのでとても便利。

APIを使う前にやること

事前準備

APIを使う前の事前準備としてやっておくことがある。

  1. Googleアカウント作成する。
  2. Google APIsでプロジェクトを作成し、YouTube Data API v3を有効化する。
  3. Google APIsで「OAuth同意画面」を作成(同意画面作成時はとりあえずアプリケーション名だけを設定)
  4. Google APIsで「認証情報」を取得。「APIキー」と「OAuth クライアントID」を取得。OAuthクライアント情報はJSONファイルとしてダウンロードする。

APIキーとOAuth クライアントID

YouTubeのパブリックなデータを操作するときは、APIキーがあればYouTubeAPIを使うことができる。しかし、ユーザーのプライベートなデータを操作する場合は、そのユーザーから許可をもらう必要がある。

OAuthとはユーザーからデータへのアクセス許可をもらう手順を規格化したもので、YouTube Live Streaming APIを使う場合はOAuthのシーケンスを踏まなければならない。

そのための事前準備として「OAuth同意画面」と「OAuth クライアントID」が必要になる。

クライアントライブラリのインストール

APIの呼び出しやOAuthのシーケンス実装はクライアントライブラリを使うと実装が簡単になる。

Client Libraries  |  YouTube Live Streaming API  |  Google Developers

クライアントライブラリのサンプルもリポジトリにまとまっている。

github.com

こちらはOAuthのクライアントライブラリ(Python) Documentationに簡単なサンプルコードがある。

github.com

クライアントライブラリとしてPythonを使うなら下記コマンドで必要なライブラリがインストールできる。

pip install google-api-python-client
pip install google-auth-oauthlib

リソースについて

YouTube Live Streaming APIで操作できるリソースはいくつかあるが基本的なものとしてLiveBroadcastsLiveStreamsがある。

LiveBroadcastsリソース

YouTubeライブ配信のイベントを作成するときに最初に「配信タイトル」や「公開範囲」「配信スケジュール」などを設定できるが、これらの操作はLiveBroadcastsリソースを使うことで実現できる。

LiveStreamsリソース

YouTubeライブ配信するときには、どんな映像を取り扱うかを決めなければならない。(映像の解像度、フレームレート、エンコーダからYouTubeサーバーへ送信するプロトコルなど)これらの設定をするために使用するのがLiveStreamsリソースになる。
この他にLiveStreamsリソースを使うとエンコーダの配信先URLやストリームキー(APIの中ではストリームネームと呼ばれている)を取得することができる。

各リソースがどんなメソッドがあるのかはReferenceに詳細に書かれている。

developers.google.com

Pythonのクライアントライブラリ関連のメソッド(YouTube Data API v3)は下記ページで見れる。(liveBroadcastsとliveStreamsというメソッドがある。)

http://googleapis.github.io/google-api-python-client/docs/dyn/youtube_v3.html

APIを使った配信イベントの作成から配信までのシーケンス

APIを使ってどうやって配信イベントを作成して、配信するのかは下記URL先で解説されている。 developers.google.com

ここの内容を参考に自分でシーケンスを組んでみた。 github.com

コードの大まかな流れとしては、下記のようなことをやっている。

  1. OAuthシーケンスでユーザーからデータアクセスの許可をもらう。
  2. liveBroadcastsinsertメソッドでライブイベントを作成する。
  3. liveStreamsinsertメソッドでストリームを作成する。
  4. 2で作ったイベントと3で作ったストリームをliveBroadcastsbindメソッドで紐づける。
  5. liveStreamslistメソッドでエンコーダの配信先URL(cdn.ingestionInfo.ingestionAddress)とストリームキー(cdn.ingestionInfo.streamName)を取得する。
  6. 配信者は配信ソフトと5で取得したURLとストリームキーを使って配信する。(これはAPIの操作ではなく手動で行うこと)

この後はtransitionメソッドを使ってliveBroadcastsのステータス(status.liveCycleStatus)を適切に変更してライブ配信をしている。

参考情報

YouTube Live Streaming API

OAuth/OpenID Connect

こちらの動画ではOAuth2.0とOpen ID Connectについてとてもわかりやすく解説されている。 www.youtube.com

Go JSONをデコードする

jsonパッケージのDecoderを使うとjsonデータを簡単にパースできる。

golang.org

目次

変数名で対応付ける

package main

import (
    "encoding/json"
    "fmt"
    "strings"
)

func main() {
    json_data := `{"name":"Gopher","age":10}` // 入力のjsonデータ

    var d struct {
        Name string
        Age  int
    }

    decoder := json.NewDecoder(strings.NewReader(json_data))
    err := decoder.Decode(&d)
    fmt.Println(d.Name, d.Age, err) // Gopher 10 <nil>
}
  • json.NewDecoder()json.Decorderのポインタを取得する。
  • json.NewDecoder()にはio.Readerを渡さなければいけないので、strings.NewReader(json_data)strings.Readerのポインタを渡すようにしている。

jsonデータのkeyとデータを受け取る変数名が同じだと自動で対応づけてくれるらしい。最後の出力は見ると"name":"Gopher"はNameに、"age":10はAgeに紐づいているのが分かる。
名前の対応付がうまくいかないと値が代入されない。試しに入力のjsonのageをheightに変更すると値が代入されないことがわかる。

func main() {
    json_data := `{"name":"Gopher","height":10}` // ageをheightに変更した

    var d struct {
        Name string
        Age  int
    }

    decoder := json.NewDecoder(strings.NewReader(json_data))
    err := decoder.Decode(&d)
    // d.Ageに値が代入されない
    fmt.Println(d.Name, d.Age, err) // Gopher 0 <nil>
}

構造体タグ名で対応付ける

構造体にタグ(json:xxx)を付与するとDecorderがそのタグを使って値を代入してくれる。

func main() {
    json_data := `{"Name":"Gopher","Age":10}`

    var d struct {
        ID  string `json:"name"` // jsonのkeyと同じ名前を指定する
        Old int    `json:"age"`  // 大文字小文字は区別されない
    }

    decoder := json.NewDecoder(strings.NewReader(json_data))
    err := decoder.Decode(&d)
    // jsonのkeyと変数名は一致していないが値が代入されている
    fmt.Println(d.ID, d.Old, err) // Gopher 10 <nil>
}

json:がないと値の代入はうまくいかない。

// これだとうまく代入されない
var d struct {
    ID  string `"name"`
    Old int    `"age"`
}

一致しないkeyがあったときにエラーを返す

これまでの例では一致しないkeyがあってもerr := decoder.Decode(&d)でerrが帰ってくることはなかった。(常にnilだった)
decoder.DisallowUnknownFields()をコールしておくと、一致しないkeyがあったときにerrorとして返してくれる。

func main() {
    json_data := `{"name":"Gopher","age":10}`

    var d struct {
        Name string `json:"name"` // ageに対応する変数がない
    }

    decoder := json.NewDecoder(strings.NewReader(json_data))
    decoder.DisallowUnknownFields() // Decodeの前にコールする
    err := decoder.Decode(&d)
    // ageに対応する変数がなかったのでerrが返ってくる
    fmt.Println(d.Name, err) // Gopher json: unknown field "age"
}

Go シグナルをチャネルで受け取る

golang.org

signal.Notifyを使うとチャネルでシグナルを受け取ることができる。

package main

import (
    "fmt"
    "os"
    "os/signal"
    "syscall"
)

func main() {
    sigCh := make(chan os.Signal)
    signal.Notify(sigCh, syscall.SIGINT)

    s := <-sigCh

    fmt.Printf("%T %v\n", s, s)
}

Ctrl+cを実行するとSIGINTがプログラムに通知されてsの中に受信したシグナルが代入される。

signal.Notifyを使ってシグナルをプログラムで受信することで、(シグナルを受信した後に何かしらの停止処理を実行することで)プログラムを安全に止めることができる。

シグナル受信用のチャネルの型はos.Signalを使う。os.Signalはインターフェースとして定義されている。

type Signal interface {
    String() string
    Signal() // to distinguish from other Stringers
}

os - The Go Programming Language

通知されるシグナルは全てこのインターフェースを実装している。

通知されるシグナルはsyscallに定義されていて、signal.Notifyで受信したいシグナルを指定するときはここにある定義を使う。

syscall - The Go Programming Language

const (
    SIGABRT   = Signal(0x6)
    SIGALRM   = Signal(0xe)
    SIGBUS    = Signal(0x7)
    SIGCHLD   = Signal(0x11)
    SIGCLD    = Signal(0x11)
    SIGCONT   = Signal(0x12)
    SIGFPE    = Signal(0x8)
    SIGHUP    = Signal(0x1)
    SIGILL    = Signal(0x4)
    SIGINT    = Signal(0x2)
    ・
    ・
    ・

↑ 定義を見るとSignal型でキャストされて整数を設定している。ここのSignalos.Signalではなく、syscall.Signalである。

syscallで定義されているSignalは以下のようになっている。

type Signal int

これに加えて、os.Signalインターフェースを実装している。なので先にあったように受信する型をos.Signalに指定すればsyscall.Signalをそのまま受け取ることができる。