DeepMindのResearch Scientistの方がツイートしていたMulti Head Attentionのスレッドの紹介。 全部で12個。英語だけど、日本語に翻訳すれば10分くらいで読めるし、コードサンプルと図もあって短い時間でMHAの概要が掴めると感じた。
Transformers are arguably the most impactful deep learning architecture from the last 5 yrs.
— Misha Laskin 🇺🇦 (@MishaLaskin) 2022年1月7日
In the next few threads, we’ll cover multi-head attention, GPT and BERT, Vision Transformer, and write these out in code. This thread → understanding multi-head attention.
1/n pic.twitter.com/1rYmPtWGFV
以下は、自分が理解した内容のメモ書き。
2/n. Attentionとは
例として、sentiment analysisをしたいとする。「Attention is not too shabby.」 shabbyはネガティブを示唆しているけれど、not shabbyであればネガティブではなくポジティブ。正しく分類するためには、文中の全ての単語を考慮しないといけない。
3/n. 全ての単語を考慮する
全ての単語を考慮する最もシンプルな方法は、全ての単語をネットワークに入力すること。これで十分か?というと、そうではない。各単語を考慮するだけでなく、他の単語との関係も理解しないといけない。つまり、notはshabbyに注意を向けているということが重要。そこで出てくるのが、query, key, value(Q,K,V)。
4/n. Value
単語を線形層に通し、そこから得られた出力をValueと呼ぶ。
words.shape # (T,in_dim) values = Linear(words) values.shape # (T, D)
では、Value同士の関係をどの様にエンコードするか?それぞれのValueを混合(sumをする)することで、関係を見ることができる。ただし、これには問題がある。
out = np.ones((T, T)) @ values # (T, T) @ (T, D) out.shape # (T, D)
5. 問題
単純な総和の問題は、全ての関係が等しいと想定していること。isとtoo、notとshabbyでいえば、明らかに後者の方が感情分類に重要。
6. QueryとKey
全てが等しい関係ではなく、word_jに対してword_iがどれだけ有用かを表す様にしたい。そこで、Valueと同じように単語を線形層に通して得られるQueryとKeyを導入する。QueryとKeyから求めた重みWeightsにおいて、w_ijはi番目のQueryに対して、j番目のKeyの間の内積に比例した値になっているはず。
queries.shape # (T, D) keys.shape # (T, D) weights = queries @ keys.T # (T, T)
7. Rescaling
QueryとKeyから得られたWeightsの各要素をでrescale。
8. Single Head Attention
Weightsの列ベクトルについてSoftmaxを適用して正規化する。直感的には、Qは単語Kに対してのどの程度有用かという質問で、内積が高いということは非常に有用、逆に内積が低い時はあまり役に立たないということ。これがAttention。
attention = softmax(Q @ K.T / sqrt(D), dim=1) @ Value # [T, D] = [T, T] @ [T, D]
9. Single Head Self Attention
Technically what we’ve shown is called single-head self-attention. Before going to multi-head attention, let’s code up what we’ve done so far.
— Misha Laskin 🇺🇦 (@MishaLaskin) 2022年1月7日
9/n pic.twitter.com/FlBouHYvsr
10. Why Multi Head
なぜMulti Headなのか?Single Headだと学習データにオーバーフィットするかもしれない。過学習対策の一般的な戦略であるアンサンブルで、複数のAttentionによりロバストな結果を獲得する。(Multi Head Attentionは、Single Head Attentionの[T, D]をN個連結したもので、[T, NxD])
11. Multi Head Self Attention
forwardの入力xは、[T, D]をN個分concatした[B, T, ND](ただし、Bはバッチサイズで、ND = C)と想定している。
So multi-head is just a small tweak to single-head attention. In practice, we also add dropout layers to further prevent overfitting and a final linear projection layer. This is what a complete vectorized multi-head self-attention block looks like in PyTorch.
— Misha Laskin 🇺🇦 (@MishaLaskin) 2022年1月7日
11/n pic.twitter.com/77OXUYOwsb