stMind

about Tech, Computer vision and Machine learning

CNNによるテキスト分類で学習済みword2vec(fastText)を使う

CNNとテキスト分類で検索すると、一番最初に出てくる[WildMLのチュートリアル]。

チュートリアルではembedding layerも含めて学習するようになっていますが、embeddingのところはFacebookが公開しているfastTextの学習済みword2vecで置き換えてやってみました。

github.com

変えたところは大きく3つの部分。

1. vocabとword_vecをfastTextからロード

英語の学習済みword2vecは6GBあって、gensimを使ってロードするのですが、毎回ロードしてると時間がかかってしまうので(20分近くかかる)、事前にvocabularyとword vectorsを別々のファイルに保存しておいて、そのファイルから読み込むようにしました。 後は、オリジナルと同じようにvocabulary indexの作成と、入力文中の単語をindexに変換しました。

print('Load pre-trained word vectors')
with open('fasttext_vocab_en.dat', 'rb') as fr:
    vocab = pickle.load(fr)
embedding = np.load('fasttext_embedding_en.npy')

pretrain = vocab_processor.fit(vocab.keys())
x = np.array(list(vocab_processor.transform(x_text)))

2. Embedding layerはfastTextをfeed

オリジナルでは、ランダムな値で初期化したWをembedding_lookupのparamとして、feedされる入力テキストの単語をベクトルにしていくのですが、Wは0で初期化して、placeholderとして受け取ったfastTextのword vectorsを代入して単語ベクトルに置き換えていきました。

W_ = tf.Variable(
    tf.constant(0.0, shape=[vocab_size, embedding_size]),
                trainable=False,
                name='W')
self.embedding_placeholder = tf.placeholder(
    tf.float32, [vocab_size, embedding_size],
    name='pre_trained')
W = tf.assign(W_, self.embedding_placeholder)

3. 学習

最後は、embeddingをfeed_dictで与えて学習をしました。

feed_dict = {
    cnn.input_x: x_batch,
    cnn.input_y: y_batch,
    cnn.dropout_keep_prob: FLAGS.dropout_keep_prob,
    cnn.embedding_placeholder: embedding
}

ロスとAccはこんな感じで進みました。

# The logs around step 1000 are as follows.
...
2017-07-08T13:12:27.329179: step 990, loss 0.178512, acc 0.953125
2017-07-08T13:12:28.902815: step 991, loss 0.133091, acc 0.984375
2017-07-08T13:12:30.473521: step 992, loss 0.148561, acc 0.984375
2017-07-08T13:12:32.041047: step 993, loss 0.21213, acc 0.90625
2017-07-08T13:12:33.617257: step 994, loss 0.230192, acc 0.9375
2017-07-08T13:12:35.223648: step 995, loss 0.222954, acc 0.9375
2017-07-08T13:12:36.822623: step 996, loss 0.161116, acc 0.96875
2017-07-08T13:12:38.437168: step 997, loss 0.224385, acc 0.921875
2017-07-08T13:12:40.073519: step 998, loss 0.258734, acc 0.921875
2017-07-08T13:12:41.649018: step 999, loss 0.207504, acc 0.953125
2017-07-08T13:12:43.215527: step 1000, loss 0.211571, acc 0.921875

Evaluation:
2017-07-08T13:12:44.823491: step 1000, loss 0.647888, acc 0.681

学習のロスは十分下がっていますが、Testのロスはそれに比べるとあまり下がっていきませんでした。過学習している様子なので、チューニングが必要そうです。

まとめ

CNNでテキスト分類するチュートリアルを、fastTextでやってみました。まとめてしまえばそれほど多くの変更はなかったのですが、いろいろ調べるのに1週間強はかかっていると思います。何かを身につけるのに簡単な道はないですね。