GPTの実装#
コード詳細#
GPT2のリポジトリ:openai/gpt-2: Code for the paper “Language Models are Unsupervised Multitask Learners”
minGPT(シンプルにしたGPT):karpathy/minGPT: A minimal PyTorch re-implementation of the OpenAI GPT (Generative Pretrained Transformer) training
minGPTを実装したブログ:GPT from Scratch - Jake Tae
参考:Python(PyTorch)で自作して理解するTransformer(Encoder-Decoder方式のTransformer)
import math
import torch
from torch import nn
import torch.nn.functional as F
class GPTConfig:
attn_dropout = 0.1
embed_dropout = 0.1
ff_dropout = 0.1
def __init__(
self, vocab_size, max_len, **kwargs
):
self.vocab_size = vocab_size
self.max_len = max_len
for key, value in kwargs.items():
setattr(self, key, value)
class GPT1Config(GPTConfig):
num_heads = 12
num_blocks = 12
embed_dim = 768
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config.embed_dim
self.max_len = config.max_len
# 前処理層(Text & Positional Embedding)
self.tok_embed = nn.Embedding(config.vocab_size, embed_dim)
self.pos_embed = nn.Parameter(torch.zeros(1, config.max_len, embed_dim))
# Dropout
self.dropout = nn.Dropout(config.embed_dropout)
# Transformer blocks
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.num_blocks)])
# Layer Normalization
self.ln = nn.LayerNorm(embed_dim)
# Feed Forward Network
self.fc = nn.Linear(embed_dim, config.vocab_size)
def forward(self, x, target=None):
# batch_size = x.size(0)
seq_len = x.size(1)
assert seq_len <= self.max_len, "sequence longer than model capacity"
tok_embedding = self.tok_embed(x)
# tok_embedding.shape == (batch_size, seq_len, embed_dim)
pos_embedding = self.pos_embed[:, :seq_len, :]
# pos_embedding.shape == (1, seq_len, embed_dim)
x = self.dropout(tok_embedding + pos_embedding)
x = self.blocks(x)
x = self.ln(x)
x = self.fc(x)
# x.shape == (batch_size, seq_len, vocab_size)
return x
class Block(nn.Module):
"""Transformer Block"""
def __init__(self, config):
super().__init__()
embed_dim = config.embed_dim
self.ln1 = nn.LayerNorm(embed_dim)
self.ln2 = nn.LayerNorm(embed_dim)
self.attn = MultiheadAttention(config)
self.ff = nn.Sequential(
nn.Linear(embed_dim, embed_dim * 4),
nn.GELU(),
nn.Linear(embed_dim * 4, embed_dim),
nn.Dropout(config.ff_dropout),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ff(self.ln2(x))
return x
class MultiheadAttention(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config.embed_dim
self.num_heads = config.num_heads
assert embed_dim % self.num_heads == 0, "invalid heads and embedding dimension configuration"
self.key = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.query = nn.Linear(embed_dim, embed_dim)
self.proj = nn.Linear(embed_dim, embed_dim)
self.attn_dropout = nn.Dropout(config.attn_dropout)
self.proj_dropout = nn.Dropout(config.ff_dropout)
self.register_buffer(
"mask",
torch.tril(torch.ones(config.max_len, config.max_len))
.unsqueeze(0).unsqueeze(0)
)
def forward(self, x):
batch_size = x.size(0)
seq_len = x.size(1)
# x.shape == (batch_size, seq_len, embed_dim)
k_t = self.key(x).reshape(batch_size, seq_len, self.num_heads, -1).permute(0, 2, 3, 1)
v = self.value(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
q = self.query(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
# shape == (batch_size, num_heads, seq_len, head_dim)
attn = torch.matmul(q, k_t) / math.sqrt(q.size(-1))
# attn.shape == (batch_size, num_heads, seq_len, seq_len)
# マスキング処理
mask = self.mask[:, :, :seq_len, :seq_len]
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = self.attn_dropout(attn)
# attn.shape == (batch_size, num_heads, seq_len, seq_len)
attn = F.softmax(attn, dim=-1)
y = torch.matmul(attn, v)
# y.shape == (batch_size, num_heads, seq_len, head_dim)
y = y.transpose(1, 2)
# y.shape == (batch_size, seq_len, num_heads, head_dim)
y = y.reshape(batch_size, seq_len, -1)
# y.shape == (batch_size, seq_len, embed_dim)
y = self.proj_dropout(self.proj(y))
return y
Maskについて#
Masked Self-Attention:自己回帰生成(順に要素を予測していくタスク)などに用いられる場合、Self-Attentionの各要素が自身より未来の要素を参照できないようにする必要がある。このため、Attention Matrixに三角状のマスクを適用し、各要素が未来の要素にアクセスできないようにすることで、過去と現在の情報のみから未来の情報を予測できるように学習させる。
max_len = 5
mask = torch.tril(torch.ones(max_len, max_len)).unsqueeze(0).unsqueeze(0)
mask
tensor([[[[1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1.]]]])
seq_len = 3
mask = mask[:, :, :seq_len, :seq_len]
mask
tensor([[[[1., 0., 0.],
[1., 1., 0.],
[1., 1., 1.]]]])
Model#
vocab_size = 10
max_len = 12
config = GPT1Config(vocab_size, max_len)
model = GPT(config)
model
GPT(
(tok_embed): Embedding(10, 768)
(dropout): Dropout(p=0.1, inplace=False)
(blocks): Sequential(
(0): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(1): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(2): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(3): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(4): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(5): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(6): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(7): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(8): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(9): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(10): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
(11): Block(
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(attn): MultiheadAttention(
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(query): Linear(in_features=768, out_features=768, bias=True)
(proj): Linear(in_features=768, out_features=768, bias=True)
(attn_dropout): Dropout(p=0.1, inplace=False)
(proj_dropout): Dropout(p=0.1, inplace=False)
)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
(3): Dropout(p=0.1, inplace=False)
)
)
)
(ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(fc): Linear(in_features=768, out_features=10, bias=True)
)