stMind

about Tech, Computer vision and Machine learning

Multi Head Attentionの概要を掴む

DeepMindのResearch Scientistの方がツイートしていたMulti Head Attentionのスレッドの紹介。 全部で12個。英語だけど、日本語に翻訳すれば10分くらいで読めるし、コードサンプルと図もあって短い時間でMHAの概要が掴めると感じた。

以下は、自分が理解した内容のメモ書き。

2/n. Attentionとは

例として、sentiment analysisをしたいとする。「Attention is not too shabby.」 shabbyはネガティブを示唆しているけれど、not shabbyであればネガティブではなくポジティブ。正しく分類するためには、文中の全ての単語を考慮しないといけない。

3/n. 全ての単語を考慮する

全ての単語を考慮する最もシンプルな方法は、全ての単語をネットワークに入力すること。これで十分か?というと、そうではない。各単語を考慮するだけでなく、他の単語との関係も理解しないといけない。つまり、notはshabbyに注意を向けているということが重要。そこで出てくるのが、query, key, value(Q,K,V)。

4/n. Value

単語を線形層に通し、そこから得られた出力をValueと呼ぶ。

words.shape # (T,in_dim)
values = Linear(words)
values.shape # (T, D)

では、Value同士の関係をどの様にエンコードするか?それぞれのValueを混合(sumをする)することで、関係を見ることができる。ただし、これには問題がある。

out = np.ones((T, T)) @ values # (T, T) @ (T, D)
out.shape # (T, D)

5. 問題

単純な総和の問題は、全ての関係が等しいと想定していること。isとtoo、notとshabbyでいえば、明らかに後者の方が感情分類に重要。

6. QueryとKey

全てが等しい関係ではなく、word_jに対してword_iがどれだけ有用かを表す様にしたい。そこで、Valueと同じように単語を線形層に通して得られるQueryとKeyを導入する。QueryとKeyから求めた重みWeightsにおいて、w_ijはi番目のQueryに対して、j番目のKeyの間の内積に比例した値になっているはず。

queries.shape # (T, D)
keys.shape # (T, D)
weights = queries @ keys.T # (T, T)

7. Rescaling

QueryとKeyから得られたWeightsの各要素を \sqrt{D}でrescale。

8. Single Head Attention

Weightsの列ベクトルについてSoftmaxを適用して正規化する。直感的には、Qは単語Kに対してのどの程度有用かという質問で、内積が高いということは非常に有用、逆に内積が低い時はあまり役に立たないということ。これがAttention。

attention = softmax(Q @ K.T / sqrt(D), dim=1) @ Value # [T, D] = [T, T] @ [T, D]

9. Single Head Self Attention

10. Why Multi Head

なぜMulti Headなのか?Single Headだと学習データにオーバーフィットするかもしれない。過学習対策の一般的な戦略であるアンサンブルで、複数のAttentionによりロバストな結果を獲得する。(Multi Head Attentionは、Single Head Attentionの[T, D]をN個連結したもので、[T, NxD])

11. Multi Head Self Attention

forwardの入力xは、[T, D]をN個分concatした[B, T, ND](ただし、Bはバッチサイズで、ND = C)と想定している。