Sequence To Sequence( Seq2Seq )

Sequence to sequence ( Seq2Seq )の技術を紹介します。

機械学習界隈ではブレイクスルーな技術としてGANVAEが話題となりました。今回解説するSeq2Seqもブレイクスルーといわれる技術の1つです。

近年、劇的に日本語の翻訳技術が向上しました。その技術の背景にはSeq2Seqが使われています。Seq2Seqの得意分野は翻訳もそうですが、自動字幕技術、あるいはチャットボットなどのQuestions answeringの技術であり、連続値などのシーケンシャルなデータの取扱が得意です。

Seq2SeqはGoogleにより2014年に発表されました。Googleは2015年頃の論文でSpeech recognition[1]やビデオ字幕[2]で劇的に精度が向上したと発表しています。

今回はそんなSeq2Seqの技術を解説し、最後にPyTorchでの実装例を紹介します。

Sequence to Sequence の構造

Seq2Seqは2つのパートに別れます。EncoderとDecoderです。形的にはAutoEncoder(例えばVAE)に少し似ています。ただ、中の実装方式がまるで違います。

Seq2SeqはRNNを利用しているため時系列データに強いという所が特徴です。 そのため翻訳や音声認識の分野で力を発揮しているのです。

Encoder&Decoderの2つのパートに分かれる。入力xをいれ、出力yを受け取る。SOSはStart of String, EOSはEnd of stringを表す。Encoder部分で出力されるものは無視される。Decoder部分では、入力として解答をいれ、出力としてもその回答が出るようにする。2つのパートに分かれるが、厳密には真ん中のContextを含めた3パートであるのが正しい。

Encoder、ならびにDecoderの内部では横矢印でデータをそれぞれ前段のデータを数珠つなぎに渡しています。RNNの特徴で、学習時に前段の特徴をわたすことにより時系列データに強くなるのです。

Seq2Seqを理解するに辺りRNNの概念の理解が必要です。本ブログのRNNの記事の概念の部分を読んで把握しておいてください。

それではEncoderとDecoderを実際にどういう仕組みや役割を担っているのか見ていきます。

EncoderとDecoder

Seq2Seqで作られた翻訳エンジンを例にEncoderとDecoderの役割を解説します。翻訳とは「こいつは犬。」という入力に対し、「 This is a dog.」と解答が得ることと定義します。

Encoder部は、日本語で「こいつは犬。」を入力するパートになります。

Decoder部は 、「This is a dog.」という入力、並びにEncoderから出る横矢印の隠れ層のデータ(Context)を受け取って、入力と同じ「This is a dog.」という出力を得るパートです。やや複雑ですが、下の図を見て、理解をしてください。

Encoderでは日本語を入力として受け取っている。入力は単語単位で行う。単語数は素子数を上回ることは出来ない。Decoderでは入力として英単語、ならびにEncoderからのContextを受け取る。出力は1段ずつ遅れて出力するモデルを考える。

上の図を見ると面白いことに気づきます。「こいつ は 犬」と3つ単語に対し、「This is a dog」と4つの単語が出力されます。 Seq2Seqの特徴ですが入力数と出力数が一致していなくてもいいのがユニークなところです。上の図の説明ではEncoder&Decoderともに5つの素子で構成しています。実際には最大の文章の単語数以上の素子を用意しておくことになります。

Decoder部分に注目してみます。入力値として英語の答えを入力し、出力値として、一弾ずつずれて同じものを期待するように設計しています。

<SOS>,< EOS>はStart of string, End of stringの略ですが、学習時にはこれをつけて学習していきます。これを伝えることで、文の始まりと終わりを伝えるのです。

Encoderの出力値については捨てられます。ただし、使うことに依り精度を上げる方法(+Attention法)もあります。アルゴリズムの選択により使ったり使わなかったりするのです。

さて、DecoderとEncoderをつなげるのは、Encoderで学習したHiddenベクトル(Context)となります。

以上が簡単な説明になります。世界を圧巻したSeq2Seqですが意外と簡単な構造をしていたんだな、と思うと思います。それではいよいよPyTorchで実装していきます。

実装

今回はPytorchの公式のSeq2Seqを参考にソースコード解説をします。本家はやや説明に冗長なコードがありますので、Seq2seqを理解するためだけのコードにしました。

下準備(学習データ)

学習には次のファイルを使いましょう。

実装する上では学習データを用意しないと学習できません。残念ながらPyTorchでは標準で日本語データサポートしていないので、他社サイトからデータを取得します。

今回はこちらのサイトからデータを取得しました。流れとしては、そして日本語と英語の2つのファイルに分けました。日本語は英語のようにスペースで分けられていないので、分かち書き(形態素解析)によって分割しました。日本語には半角全角といった表記ゆれもあるのでそうしたものを正規化処理します。具体的にはそれぞれMecab, unicodedata.normalizeなどを使うのですがその辺りは今回のseq2seq技術と全く関係ないのでここでは説明しません。

ファイルの1行1行の日本語と英語が1対1対応しています。

他の言語でもテストしてみたい場合は、こうした学習データを作り、独自に学習させてみてください。

また、import文と、言語の処理クラスLangをインポートします。Lang クラスではwordをindex化したりするクラスです。

import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

SOS_token = 0
EOS_token = 1

device = "cuda" # torch.device("cuda" if torch.cuda.is_available() else "cpu")                                                                                                                                                                                                                                                                                                                                                                                                                                                                  

class Lang:
    def __init__( self, filename ):
        self.filename = filename
        self.word2index = {}
        self.word2count = {}
        self.sentences = []
        self.index2word = { 0: "SOS", 1: "EOS" }
        self.n_words = 2  # Count SOS and EOS                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   

        with open( self.filename ) as fd:
            for i, line in enumerate( fd.readlines() ):
                line = line.strip()
                self.sentences.append( line )
        self.allow_list = [ True ] * len( self.sentences )
        self.target_sentences = self.sentences[ :: ]

    def get_sentences( self ):
        return self.sentences[ :: ]

    def get_sentence( self, index ):
        return self.sentences[ index ]

    def choice( self ):
        while True:
            index = random.randint( 0, len( self.allow_list ) - 1 )
            if self.allow_list[ index ]:
                break
        return self.sentences[ index ], index

    def get_allow_list( self, max_length ):
        allow_list = []
        for sentence in self.sentences:
            if len( sentence.split() ) < max_length:
                allow_list.append( True )
            else:
                allow_list.append( False )
        return allow_list

    def load_file( self, allow_list = [] ):
        if allow_list:
            self.allow_list = [x and y for (x,y) in zip( self.allow_list, allow_list ) ]
        self.target_sentences = []
        for i, sentence in enumerate( self.sentences ):
            if self.allow_list[ i ]:
                self.addSentence( sentence )
                self.target_sentences.append( sentence )

    def addSentence( self, sentence ):
        for word in sentence.split():
            self.addWord(word)


    def addWord( self, word ):
        if word not in self.word2index:
            self.word2index[ word ] = self.n_words
            self.word2count[ word ] = 1
            self.index2word[ self.n_words ] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1
def tensorFromSentence( lang, sentence ):
    indexes = [ lang.word2index[ word ] for word in sentence.split(' ') ]
    indexes.append( EOS_token )
    return torch.tensor( indexes, dtype=torch.long ).to( device ).view(-1, 1)

def tensorsFromPair( input_lang, output_lang ):
    input_sentence, index = input_lang.choice()
    output_sentence       = output_lang.get_sentence( index )

    input_tensor  = tensorFromSentence( input_lang, input_sentence )
    output_tensor = tensorFromSentence( output_lang, output_sentence )
    return (input_tensor, output_tensor)

Encoder

Encoderの実装は次のとおりです。

class Encoder( nn.Module ):
    def __init__( self, input_size, embedding_size, hidden_size ):
        super().__init__()
        self.hidden_size = hidden_size
        # 単語をベクトル化する。1単語はembedding_sie次元のベクトルとなる                                                                                                                                                          
        self.embedding   = nn.Embedding( input_size, embedding_size )
        # GRUに依る実装.                                                                                                                                                                  
        self.gru         = nn.GRU( embedding_size, hidden_size )

    def initHidden( self ):
        return torch.zeros( 1, 1, self.hidden_size ).to( device )

    def forward( self, _input, hidden ):
        # 単語のベクトル化                                                                                                                                                                                                        
        embedded        = self.embedding( _input ).view( 1, 1, -1 )
        # ベクトル化したデータをGRUに噛ませる。通常のSeq2Seqでは出力outは使われることはない。                                                                                                                                     
        # ただしSeq2Seq + Attentionをする場合にはoutの値を使うことになるので、リターンする                                                                                                                                        
        out, new_hidden = self.gru( embedded, hidden )
        return out, new_hidden

Encoderでは文字列をEmbeddingします。Embeddingとは単語をベクトル化することです。例えばDogを5次元にEmbeddingするとするとDog–>[0.9, 0.5 0.4, 0.7, 0.1] のようにすることを意味します。

Embeddingは実はword2vecを用いたほうが精度が良いようですが、とりあえず今はSeq2Seqの実装とは関係ないので標準のライブラリを使います。精度を上げたい人はこの部分を改良してみても面白いでしょう。

Embeddingされた単語をRNNのネットワークに入れるのですが、PyTorchではRNN系として、nn.LTSM, nn.RNN, nn.GRUというものが既にあり、自分で数珠つなぎのRNN素子を定義してネットワークを書く必要はありません。RNNモジュールの入力次元はEmbeddingする次元になります。

実装ではGRUかLSTMで世の中では良く取り沙汰されています。今回はPyTorchの公式ドキュメントでGRUであったのと、日本語のドキュメントサイトでLSTMが多かったのでGRUで説明します。LSTMで実装したいなどあれば適時、ソースコードを書き換えてみてください。(nn.LSTMは出力がGRUとことなるので注意が必要です。view関数などを使って出力数を変更する必要もあります。)

Decoder

Decoderの実装は次のとおりです。

class Decoder( nn.Module ):
    def __init__( self, hidden_size, embedding_size, output_size ):
        super().__init__()
        self.hidden_size = hidden_size
        # 単語をベクトル化する。1単語はembedding_sie次元のベクトルとなる                                                                                                                                                          
        self.embedding   = nn.Embedding( output_size, embedding_size )
        # GRUによる実装(RNN素子の一種)                                                                                                                                                                                          
        self.gru         = nn.GRU( embedding_size, hidden_size )
        # 全結合して1層のネットワークにする                                                                                                                                                                                      
        self.linear         = nn.Linear( hidden_size, output_size )
        # softmaxのLogバージョン。dim=1で行方向を確率変換する(dim=0で列方向となる)                                                                                                                                                
        self.softmax     = nn.LogSoftmax( dim = 1 )

    def forward( self, _input, hidden ):
        # 単語のベクトル化。GRUの入力に合わせ三次元テンソルにして渡す。                                                                                                                                                           
        embedded           = self.embedding( _input ).view( 1, 1, -1 )
        # relu活性化関数に突っ込む( 3次元のテンソル)                                                                                                                                                                             
        relu_embedded      = F.relu( embedded )
        # GRU関数( 入力は3次元のテンソル )                                                                                                                                                                                       
        gru_output, hidden = self.gru( relu_embedded, hidden )
        # softmax関数の適用。outputは3次元のテンソルなので2次元のテンソルを渡す                                                                                                                                                 
        result             = self.softmax( self.linear( gru_output[ 0 ] ) )
        return result, hidden

    def initHidden( self ):
        return torch.zeros( 1, 1, self.hidden_size ).to( device )

ほとんどEncoderと一緒です。ただRelu活性関数を適用したり、最後に全結合してSoftmax関数を噛ませているところに違いがあります。また、Decoderでは入力値として、前段のEncoderからのHiddenベクトルをもらうところが違いがあります。

メイン関数

EncoderとDecoderを用いたメイン関数は次のとおりです。

def main():
    n_iters       = 75000
    learning_rate = 0.01 * 0.8
    embedding_size = 256
    hidden_size   = 256
    max_length    = 30

    input_lang  = Lang( 'jpn.txt' )
    output_lang = Lang( 'eng.txt')
    # 英単語数がmax_lengthより多い場合は計算しない。(時間がかかるため。)                                                                                                                                                                                                                                                                                                                                  
    allow_list = [x and y for (x,y) in zip( input_lang.get_allow_list( max_length ), output_lang.get_allow_list( max_length ) ) ]
    # allow_listに従って、英語、日本語のファイルをロードする                                                                                                                                                                                                                                                                                                                                                
    input_lang.load_file( allow_list )
    output_lang.load_file( allow_list )
    # Encoder & Decoderの定義                                                                                                                                                                                                                                                                                                                                                                               
    encoder           = Encoder( input_lang.n_words, embedding_size, hidden_size ).to( device )
    decoder           = Decoder( hidden_size, embedding_size, output_lang.n_words ).to( device )
    # Optimizerの設定                                                                                                                                                                                                                                                                                                                                                                                       
    encoder_optimizer = optim.SGD( encoder.parameters(), lr=learning_rate )
    decoder_optimizer = optim.SGD( decoder.parameters(), lr=learning_rate )
    # 学習用のペアデータの作成. He is a dog, 彼は犬だ みたいなペアをエポック数分用意する                                                                                                                                                                                                                                                                                                                    
    training_pairs = [ tensorsFromPair( input_lang, output_lang ) for i in range( n_iters ) ]
    # LOSS関数                                                                                                                                                                                                                                                                                                                                                                                              
    criterion      = nn.NLLLoss()

    for epoch in range( 1, n_iters + 1):
        # 学習用のペア単語の取り出し。                                                                                                                                                                                                                                                                                                                                                                      
        input_tensor, output_tensor = training_pairs[ epoch - 1 ]
        #初期化                                                                                                                                                                                                                                                                                                                                                                                             
        encoder_hidden              = encoder.initHidden()
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        input_length  = input_tensor.size(0)
        output_length = output_tensor.size(0)

        # Encoder phese                                                                                                                                                                                                                                                                                                                                                                                     
        for i in range( input_length ):
            encoder_output, encoder_hidden = encoder( input_tensor[ i ], encoder_hidden )

        # Decoder phese                                                                                                                                                                                                                                                                                                                                                                                     
        loss = 0
        decoder_input  = torch.tensor( [ [ SOS_token ] ] ).to( device )
        decoder_hidden = encoder_hidden
        for i in range( output_length ):
            decoder_output, decoder_hidden = decoder( decoder_input, decoder_hidden )
            # 次の入力野取り出し                                                                                                                                                                                                                                                                                                                                                                            
            decoder_input = output_tensor[ i ]
            # 学習では一定の確率(ここでは50%)で、自身が前に出力した単語を次の入力とする。                                                                                                                                                                                                                                                                                                              
            if random.random() < 0.5:
                # 確率が最も高い単語を抽出                                                                                                                                                                                                                                                                                                                                                                  
                topv, topi                     = decoder_output.topk( 1 )
                # 確率が一番高かった単語を次段の入力とする                                                                                                                                                                                                                                                                                                                                                  
                decoder_input                  = topi.squeeze().detach()

            # Loss関数                                                                                                                                                                                                                                                                                                                                                                                      
            loss += criterion( decoder_output, output_tensor[ i ] )
            # EOSに当たった場合は終わる。                                                                                                                                                                                                                                                                                                                                                                   
            if decoder_input.item() == EOS_token: break
        loss.backward()
        encoder_optimizer.step()
        decoder_optimizer.step()
        # 進捗状況の表示                                                                                                                                                                                                                                                                                                                                                                                    
        if epoch % 50 == 0:
            print( "[epoch num %d (%d)] [ loss: %f]" % ( epoch, n_iters, loss.item() / output_length ) )

流れに関してはコメントアウトにて記載しました。

実装をする上で、言語を管理するLangクラスを定義しました。内容はSeq2seqの技術と関係ないので割愛しますが、Langクラスについては、Githubにあるmain.pyを参考にしてください。

評価関数

学習がきちんとできたか、実際確かめる評価関数は次のとおりです。

def evaluate( sentence, max_length ):
    input_lang  = Lang( 'jpn.txt')
    output_lang = Lang( 'eng.txt' )
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                  
    allow_list = [x and y for (x,y) in zip( input_lang.get_allow_list( max_length ), output_lang.get_allow_list( max_length ) ) ]

    input_lang.load_file( allow_list )
    output_lang.load_file( allow_list )

    hidden_size = 256
    embedding_size =256
    encoder = Encoder( input_lang.n_words, embedding_size, hidden_size ).to( device )
    decoder = Decoder( hidden_size, embedding_size, output_lang.n_words ).to( device )


    enfile = "OUTPUT_FILE_FROM_ENCODER"
    defile = "OUTPUT_FILE_FROM_DECODER"
    encoder.load_state_dict( torch.load( enfile ) )
    decoder.load_state_dict( torch.load( defile ) )

    with torch.no_grad():
        input_tensor   = tensorFromSentence(input_lang, sentence)
        input_length   = input_tensor.size()[0]
        encoder_hidden = encoder.initHidden()

        for ei in range(input_length):
            encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)

        decoder_input      = torch.tensor([[SOS_token]], device=device)  # SOS                                                                                                                                                                                                                                                                                                                                                                                                                                                                  
        decoder_hidden     = encoder_hidden
        decoded_words      = []
        decoder_attentions = torch.zeros(max_length, max_length)

        for di in range(max_length):
            decoder_output, decoder_hidden = decoder( decoder_input, decoder_hidden )

            topv, topi = decoder_output.data.topk(1)
            if topi.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            else:
                decoded_words.append(output_lang.index2word[topi.item()])

            decoder_input = topi.squeeze().detach()
        return decoded_words, decoder_attentions[:di + 1]

if __name__ == '__main__':
    import MeCab
    import unicodedata
    wakati = MeCab.Tagger("-Owakati")
    sentence = 'とても悲しいです.'
    sentence = unicodedata.normalize( "NFKC", sentence.strip() )
    a=wakati.parse( sentence.strip() ).split()
    ret =" ".join( a )

    print(evaluate( ret, 30 ) )

enfile, defileは学習時したデータのPATHを記載します。

注意ですが、hidden_sizeとmax_lengthは学習時と同じ値を使うようにします。

encoder.load_state_dict(torch.load(FILE_PATH))により、実際に学習したデータをロードしています。

結果

60をmax_length、epoch数を150,000として学習させた結果は次の通りです。学習には40分程かかります。(GPUを利用した場合)

海外 に 旅行 に 行き たい .
([‘i’, ‘want’, ‘to’, ‘go’, ‘to’, ‘the’, ‘trip’, ‘.’, ”],

この 映画 は 面白い です か ?
([‘how’, ‘movie’, ‘is’, ‘this’, ‘movie’, ‘?’, ”]

これ は 料理 です .
([‘this’, ‘is’, ‘a’, ‘good’, ‘cook’, ‘.’, ”]

この 机 は 私 の 一番 の お気に入り です .
([‘this’, ‘is’, ‘is’, ‘most’, ‘of’, ‘mine’, ‘.’, ”],

彼 は とても いい 人 です .
([‘he’, ‘is’, ‘a’, ‘good’, ‘person’, ‘.’, ”]

さて、この結果考察をどう思うでしょうか。たかだか40分の学習時間、かつ愚直なデータでこの精度までいきました。正直驚きです。機会翻訳の分野で一生懸命やっていた人は更に驚くのではないでしょうか?

色々と問題があるものの、学習数を多くする、登録単語数を増やす、同じ学習データを何度も流し込む、など色々なアプローチで、明らかな文法ミスに対してペナルティを高くするなどしていくと劇的な変化が見られるのではないかと思います。又学習速度に関してもバッチ化することで高速化が見込めます。

興味がある人は是非トライしてみてください。多くのデータサイエンティスト、或いは機械学習のエンジニアがやっていく作業がこういう泥臭い作業になっていきます。

最後に

今回のソースコードはgithubにあげてあります。[Seq2Seq Github]

Githubでは、本稿では取り上げていない+attention法も実装してあります。

Attention法とは長文になると精度が悪くなるという弱点を補強したアルゴリズムです。そのため、短文で10単語くらいの簡単なものに関しては精度が劇的に上がることはありません。

Seq2Seqでは、EncoderやDecoderを多層にしたり、Bidirectionalにしたり、GRUの代わりにLSTMを使ったり、というような色々な工夫があります。それぞれの方法でどれが良いかについてはここで述べられています[3]

実際に実装してみて、どれが精度が良いのかなど試してみると面白いと思いでしょう。

参考文献

  1. https://www.isca-speech.org/archive/Interspeech_2017/pdfs/0233.PDF
  2. https://arxiv.org/pdf/1505.00487.pdf
  3. https://arxiv.org/pdf/1908.04332.pdf
  4. https://towardsdatascience.com/understanding-encoder-decoder-sequence-to-sequence-model-679e04af4346
  5. https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

おすすめの記事

Variational AutoEncoder( VAE )

近年、ディープラーニング業界、はたまた画像処理業界ではGANとVAEの2つの技術で話題が持ちきりとなりました。GANについては前回解説しました。今回はそんなVAEのアルゴリズムについて解説をしていきます。GAN同様に最後にPyTorchによる実装例を紹介していきます。

VAEとは

VAEは情報を圧縮して圧縮した情報から元の情報に戻す、というような仕組みをもった、AE(Auto encoder)と言われるものの一種です。AEはただ単にデータの圧縮と再構築をするだけでしたが、VAEは確率モデルの理論を導入しています。VAEは確率分布を使ったモデルということは、未知のデータを確率的に作成できることになります。VAEはGenerative model(生成モデル)と言われています。トレーニングデータからProbability Density Function (PDF)を推定するモデルであるためす。 

GANと違って何ができるのか

VAEの一つの特徴は次元削減です。次元削減といえばPCASVDなどを思い浮かべるかもしれません。VAEは同様にLatent space(潜在空間)に次元圧縮し、またそこから復元するということをしています。

PCAのように圧縮した次元において直行している必要はありません。この辺も普通の次元圧縮とのアプローチとは異なっています。

赤はAEの軸。青はPCAの軸。PCAは軸同士は直行しているが、AEに関しては必ずしも直行している必要はない。

AEとVAEは違いがあまり無く感じますが、前述の通り確率を用いたところに違いがあります。AEは点(Single point)としてデータを潜在空間にマッピングしていましたが、VAEでは潜在空間に落とし込むときにガウス分布に従って落とし込ませます。デコーダーはその点を拾ってデコードするため、見方を変えれば、確率的にサンプリングしてzを拾っている、ということになるのです。このzをデコードして元の入力に近づけるようにするのです。

さて、VAEは確率モデリングと言われます。確率モデリングとは、例えば何かのデータxが分布していたとして、その分布を確率で表現するというモデルです。あるデータxが正規分布に従っていたとしたら、データxの確率分布は正規分布となり、P(X)=正規分布( μ, σ )と言う感じで表現します。P(X)と書くと思わずPは関数と勘違いするかもしれませんが、確率を表すものなので注意が必要です。

数式

VAEはデータを確率モデル化をすることを目標とします。データをX、その確率分布をP(X)と定義すれば、P(X)を見つけること、すなわちP(X)の最大化がVAEの目的です。そしてP(X)に関しては次のように表現することができます。[1]

P(X) = \int P(X \vert z) P(z) dz = \int P(X,z)

上記の式はXが入力画像と考え、その画像を表現する潜在ベクトルzを少なくとも1つ見つけたいという意味があります。

また、事前分布P(z)から\{z_i\}_{i=1}^nをサンプルした際に、P(X)を次のように近似できます。

P(X) \approx \frac{1}{n}\sum_{i=1}^n P(x|z_i)

P(X)を求めればいいのですが、一筋縄には行きません。それはXが高次元であることもそうですが、その確率を求めるのにはすべてのzの空間を舐め回すような非常にたくさんのサンプリングが必要なだけでなく、組み合わせをチェックするような処理が必要となり、総当りでP(X)を求めることは現実的ではありません。

そのため、P(X)を求めるには別のアプローチを考えます。もし事後確率のP(z|X)がもとめられれば、p(x|X)=\int P(x|z)p(z|X) dz より未知の画像xを作り出すような確率分布を求めることができそうです。そうなると 事後確率のP(z|X)を求めるだけとなり、簡単に思えますが事後確率はベイズの定理により次の式になります。

P(z|X) = P(X|z)\cfrac{P(z)}{ P(X)}

よく見るとあの厄介なP(X)が分母にあります。やはり求めることができません。

ここで、諦めずに別のアプローチで P(z|X) を求めていくようにしていきます。

P(z|X) を求める

P(z|X)を求めるにあたり少し工夫します。ここでVAEの名前の由来ともなっていますが、 Variational Inference(VI) という手法を使います。P(z|X)を推定するためにQ(z|X)という分布を考えます。Q(z|X)はP(z|X)の近似で、ガウス分布などの簡単な分布関数の組み合わせによりP(z|X)を近似します。

関数の組み合わせにより、近似していく。青の点線へ緑の分布関数で近似していく様子。 https://towardsdatascience.com/bayesian-inference-problem-mcmc-and-variational-inference-25a8aa9bce29

どれだけ似ているのか、という指標については類似尺度としてKLダイバージェンスを使います。同じ分布であれば0に、異なれば値が大きくなっていくような関数です。KLダイバージェンス次のように書くことが出来ます。

KL(P||Q) = E_{x \sim P(x)} log \cfrac{P(x)}{Q(x)}=\int_{-\infty}^{\infty}P(x)\ln \cfrac{P(x)}{Q(x)}dx

ここから式が多く出てくるので、見やすくするために単純にP(z|X )をP, Q( z|X)をQと表示します。Qの近似は次のようにかくことができます。

\begin{aligned} D_{KL}[Q\Vert P] &= \sum_z Q \log (\cfrac{Q}{P}) \\ &=E [ \log (\cfrac{Q}{P}) ] \\ &= E[\log Q - \log P] \end{aligned}

最後はlogの変換公式により、割り算を引き算にしただけです。さてP である P(z|X ) は P(z|X) = P(X|z)P(z) / P(X) のように表現できたので、先の式に当てはめてみましょう。

\begin{aligned} D_{KL}[Q\Vert P] &= E[\log Q - (\log (P(X \vert z) \cfrac{P(z)}{P(X)})] \\ &= E[\log Q- \log P(X \vert z) - \log P(z) + \log P(X)] \end{aligned}

ここでzに関する期待値でくくっている項に注目するとP(X)があります。P(X)はzに関係ありません。そのため括弧の外に出すことができます。

\begin{aligned} D_{KL}[Q\Vert P] &= E[\log Q- \log P(X \vert z) - \log P(z)] + \log P(X) \\ D_{KL}[Q\Vert P] - \log P(X) &= E[\log Q- \log P(X \vert z) - \log P(z)] \end{aligned}

ここからがトリッキーなのですが、右辺に注目するともう一つのKLダイバージェンスを見つけることができます。先の式を両辺にマイナス倍します。

\begin{aligned} \log P(X) - D_{KL}[Q\Vert P] &= E[-\log Q + \log P(X \vert z) + \log P(z)] \\ &= E[\log P(X \vert z) -( \log Q - \log P(z))] \\ &= E[\log P(X \vert z)] -E[( \log Q - \log P(z))] \\ &= E[\log P(X \vert z)]- D_{KL}[Q\Vert P(z) ] \end{aligned}

なんとPとQを近づけようとした結果、本来の目的であったP(X)に関する式がでてきてしまいました。PとQを縮約して記載していたので、正しく展開してかいてみます。

\log P(X) - D_{KL}[Q(z \vert X) \Vert P(z \vert X)] = E[\log P(X \vert z)] - D_{KL}[Q(z \vert X) \Vert P(z)]

本当の目的はP(X)の最大化でした。結局ですが、これを最終的なVAEの目的の関数として定義することになります。

最終的に得た式は非常に興味深い構造になっています。

  1. Q(z|X) はデータXを受取り、潜在空間にzを投影
  2. zは潜在変数
  3. P(X|z)は潜在変数zからデータ生成

つまりQ(z|X)はエンコーダ、zは潜在変数(エンコードされたデータ),そしてP(X|z)はデコーダです。まさにオートエンコーダーそのものとなりました。

算出された式を観察して、左辺と右辺がありますが、左辺にあるlog(p)最大化することが目的でした。そのためには右辺を最大化していくこと、つまり E[\log P(X \vert z)] を大きくして D_{KL}[Q(z \vert X) \Vert P(z)]を小さくしていくことで、左辺は大きくなっていきます。

そのため、今度は目的を右辺の最大化に絞っていきます。

右辺を最大化する。

右辺には以下の2つの式があります。

  1. E[\log P(X \vert z)]
  2. D_{KL}[Q(z \vert X) \Vert P(z)]

右辺=数式1-数式2、ですので数式1を最大化、数式2は最小化していけば、右辺は大きくなり目的が達成されます。

まず 数式 1についてですが、よく見ると潜在変数zを受取りXを生成する教師つきの学習そのものです。ですので、学習によりなんとかなりそうです。

さて、厄介なのが 数式 2です。ここで一つの仮定を起きます。P(z)は正規分布N(0, 1)と仮定するのです。そして、Xからzを生成する分布もパラメータ\mu(X), \sigma(X)付きの正規分布となります。 平均と分散はXを中心としたという意味です。そして、KLダイバージェンスは次のように表されます。

D_{KL}[N(\mu(X), \Sigma(X)) \Vert N(0, 1)] = \frac{1}{2} \, \left( \textrm{tr}(\Sigma(X)) + \mu(X)^T\mu(X) - k - \log \, \det(\Sigma(X)) \right)

kはガウシアンの次元数、traceは対角要素の和を表します。そして、detは対角要素の積\det \left({\mathbf A}\right) = \prod_{i \mathop = 1}^n a_{ii}です。

導出に関しては、ここでは重要でないため割愛しますが最終的には次のようになります。(導出に関して興味ある方は[2]を参照してください。)

D_{KL}[N(\mu(X), \Sigma(X)) \Vert N(0, 1)] = \frac{1}{2} \sum_k \left( \exp(\Sigma(X)) + \mu^2(X) - 1 - \Sigma(X) \right)

この項目を実装するときにはロス関数の中で利用します。実際との差分を計算するためですね。

全ての要素の説明ができました。あとはやるだけです。

実装

エンコーダーとデコーダーの実装は次のとおりです。

エンコーダー

class Encoder( nn.Module ):
    def __init__( self ):
        super().__init__()
	self.common = nn.Sequential(
            nn.Linear( 784, 400 ),
            nn.ReLU(),
            )
	self.model1 = nn.Sequential(
            self.common,
            nn.Linear( 400, 20 )
            )
        self.model2 = nn.Sequential(
            self.common,
            nn.Linear( 400, 20 )
            )
    def forward( self, img ):
	img_flat = img.view( img.size( 0 ), -1 )
        return self.model1( img_flat ), self.model2( img_flat )

デコーダー

class Decoder( nn.Module ):
    def __init__( self ):
	super().__init__()
	self.model = nn.Sequential(
            nn.Linear( 20, 400 ),
            nn.ReLU(),
            nn.Linear( 400, 784 ),
            nn.Sigmoid(),
            )
    def forward( self, z ):
        return self.model( z )

エンコーダーとデコーダーを用いいたVAEを次のように実装していきます。

VAE

class VAE( nn.Module ):
    def __init__( self ):
        super().__init__()
        self.encoder = Encoder()
	self.decoder = Decoder()

    def _reparameterization_trick( self, mu, logvar ):
        std = torch.exp( 0.5 * logvar )
        eps = torch.randn_like( std )
        return mu + eps * std

    def forward( self, _input ):
        mu, sigma = self.encoder( _input )
	z         = self._reparameterization_trick( mu, sigma )
        return self.decoder( z ), mu, sigma

説明ではzの取得は正規分布でのサンプリングを仮定しました。ところが実際にサンプリングするとバックプロパゲーションが出来ないため学習が出来ません。そこでreparameterization trickというのを用いいます。zを次の式で近似します。

z = \mu(X) + \Sigma^{\frac{1}{2}}(X) \, \epsilon

where, \epsilon \sim N(0, 1)

左図はZを表現するために本来のサンプリングで実装した図。左ではバックプロパゲーションができない。そのため、足し算と掛け算に依る表現に正規分布に従うノイズの足しこみを行い誤差伝搬を可能にする。[3]

最後に損失関数とVAEを使ったコードは次のとおりです。

# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014                                                                                                                                                                 
# 入力画像をどのくらい正確に復元できたか?                                                                                                                                                                                        
def VAE_LOSS( recon_x, x, mu, logvar ):
    # 数式では対数尤度の最大化だが交差エントロピーlossの最小化と等価                                                                                                                                                              
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), size_average=False)
    # 潜在空間zに対する正則化項. # P(z|x) が N(0, I)に近くなる(KL-distanceが小さくなる)ようにする                                                                                                                               
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

def main():
    epoch_size = 50
    vae = VAE()
    vae.cuda()
    Tensor = torch.cuda.FloatTensor
    dataloader=get_dataloader()
    optimizer = torch.optim.Adam( vae.parameters(), lr=1e-3 )

    for epoch in range( epoch_size ):
        for i, ( imgs, _ ) in enumerate(dataloader):
            optimizer.zero_grad()
            real_images          = Variable( imgs.type( Tensor ) )
            gen_imgs, mu, logvar = vae( real_images )
            loss                 = VAE_LOSS( gen_imgs, real_images, mu, logvar ).cuda()
            loss.backward()
            optimizer.step()

VAE_LOSSでは得た画像が目的とした式と同じになるような差分を計算しています。

長くなりましたがこれがVAEの全貌です。今回のソースコードはGithubにあげてあります。

https://github.com/octopt/techblog/blob/master/vae/main.py

参考文献

  1. https://en.wikipedia.org/wiki/Law_of_total_probability
  2. https://wiseodd.github.io/techblog/2016/12/10/variational-autoencoder/
  3. https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73

おすすめの記事

敵対的生成ネットワーク(GAN)

GANは画像処理分野でセンセーショナルな話題を巻き起こした技術です。例えば馬をシマウマにしたり、色々な人の顔画像を作ったりすることが事が行えるようになります。画像処理界を相当ザワつかせた技術を今回は解説いたします。(どれくらい話題になったかはこちらをみると解ると思います)

GANは比較的難しい概念&技術です。表面的な解説は他のWEBサイトやQiitaなどでも取り上げられています。今回はオリジナルの論文から数式やアルゴリズムにどういう意味があるのかということについて解説し、最後にPyTorchによる実装例を紹介します。

GANとは

GANには2つのモジュールGenerator(生成者)、Discriminator(識別者)のがあります。これをニューラルネットワークで作っていくのです。説明のために画像処理の技術として話を進めていきます。Generatorは画像を作る、Discriminatorは画像を識別する技術を表します。

例えばシマウマ画像をGeneratorは作る、Discrimanatorは本物のシマウマの画像かを識別する器を作っていきます。

GANは Generative Adversarial Nets の略ですが、 Generative (生成する)、 Adversarial (敵対者)という言葉が入っています。これは、互いに騙し合うモデルを作る、というコンセプトから来ています。絶対に騙そうとする者絶対に判別してやる者という、いわば矛と盾のようなものを学習により作っていくのです。GANはGANsと表現する事がありますが、これは語尾がNetsとなっているからであり、どちらも変わりません。

、騙す側はGeneratorと呼ばれます。後述しますが、学習が進むにつれて相手を騙すほどの画像が作れるようになるのです。GANの目的は素晴らしいGenereatorを作ることです。

、判別側はDiscriminatorと呼ばれます。日本語では判別器、識別機などとも呼ばれています。こちらは最終的に本物か偽物かを判断します。実装での出力値は確率で出てきますが、一番最後の最後にシグモイド関数をかませて0か1か(正か偽か)を出力します。

GANはこうしたGenerator & Discriminatorというコンセプトを用いた学習方法です。今様々なGANがありますが、なぜ沢山あるのかと言えば、 目的や実装方法によって名前が変えるためです。したがって DCGAN, LAPGAN , SRGAN, StackGAN なども全ては広くGANの一種、というような言い方が可能です。 GeneratorとDiscriminator というコンセプトを使って学習していくモデルがGANということなのです。

察しのいい人は気づいたかもしれませんが、盾であるDiscriminatorの学習はGeneratorよりもずっと簡単です。ラベル付き画像の学習であり伝統的なNNそのものとなります。

GANの目的はGeneratorを作ることです。それを忘れないようにしましょう。学習したGeneratorを使うことで人間も騙せる画像を作っていけるのです。

GANの学習の流れ

Generator

Generatorは適当な入力値をランダム値でもらい、ターゲットとなる画像を生成します。100次元程度のランダムな値が入力値として良く利用されます。生成した画像をDiscriminatorに渡し、正か偽かを判定してもらいます。Discriminatorの判定結果を受け、騙せたか騙せていないかの2値からバイナリクロスエントロピーによりロスを計算し誤差伝搬をして、Generator内部のネットワークの重みを更新していきます。

  1. ランダムノイズを作成する
  2. ランダムノイズから画像を作成する
  3. Discriminatorから真か偽かの判定をもらう
  4. Discriminatorからの判定をもとにロスを計算する
  5. Discriminator&Generatorを通して更新すべき重みの値を受け取る
  6. Generatorのみネットワークの重みを更新する
https://towardsdatascience.com/understanding-generative-adversarial-networks-gans-cd6e4651a29

Discriminator

Discriminatorは、Generatorが出力した偽画像と、予め用意してある本物の画像を次々と入力して学習します。最終的にはシグモイド関数を利用してTrue もしくはFalseを判定結果として出力し、正解画像、偽画像が正しく識別できたかどうかを比較します。出力値は2値なのでGenerator同様にloss関数としてbinary cross-entropyを利用して精度をあげていきます。流れとしては次のとおりです。

  1. 本物の画像とGeneratorが作成した偽画像を仕分ける
  2. 真画像を偽と判定した、あるいは偽画像を真と判定したロスを計算し、ペナルティを与える
  3. ネットワークの重みを更新する。

Note(注意)

Discriminatorの学習は簡単です。Generatorから偽画像を生成してもらって、それと正解画像を入力して正答率を上げるだけですので、見てみればよくある典型的なニューラルネットです。

Generatorの学習でユニークなところはDiscriminatorを噛ませて出力しているところです。GeneratorはDiscriminatorの出力値を見ながら、騙せるような画像を作っていくのです。Generatorの学習注意上で重要なのは、誤差伝搬(Back propagation)の際にはDiscriminatorの重みは更新しないようにする事です。ただし、学習時には連結しているのでGeneratorへの重み伝搬の際にはDiscriminator内部も通っていくことになります。

https://www.freecodecamp.org/news/an-intuitive-introduction-to-generative-adversarial-networks-gans-7a2264a81394/

学習はDiscriminatorとGeneratorをループでぶん回して学習していくことになります。エポック数と言われるものが、全体のループを決めるものとなります。

最初Generatorは全然学習できていないのでノイズっぽいデータが出てくることになるでしょう。Discriminatorも判定が出来ないので全く判別できないはずです。エポック数が多くになるにつれて判別ができるようになってきます。

オリジナルの論文で言及していますが、Generatorの学習は十分なDiscriminatorの学習が出来ないのであればしないほうが良いということを行っています。the Helvetica scenarioを避けるためというような、つまらないイギリシアンジョークをいれていますが、要するにDiscriminatorの学習精度を上げるためGeneratorよりも多く学習することが大事です。

さて、今までの説明をもとに疑似コードを見て全体像を掴んでみましょう。

擬似コード

# 200回Generatorを学習                                                                                                                                                                                                                                                        
for epoch in range( 200 ):                                                                                                                                                                                       
    for j in range( 20 ): # Discriminator。Generatorよりも多く学習する。              
        # 偽画像をGから生成する。最初はそれこそランダムっぽい画像が出るが、学習に従って段々と精度が上がる                                                                                                                                                                       
        fake_images = G.generate_images()
        # 本物の画像を取得。                                                                                                                                                                                                                                                    
        true_images = get_correct_images()

        # 偽物の画像を取り出していき学習                                                                                                                                                                                                                                        
        for f_img in fake_images:
            # False or True(0,1)で結果が帰ってくる。                                                                                                                                                                                                                            
            answer = D.check( f_img )
            # 偽画像と判定するべきなので、正解であるFalseを教えて重みを更新                                                                                                                                                                                                        
            D.update( answer, False )
        # 本物の画像をとりだしていく(上記と同じ事を正解画像でやる)                                                                                                                                                                                                        
        for t_img in true_images:
            answer = D.check( t_img )
            D.update( answer, True ) #正解画像なのでTrueを渡す。
    # ----------------                                                                                                                                                                                                                                                          
    # Discriminatorの学習が終わったので、Generatorの学習をする。                                                                                                                                                                                                                        
    # ----------------                                                                                                                                                                                                                                                          
    # UpdateされたDを用いて学習する                                                                                                                                                                                                                                             
    G.set_discriminator( D )
    # 画像を生成するためのシードを作ってやる                                                                                                                                                                                                                                    
    random_images = get_random_image()
    # シードから画像を取り出す                                                                                                                                                                                                                                                  
    for r_img in random_images:
        # G.check内ではFake画像を生成し、Dに判別させ、結果を得る作業が入っている                                                                                                                                                                                  
        answer = G.check( r_img )
        # 得た結果からGの重みを更新。なお、この際にセットしたDの重みは更新しない。                                                                                                                                                                                                              
        G.update( answer )

上記は実装よりの疑似コードですが、概要を知るためにじっくりと眺めてください。そして、上記を見ないで自分で擬似コードを書いてみましょう。

機械学習で最も大事なのはコンセプトの理解です。第三者が書いたコードを写経やコピペしてすぐに走らせたくなる気持ちはわかりますが、こうした一見複雑な仕組みの理解には自分で疑似コードを書くことがおすすめです。他の人のソースコードのコピペでは理解することは難しいでしょう。

GANの数式

GANでは次の式が利用されています。

\underset{G}{\text{min}} \underset{D}{\text{max}}V(D,G) = \mathbb{E}_{x\sim p_{data}(x)}\big[logD(x)\big] + \mathbb{E}_{z\sim p_{z}(z)}\big[log(1-D(G(z)))\big]

一見摩訶不思議ですが、実は非常に簡単です。Gは最小化になるように、Dは値が最大になるようにしたいという意味が込められています

G(z)はFake画像です。zという潜在変数をGという関数にいれて画像を生成することを意味しています。さらにそれをDiscriminatorの関数に噛ませたものがD( G( z ) )と表現されます。

Generatorにとってみると、D(G(z))が1になる、つまり本物と誤解するようにしたい、という意味です。

Discriminatorにとっては左辺は無視します。log( D(x) )は本物のデータを用いる、当然1になるためです。右辺に注目して、log( 1 – D( G(x) ) )において、 D( G(x) )が0になるように頑張るのです。

Generatorにとっては最小化、Discrimanatorにとっては最大化する、というのはこのためです。

EはExpected Lossなのですが、添字のx〜pdata(x)とは確率分布pdataからxを独立的にサンプリングするという意味になります。Eは期待値ですので、イメージ的には全てを足し合わせてサンプル数で割った値となります。例えば200人の人の身長の高さの期待値は次のようになります。

{\Bbb E}[h]=(\sum_{n=1}^{200}h_n)/200

pdata(x) の範囲からxとしてサンプリングしていき、適用して、期待値として最大化する(最小化する)というような意味が込められています。

このような数式の表現の方法、ルール、解説についてhttps://www.hellocybernetics.tech/entry/2018/07/16/234815 に詳しく書いてあります。わからない方は読んでみてください。

PyTorchに依る実装

Pytorchによる実装を示します。PyTorchはプリファード社(Chainerを作っていた会社)が推奨した後継のライブラリです。kerasもいいですが、私は PyTorchのほうがPythonライクで好きです。

まずは画像の取得関数を定義します。

画像取得関数

# 画像取り出し。
import os
from torchvision import datasets
import torchvision.transforms as transforms

def get_dataloader():
    location = "data/mnist"
    os.makedirs(location, exist_ok=True)
    dataloader = torch.utils.data.DataLoader(
	datasets.MNIST(
	    location,
            train=True,
            download=True,
            transform=transforms.Compose(
                [ transforms.Resize( 28 ), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        ),
        batch_size=64,
        shuffle=True,
    )
    return dataloader

単に画像をdata/mnistに保存すると言うだけの画像取得ローダーです。今回は本質にかかわらないので詳しくは説明しません。ただ、今後自分で何か学習用の画像を手に入れた際はローダーを自分で定義していくことになるので、どこかで使い方をマスターする必要があります。

続いて、GeneratorとDiscriminatorのクラスを定義します。最初はGeneratorクラスです。

Generatorクラス

import torch.nn as nn
class Generator( nn.Module ):
    def __init__( self, z_dim = 100, channel = 1, w = 28, h = 28 ):
        super().__init__()
        self.latent_dim = z_dim
        self.img_channels = channel
        self.img_width = w
        self.img_height = h
        self.img_shape = ( self.img_channels, self.img_width, self.img_height )

        def _block( in_feat, out_feat, normalize ):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append( nn.BatchNorm1d( out_feat ) )
            layers.append( nn.LeakyReLU( 0.2 ) )
            return layers

        self.model = nn.Sequential(
            *_block( self.latent_dim, 128, normalize=False ),
            *_block( 128, 256, normalize=True ),
            *_block( 256, 512, normalize=True ),
            *_block( 512, 1024, normalize=True ),
            nn.Linear( 1024, int( np.prod( self.img_shape ) ) ),
            nn.Tanh()
        )
    def forward( self, z ):
        img = self.model( z )
        img = img.view( img.size( 0 ), self.img_channels, self.img_width, self.img_height )
        return img

初期化関数__init__内部ではmodelを作成していきます。Generatorは潜在変数zを受け取って疑似画像を作成しますので、潜在変数zの次元数を受け取れるように設定しています。

nn.Sequentialの内部を見るとわかりますが、まずは100次元を受取、128次元に、128次元からノーマライズして…と繰り返し1024次元に変更した後、にnp.prodを用いいて画像の最終的な次元に展開しています。normalized処理がないと値が安定しないので必ず入れるようにします。

np.prodというのは要素内の全てを掛け合わせるという意味です。img.size( 0 )はバッチ数を意味していて、最初に説明したdata-loaderの設定にもよるのですが64を返します。 ミニバッチ数とは、例えば入力を1枚1枚でなく、64枚の画像単位(バッチ単位)で学習するという意味です。

続いてDiscriminatorの実装例を表示します。

Discriminatorクラス

class Discriminator( nn.Module ):
    def __init__(self, channel = 1, w = 28, h = 28):
        super().__init__()

        self.img_channels = channel
        self.img_width = w
        self.img_height = h

        self.img_shape = ( self.img_channels, self.img_width, self.img_height )

        self.model = nn.Sequential(
            nn.Linear( int( np.prod( self.img_shape ) ), 512),
            nn.LeakyReLU( 0.2 ),
            nn.Linear( 512, 256 ),
            nn.LeakyReLU( 0.2 ),
            nn.Linear( 256, 1 ),
            nn.Sigmoid(),
        )

    def forward( self, img ):
        img_flat = img.view( img.size( 0 ), -1 )
        validity = self.model( img_flat )
        return validity

Generator同様にimg_shapeは画像のチェネル数(RGBなら3チャンネル、グレースケールなら1チャンネル)、画像縦、画像横サイズを保持しています。こいつをnp.prodすることにより、例えば3チャンネル16×16の画像なら768次元に全て展開されます。その次元をだんだんと落とし込んでいき、最後には1次元にしてSigmoid関数に噛ませていきます。

Discriminatorは本物かどうかをYes/Noで判定する学習機でした。そのため、次元を少なくしていきます。最後にシグモイド関数をいれて強制的に0か1にします。シグモイド関数をいれない直前に関しては確率密度関数といわれます。(ここでは割愛)

さて、forward関数ではimg.size( 0 )でミニバッチ数をとりだし、-1を渡して自動展開しています。それをモデルに突っ込んで0か1を受け取っています。

これで準備は整いました。それでは学習プログラムをつくります。メインのプログラムを下に記します。

メイン関数

import torch

def main()    
    batch_size = 64
    # 色々と初期化                                                                                                                                                                                                                                                                                                                                                                                          
    Tensor = torch.cuda.FloatTensor # Tensor = torch.FloatTensor
    generator     = Generator().cuda()
    optimizer_G   = torch.optim.Adam( generator.parameters(), lr=0.0002, betas=( 0.5, 0.999 ) )
    discriminator = Discriminator().cuda()
    optimizer_D   = torch.optim.Adam( discriminator.parameters(), lr=0.0002, betas=( 0.5, 0.999 ) )
    # ロス関数の初期化                                                                                                                                                                                                                                                                                                                                                                                      
    adversarial_loss = torch.nn.BCELoss().cuda()

    epoch_size = 200 # 普通は100-200くらい。                                                                                                                                                                                                                                                                                                                                                                
    for epoch in range( epoch_size ):
        dataloader = get_dataloader()
        for i, ( real_images, some ) in enumerate( dataloader ):
            batch_size = real_images.size( 0 )
            # 正解と不正解のラベルを作る                                                                                                                                                                                                                                                                                                                                                                    
            valid = torch.ones( (batch_size,1), requires_grad=False ).cuda()                                                                                                                                                                                                                                                                                                   
            fake = torch.zeros( (batch_size,1), requires_grad=False ).cuda()
            # ---------------------                                                                                                                                                                                                                                                                                                                                                                         
            #  Dの学習                                                                                                                                                                                                                                                                                                                                                                                      
            # ---------------------                                                                                                                                                                                                                                                                                                                                                                         
            # DはGより20回多く学習をさせる。( オリジナルの論文より)                                                                                                                                                                                                                                                                                                                                      
            for j in range( 20 ):
                # まず初期化                                                                                                                                                                                                                                                                                                                                                                                
                optimizer_D.zero_grad()                                                                                                                                                                                                                                                                                                                                    
                # 偽画像の作成                                                                                                                                                                                                                                                                                                                                                                              
                # ランダムな潜在変数を作成                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
                z = torch.empty( real_images.shape[0], 100,requires_grad=False ).normal_( mean = 0, std = 1 ).cuda()
                # fake imageを取得                                                                                                                                                                                                                                                                                                                                                                          
                fake_images = generator( z )
                # ロスの計算.                                                                                                                                                                                                                                                                                                                                                                               
                real_loss = adversarial_loss( discriminator( real_images.type( Tensor ) ), valid )
                fake_loss = adversarial_loss( discriminator( fake_images.detach() ), fake )
                d_loss = (real_loss + fake_loss) / 2
                # 勾配を計算                                                                                                                                                                                                                                                                                                                                                                                
                d_loss.backward()
                # 伝搬処理。Dにだけ誤差伝搬される                                                                                                                                                                                                                                                                                                                                                           
                optimizer_D.step()
            # ---------------------                                                                                                                                                                                                                                                                                                                                                                         
            #  Gの学習                                                                                                                                                                                                                                                                                                                                                                                      
            # ---------------------                                                                                                                                                                                                                                                                                                                                                                         
            # まず初期化                                                                                                                                                                                                                                                                                                                                                                                    
            optimizer_G.zero_grad()
            # ランダムな潜在変数を作成                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               
            z = torch.empty( real_images.shape[0], 100,requires_grad=False ).normal_( mean = 0, std = 1 ).cuda()
            # fake imageを取得                                                                                                                                                                                                                                                                                                                                                                              
            fake_images = generator( z )
            # discriminatorを利用して結果を取得する                                                                                                                                                                                                                                                                                                                                                         
            g_loss = adversarial_loss(discriminator( fake_images ), valid )
            # 勾配を計算                                                                                                                                                                                                                                                                                                                                                                                    
            g_loss.backward()
            # 重みを更新する。Gのみにだけ勾配伝搬処理がされる                                                                                                                                                                                                                                                                                                                                               
            optimizer_G.step()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, epoch_size, i, len(dataloader), d_loss.item(), g_loss.item())
                )

            batches_done = epoch * len(dataloader) + i
            if batches_done % 400 == 0:
                save_image(fake_images.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

optimizerはそれぞれ、Discriminator とGeneratorで設定します。

ロス関数はバイナリ値(0もしくは1)なのでその関数をセットします。

Valid, Fakeは単に正解ラベルとして差分を計算するために出力しているだけです。

Pseudo-codeで記載したとおりまずはDの学習を先行します。Discriminatorの学習率をアップしたほうが学習結果が良いためです。

Generatorが吐き出した学習結果は次のとおりとなりました。

段々と精度が上がってきているかわかります。見ていると、どうも7,9,1が多いです。こうした減少はよく知られている現象(モード崩壊)で、それを避けるためのテクニックも随所論文で見られます。

ソースコード

https://github.com/octopt/techblog/blob/master/gan/main.py

Githubに上げておりますので参考にしてください。

最後に

GANを作っていくと、学習していくと面白いことに気づきます。例えばGeneratorはよく出来た文字を作り出します。人間でも間違うくらいです。というよりは間違えます。あたかも人間が間違えた文字を拒否する構造は正しいのでしょうか?たしかにそれはGeneratorが作ったものですが、人間が作ったものと同じかもしれません。こうしたことにはどう対処していくべきでしょうか?こんなことも考えながら色々と工夫をしていくと面白いと思います。

手書き文字ではかなりシンプルでした。人間の顔などで作っていくとまた面白い結果が出るでしょう。

参考文献

  1. https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
  2. https://towardsdatascience.com/understanding-generative-adversarial-networks-gans-cd6e4651a29
  3. https://medium.com/deeper-learning/glossary-of-deep-learning-batch-normalisation-8266dcd2fa82
  4. https://github.com/eriklindernoren/PyTorch-GAN

次におすすめの記事