言語モデルとRNN#

言語モデル#

embeddingを直接推定するのではなく、単語予測モデル(言語モデル)を構築して副次的にembeddingを取得することになる。

言語モデルとは、尤もらしい文章を生成できるような確率分布を習得するために、文章の確率を推定するモデルのこと。

ある文章\(S\)をトークン化したのを\((w_1, w_2, \cdots, w_n)\)と表記するならば、

\[ P(S)=P(w_1, w_2, \cdots, w_n) \]

を求めたいということになる。これは条件付き確率の積として表せる

\[\begin{split} \begin{align} P(w_1, w_2, \cdots, w_n) &= P(w_1) \times P(w_2|w_1) \times P(w_3|w_1, w_2) \times \cdots\\ &= \prod_{i=1}^n p(w_i|\boldsymbol{c}_i) \end{align} \end{split}\]

ここで\(\boldsymbol{c}_i\)\(w_i\)より前のトークン列\(\boldsymbol{c}_i=(w_1,w_2,\cdots,w_{i-1})\)で、文脈(context)と呼ばれる

Note

同時確率の分解

これは確率の乗法定理

\[ P(A,B) = P(A|B) P(B) \]

に基づく。

\(w_m\)までの単語を\(C_m\)とすると、

\[ P(\underbrace{w_1, \dots, w_{m-1}}_{C_m}, w_m) = P(C_m, w_m) = P(w_m | C_m) P(C_m) \]

さらに

\[ P(C_m) = P(\underbrace{w_1, \dots, w_{m-2}}_{C_{m-1}}, w_{m-1}) = P(C_{m-1}, w_{m-1}) = P(w_{m-1} | C_{m-1}) P(C_{m-1}) \]

となる。これを繰り返すことで上記のようになる

#

言語モデルは文脈をもとに次の単語を予測する。例えば

Alice is reading a book in the room. Bob comes into the room and says hi to ?

の?に入る語を予測する

語順の問題#

Word2Vecに使われたcontinuous bag-of words (CBOW) のようなFeed-Forward Networkによる言語モデルでは、コンテキストの語順が考慮されない

RNN型ニューラル言語モデル(Mikolov + 2010)#

Mikolov et al. (2010). Recurrent neural network based language model.

RNN#

RNNは前のトークンまでの情報を次のトークンの出力に渡すパスが存在する。 トークンの系列を時系列モデルになぞらえて時刻と表現すると、時刻\(t\)の出力\(\boldsymbol{h}_t\)

\[ \newcommand{\b}[1]{\boldsymbol{#1}} \b{h}_t = \text{tanh}( \b{W}_h \b{h}_{t-1} + \b{W}_x \b{x}_t + \b{b} ) \]

となる。ここで\(\boldsymbol{x}_t\)は入力で、\(\boldsymbol{h}_{t-1}\)は1時刻前の出力、\(\b{W}_h, \b{W}_x\)は重みで\(\b{b}\)はバイアスである。

なお、出力\(\b{h}\)隠れ状態(hidden state)と呼ばれる事が多い

Note

双曲線正接(hyperbolic tangent: tanh)関数

RNNの場合、活性化関数にはtanh

\[ \text{tanh}(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}} \]

を用いることが多い。tanhはシグモイド関数と形状が似ているが、値域は\([-1, 1]\)と負の値を許しており、また二次微分の減衰がゆっくりとゼロになるため勾配消失が起きにくいという特性がある。

../_images/493c0c1521ee2e52cd33d5a8d613ff1e2c53e5b3d3abec673616ebe9ab7e00e5.png

Fig. 1 tanh#

Truncated BPTT#

RNNの誤差逆伝播法は、時間方向への逆伝播法ということで**Backproagation Through Time(BPTT)**と呼ばれる。

しかし長い文章を扱う場合、すべてのトークンを使うと学習の際にメモリに乗り切らない問題や勾配が不安定になる問題がある。 そこで、逆伝播のときはトークン系列を分割して学習する(順伝播は全部つながるようにする)方法があり、これをTruncated BPTTという。

実装(PyTorch)#

\[ h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh}) \]
# 最小構成
import torch
from torch import nn

n_input = 1
n_hidden = 3
n_layer = 2
rnn = nn.RNN(n_input, n_hidden, n_layer)
x = torch.randn(5, 3, n_input)
h0 = torch.randn(2, 3, n_hidden)
output, hn = rnn(x, h0)
output
tensor([[[-0.1668,  0.0216,  0.9021],
         [-0.3862, -0.0949,  0.9085],
         [ 0.0981, -0.6771,  0.0224]],

        [[-0.7750, -0.2505,  0.2387],
         [-0.7632, -0.1634,  0.4631],
         [-0.6494, -0.6738,  0.2008]],

        [[-0.4302, -0.4284,  0.6875],
         [-0.5634, -0.3066,  0.6524],
         [-0.5123, -0.5392,  0.6363]],

        [[-0.6737, -0.3976,  0.4973],
         [-0.7377, -0.3168,  0.4417],
         [-0.6590, -0.4172,  0.5490]],

        [[-0.5468, -0.4108,  0.6417],
         [-0.5659, -0.3634,  0.6552],
         [-0.6862, -0.3765,  0.5160]]], grad_fn=<StackBackward0>)

実装(Python)#

参考:ゼロから作るDeep Learning 2

import numpy as np
h = 2
w = 3
Wh = np.random.normal(size = (h, w))
import numpy as np

class RNN:
    def __init__(self, Wx, Wh, b):
        self.params = [Wx, Wh, b]
        self.grads = [np.zeros_like(Wx),
                      np.zeros_like(Wh),
                      np.zeros_like(b)]
        self.chache = None

    def forward(self, x, h_prev):
        Wx, Wh, b = self.params
        t = Wh @ h_prev + Wx @ x + b
        h_nrex = np.tanh(t)
        self.cache = (x, h_prev, h_next)
        return h_next

    def backword(self, dh_next):
        Wx, Wh, b = self.params
        x, h_prev, h_next = self.cache
        dt = dh_next * (1 - h_next ** 2)
        db = np.sum(dt, axis=0)
        dWh = h_prev.T @ dt
        dh_prev = Wh @ dt
        dWx = x @ dt
        dx = Wx @ dt
        self.grads[0][...] = dWx
        self.grads[1][...] = dWh
        self.grads[2][...] = db
        return dx, dh_prev

RWKV#

RNNを利用しつつも並列計算を可能とした

RWKVを論文と実装から読み解く

従来の大規模言語モデルの制約だった「入力量の限界」を取り払った「RWKV」は一体どんな言語モデルなのか? - GIGAZINE