今回はNLP界でも話題沸騰中のトランスフォーマーの仕組みについて解説していこうと思います。Transformerの前に、Attentionの技術の理解が必要で、Attentionがどんな物か簡単に(人のYoutubeビデオで)紹介し、そのあとTransformerの技術説明をしていこうと思います。TransformerはDeep Learning 研究の中でも非常に大きなブレイクスルーを起こした技術の一つです。Transformerの派生版である GoogleのBERTやGPT-3など聞いたことがあると思いますが、翻訳のレベルが数段上がりました。
本ブログではできるだけ数式を使わず、簡単に説明したいと思っているのですが、所々に難しいところがあると思います。疑問点など有りましたらご連絡をいただければ幸いです。
Attentionとは
Transformerを学ぶ前にTransformerの目玉の一つでもあるAttentionという技術はどういう技術か紹介します。
「Attention」は注目という意味となります。つまるところAttentionとは、入力の注目する部分を見つける!ということになります。物を学習する時に、画像処理で言えば例えば犬だけに注目して処理したもの(犬を全景処理で切り出したもの)と、画像全体で処理したものを使っての学習では、効率や結果が変わりそうです。この犬だけに注目したい!というのがAttentionという技術となります。この技術は文章にも適用ができ、例えば翻訳の部分では翻訳する際にどこにまず注目した方が良いのか、や、文章の単語同士の関わりの指標みたいなのを抽出することが出来るようになるのです。
言葉で説明してもわかりにくいかもしれません。この技術をYoutubeで15分で紹介している方が居ました。素晴らしく解りやすいので、是非ご覧になってください。
Attentionは対象物のみに集中する、という意味で逆に言えばそれ以外を無視するという雑な言い方も出来ます。実際に我々の脳でも同じことをやっています。例えば集合写真で人数を数えてくれって言われたら、頭の数を数えてほかは無視するようにすると思います。このように、人間と同じようなことをするという方法がAttentionという方法です。
Seqence to sequence with the Attention
さて、Seq2Seqモデルで翻訳を前回の記事では実装しました。Seq2Seqは素晴らしい技術ですが長い文章の翻訳が弱いという弱点があります。それはエンコーダーからデコーダーへ渡されるものがContextベクトルだけであるため、例えば翻訳の最初の部分の表現が弱くなってしまう可能性があるという問題点があったためです。Seq2Seqのモデルはわかりやすく表すと下図のような仕組みとなっていました。なおSeq2Seqに関しては前回の記事を参照ください

Attentionはこうした問題に対応するべく産まれました[ 2 Bahdanau et al, 2015 ]。Seq2Seqで実際に、どのようなことをAttentionを実現しているかと言えば、雑に言うと、seq2seqのエンコードの部分で利用したHiddenベクトルを、まとめてデコーダー部分にわたしている、というのが特徴になります。上のSeq2Seqのモデルと比較してみてください。

Seqence to seqence でAttentionの実現には、デコーダー内でまとめて渡されたHidden ベクトルの内どれが重要であるかを単語単位で調べていき、重み付けをしてSoftmaxし、加重平均しています。詳細な流れですが長くなるので説明しません。というのも、その方法よりもTransformerの方が精度があがっているため、いまさらSeq2SeqのAttentionの実装方式にフォーカスするよりは、Transformerを学んだほうが有用であるためです。(なにより、Transformerの説明の記事なので)
Seq2Seq with Attentionの実現方法については、リクエストが有れば別の機会にご紹介できればと思います。なお、Seqence to sequence のAttentionの実装方法も奥が深く、Global, local attentionなどの手法も提案されています。
Transformer
本題のトランスフォーマーです。トランスフォーマーが産まれたのは、雑に言えばSeq2Seq+Attentionの並列化をして高速に学習したかったからです。
並列化と聞いてまず思い浮かぶのがCNNなどコンボリューションで、並列処理が得意なGPUに向けて行列演算に落とし込めば並列化が見込めます。単純にはSeq2Seqはバケツリレー形式で情報を運んでいる(前の入力に対しての処理をしている)ので並列化が叶いません。
そこで、コンテキストを送ってAttentionを計算していたところから、Encoder,Decoder各層内でself-attention層をつくり、行列演算を実現することで並列化を目指します。
Attentionという名前がついていますが、今までのSeq2Seqで扱っていた方式と全く別なので注意してください。Seq2Seqでは全体でAttentionを実現しているイメージでしたが、Transformerでは層という単位に落とし込んで実現しています。下記に簡単にTransformerの全体図を示します。

右側DecoderのOutputでは最後にSoftmax処理が入る。Encoder,Decoderで、Self-attention層があり今までのSeq2Seqとは様相が異なる。
上図がTransformerの全体図になります。Encoder部とDecoder部の2部構成はSeq2Seqと同じです。Encoder,Decoderとも多段で接続(オリジナルの論文では6段)で構成しています。
Transformerでの肝はずばりSelf-attention層となっています。Self-attention層は同じセンテンス内の単語群が、翻訳したい単語にどれだけ影響しているかというものを検出します。
EncoderとDecoderともSelf-attetionやNormalizationなど殆ど同じ層を使いますが、唯一Decoderにしかないものがあり、Encoder-decoder attention層です。後ほど説明しますが、Encoder側からバイパスしてもらう特徴量(最後の段のEncoderから2つの特徴行列(Key, Valueを生成する行列))を用いいてSelf attentionを行います。これにより入力のどのパーツに注目するべきかを考えることができます。なお、全てのDecoder群のEncoder-decoder attention層では、最後のEncoderの特徴(key, valueの生成行列 )を使うことになります。
Transformer 説明の流れ
全体の流れがサラッとわかったところでそれぞれの層がどのような処理をしているかを記載していきます。それぞれの層の説明、特にSelf-attentionはなかなか濃ゆいと思うのでじっくり理解していってください。
Encoder、Decoderそれぞれでは
- 事前処理Embedding
- 事前処理Positional Encoding
- self-attention層( + Encoder-Decoder attention層( decoderのみ ) )
- Normalization層
- Feed forward層
というような処理を順次行っています。それぞれの処理について説明していきます。
Embedding処理
Embeddingは文字の埋め込みを行います。Word2Vecの様な処理です。具体的には例えば512次元のベクトルに単語を変換します。例えば I am an studentであれば、
I = [ 0.1, 0.2, 0.03, 0.78 …. 0.04 ], am = [ 0.41, 06., 0.033, 0.378 …. 0.3 ], an = [ 0.41, 0.52, 0.073, 0.5678 …. 0.34 ], student = [ 0.81, 0.62, 0.023, 0.478 …. 0.54 ]
のように一つ一つの単語に対してベクトルに変換をします。pytrochではEmbeddingは関数一つで実現でき、上記例では512次元のベクトルに変換するのにembedding = nn.Embedding(4, 512)としておき、embeddingを呼ぶことで実現できます。Embeddingの方法やアルゴリズムについては[10]を参照してみてください。要するにここで文字列をベクトル化してやって学習するようにするのです。
Positional Encoding(PE)処理
Positional Encoding (PE)は順序の定義を行います。今回のTransformerでは行列演算により並列化し従来よりも高速に計算をするという目的がありました。並列に行列計算を行うため、文章の単語に対して順序(もとにいた位置)を考慮させる必要があります。Seq2Seqのようにバケツリレーではないので、バラバラにしてももとの位置が解るような情報量を与える必要があります。
順序でいえば、例えば0,1,2,3とindex値を割り振ればいいと思うかもしれませんが、文字数が大きくなると、値も大きなってしまいます。そもそも-1〜1の範囲で正規化して学習するため整数のindex値は相性が悪そうです。じゃあ正規化をすれば良いのではないかと思うのですが、今度は別の問題が発生します。例えば、5文字のセンテンスがあったとして0.8は四番目の文字を意味します。所が10文字の0.8は8番目になります。つまり、同じ値に場所が違う問題が発生します。
では、文字列の次元を固定(例えば512次元)して、割り振ればいいと思うかもしれませんが、ほとんどの文章は20単語くらいでしょう。そうすると、ものすごい無駄な空間(20次元以降の情報は使われることは殆ど無い)となります。
さて、こうした問題を解決したのが余弦、正弦を使った順序の定義です。数式は次のとおり表されます。


上記式を例えば128次元(d_model)と仮定すると、以下のような画像となります。[11]

Embeddingしたinputを、このPositional Encodeing の値と足し込めたものをインプットとして利用します。
x = Embedding( input ) + PositionalEncoding( input )
Self-attention層
肝のSelf attention層です。センテンスにある各単語同士の注目度を算出していきます。注目度の算出は雑に言うと内積で求めていきます。
Self-Attentionでは、Embedding+PEされたベクトルから、Query, Key, Valueという3つのベクトルを作成します。この3つのベクトルを作るためには3つの行列が必要なのですが、Self-attention層ではその3つの行列を学習で作ることが目標です。
注目度である内積値はQuery, Key, Valueのベクトルから求めます。3つのベクトルの次元数は全て同じで、入力のEmbedding+PEされた次元数より少なくなります。例えば、オリジナルの論文では512次元のEmbedding+PEベクトルがあるとすると、Query, Key, Valueベクトルは64次元まで減っています。
さてさて、ここでいうQuery, Key, Valueとはなんでしょうか?
わたしもそうでしたが、Query, Key, Valueって何で3つに分解するんだろう、それぞれの持つ意味は何だろうかと最初は考えては調べていき、最後は思考停止に陥っていました。今は取り敢えず深く考えずに3つに分解する、それを以ってなんとかSelf-attentionを実現する、という考えで読み勧めてください。
上で説明した話の概念図を下に示します。
今ここで, You love me( お前は俺を愛してる)の3ワードのセンテンスを例としましょう。説明のため、各ワードは16次元でembedding+PE されているとします。まずは、16次元のワードを3次元のベクトルに落とし込みます。(3ワードの”3”とは無関係です。短縮次元はいくつでもいいのです。オリジナルの論文では8で割って512次元を64次元にしていました)。それぞれの生成行列を定義します。初期値はランダムです。

上記処理を行ったあとは、ワードに対するそれぞれのquery, key, valueベクトルを用いて内積値の計算を実施します。まず、Youを例に取り、You, love, meのKeyベクトルとValueベクトルを使って結果zを取り出します。計算処理の内容については下記のとおりです。

これが、Self-attention層で行われている処理です。
なおこの処理を複数回行うのがMulti-head attentionという処理になります。今z1だけが出てきましたが、z2, z3 …と出力していきます。オリジナルの論文では8つのマルチヘッドに分割していました。最初に512次元から64次元に落とすために8で除算していたので、そのため8なのかもしれません。なお、8つのマルチヘッドで分割していた場合、それに伴って生成行列も8個ずつ出来ます。
出力されたz1 .. znを連結(concat)して、もとの入力次元に戻す(linear)作業をすることで、最終的な処理は終わりになります。例えば、今回16次元の入力から3次元になりました。8つのヘッドを使ってConcatするとz1 .. z8までの合計で24次元になります。それを入力次元の16次元に戻す処理、つまり射影マトリックスをかける処理(nn.Linear)をすることで処理が終わります。
ノーマライゼーション層(+ residual の説明 )
標準化、正規化、規格化など、呼ばれるノーマライゼーションにより安定した学習が出来るようになります。中で実装している内容ですが、値を平均から引き、標準偏差で割るような操作をしてやります。平均が0で、分散が1になります。まさにノーマライゼーションですね。
言葉だとわかりにくいので、どういう処理をしているのか、ある人[6]がソースコードを書いていました。参考にしてみてください
class Norm(nn.Module): def __init__(self, d_model, eps = 1e-6): super().__init__() self.size = d_model # create two learnable parameters to calibrate normalisation self.alpha = nn.Parameter(torch.ones(self.size)) self.bias = nn.Parameter(torch.zeros(self.size)) self.eps = eps def forward(self, x): norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \ / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias return norm
residual(残差処理) の説明
残差処理といいますが、Encoderを構築する際に、Multi header attention で出てきた出力値に、元々の入力Inputを足し合わせます。そしてNormalization処理をします。ソースコード風で簡単に表すと
y = Normlization( MulitHeadderAttention( input ) + input )
これを残差処理といいます。下記のオリジナルの論文を見ると、Add&Norm内で入力値を持ち込んでいるので残差処理をしているのがわかります。

Feed Forward 層
Feed Fowardでは、2層の全結合層からなるニューラルネットワークを実装しています。また、Dropoutも定義することで過学習を抑制します。Dropoutに関しては、次の記事が明るいです。[14]
数式ではつぎのようなことを行っています。
FFN(x)=max(0,xW1+b1)W2+b2
図では次のとおりです。いえば、ただのニューラルネットワークです。

ソースコードだと次のような単純な処理をしているだけです。
class FeedForward(nn.Module): def __init__(self, d_model, d_ff=2048, dropout = 0.1): super().__init__() # We set d_ff as a default to 2048 self.linear_1 = nn.Linear(d_model, d_ff) self.dropout = nn.Dropout(dropout) self.linear_2 = nn.Linear(d_ff, d_model) def forward(self, x): x = self.dropout(F.relu(self.linear_1(x))) x = self.linear_2(x) return x
ここまでで、Encoderで使われている全てのパーツの説明がおわりました!実はEncoderもDecoderもほとんど同じパーツを使います。唯一,Encoder-decoder attention層だけがちがうのですが、その解説を行います。
Encoder-DecoderAttention
改めてTransformerの全体図をみてみると、次のようになっています。

ブロックを見てみると殆どの解説が終わっている部分ばかりです。唯一Encoder-DecoderAttentionだけ説明をしていませんでした。ただ、Encoder-Decoder Attentionは次の画像がとてもわかり易いかと思います。

やっていることはSelf-attentionと一緒です。ただし、Encoder-Decoder層では、最後のEoncoderパートで使われたValue, Keyベクトルを生成するマトリックスを渡し、Self-attention層で計算した方法と全く同じロジックで計算するのです。これだけとなります。
すべてのパーツの説明が終わりました。後は実装してみましょう!
実装
実装については、ブログの公式Githubにアップロード予定です。
最後に言い訳…
2021年頃にTransformerの記事が書きたいと思って、Twitterで書くと公言していました。しかし、コロナ渦でてんやわんやだったり、そもそもTransformerのボリュームが有りすぎて、画像等の用意を考えると時間的になかなか書けませんでした。そうこうしている間に2022年では素晴らしい記事が出てきてしまい、書くのは意味ないと思ってしまい、全く手が進みませんでした。しかし最近改めて日本語のサイトを見ると、日本語では全体的に網羅されていない、数学的すぎる、簡潔すぎる、論文の解説に集中している、というような本ブログのようなまとめ解説は有用なのかもしれない、と思い、思い切って書いてみました。皆さんの勉強の参考になってくれれば幸いです。
参考文献
[1] https://jalammar.github.io/visualizing-neural-machine-translation-mechanics-of-seq2seq-models-with-attention/
[2] https://towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853
[3] https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html
[4] https://jalammar.github.io/illustrated-transformer/
[5] https://mchromiak.github.io/articles/2017/Sep/12/Transformer-Attention-is-all-you-need/#.XIWlzBNKjOR
[6]https://towardsdatascience.com/how-to-code-the-transformer-in-pytorch-24db27c8f9ec
[7]https://machinelearningmastery.com/a-gentle-introduction-to-positional-encoding-in-transformer-models-part-1/
[8]https://towardsdatascience.com/transformers-explained-visually-part-2-how-it-works-step-by-step-b49fa4a64f34
[9]https://towardsdatascience.com/how-to-code-the-transformer-in-pytorch-24db27c8f9ec
[10]https://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html
[11]https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
[12]https://medium.com/swlh/elegant-intuitions-behind-positional-encodings-dc48b4a4a5d1
[13]https://data-science-blog.com/blog/2021/04/07/multi-head-attention-mechanism/
[14]https://medium.com/axinc/dropout%E3%81%AB%E3%82%88%E3%82%8B%E9%81%8E%E5%AD%A6%E7%BF%92%E3%81%AE%E6%8A%91%E5%88%B6-be5b9bba7e89
[15] https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
[16]https://kikaben.com/transformers-encoder-decoder/