stMind

about Tech, Computer vision and Machine learning

ゼロから作るVision Transformer (JAX/Flax)

最近、Jax/Flaxを触るようになりました。ここでは、Jax/Flaxを用いてVision Transformerを実装した方法と、Jax/Flaxによる学習の方法について紹介しようと思います。

Vision Transformerのおさらい

Vision Transformerを実装するにあたって、まずはこの図を頭に入れておきます。

併せて、ViTの処理を論文で把握しておきます。

  1. 入力画像からパッチ画像を切り出し、フラットなベクトルに変換。
  2. Transformer Encoderで扱う隠れベクトルの次元へ射影
  3. Positional embedを追加、CLSトークンの追加
  4. Transformer Encoderで処理
  5. CLS embedを使ってMLPで分類

処理ブロックの実装

ここから、Jax/Flaxを使ってそれぞれの処理ブロックを作っていきます。

1. パッチ画像生成とフラットベクトル化

入力画像をH x W x Cとすると、patch_size x patch_size x Cのパッチ画像を切り出して、フラットなベクトルにします(embed_dim次元)。ここでは、カーネルサイズがpatch_size x patch_size、重複しないようにstridesをpatch_size x patch_sizeにした畳み込みで実装します。

class Patches(nn.Module):
  patch_size: int
  embed_dim: int

  def setup(self):
    self.conv = nn.Conv(
        features=self.embed_dim,
        kernel_size=(self.patch_size, self.patch_size),
        strides=(self.patch_size, self.patch_size),
        padding='VALID'
    )

  def __call__(self, images):
    patches = self.conv(images)
    b, h, w, c = patches.shape
    patches = jnp.reshape(patches, (b, h*w, c))
    return patches

2と3. D次元ベクトルへの射影とCLSトークン及びPosition Embeddingの追加

Transformer Encoderでは、全ての層で同じ次元サイズhidden_dimを使用します。先ほど作ったフラットなパッチ画像のベクトルを、hidden_dim次元ベクトルに射影します。また、BERTと同じように、分類に使用する特別なトークンとして、CLSトークンをパッチ系列の先頭に追加します。さらに、位置情報を保持するため、学習可能な1D position embeddingも追加します。

class PatchEncoder(nn.Module):
  hidden_dim: int

  @nn.compact
  def __call__(self, x):
    assert x.ndim == 3
    n, seq_len, _ = x.shape
    # Hidden dim
    x = nn.Dense(self.hidden_dim)(x)
    # Add cls token
    cls = self.param('cls_token', nn.initializers.zeros, (1, 1, self.hidden_dim))
    cls = jnp.tile(cls, (n, 1, 1))
    x = jnp.concatenate([cls, x], axis=1)
    # Add position embedding
    pos_embed = self.param(
        'position_embedding', 
        nn.initializers.normal(stddev=0.02), # From BERT
        (1, seq_len + 1, self.hidden_dim)
    )
    return x + pos_embed

4. Transformer Encoder

上で貼り付けた図にあるように、Transformer EncoderはMulti Head Self Attention(MSA)とMLPが交互に接続された層で構成され、MSAとMLPブロックの前にLayernorm(LN)、ブロックの後にresidual接続を適用します。

class TransformerEncoder(nn.Module):
  embed_dim: int
  hidden_dim: int
  n_heads: int
  drop_p: float
  mlp_dim: int

  def setup(self):
    self.mha = MultiHeadSelfAttention(self.hidden_dim, self.n_heads, self.drop_p)
    self.mlp = MLP(self.mlp_dim, self.drop_p)
    self.layer_norm = nn.LayerNorm(epsilon=1e-6)
  
  def __call__(self, inputs, train=True):
    # Attention Block
    x = self.layer_norm(inputs)
    x = self.mha(x, train)
    x = inputs + x
    # MLP block
    y = self.layer_norm(x)
    y = self.mlp(y, train)

    return x + y

MLPは2層のネットワークです。活性化関数はGELUを用いています。論文に従い、DropoutをDense層の後に適用しています。

class MLP(nn.Module):
  mlp_dim: int
  drop_p: float
  out_dim: Optional[int] = None

  @nn.compact
  def __call__(self, inputs, train=True):
    actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
    x = nn.Dense(features=self.mlp_dim)(inputs)
    x = nn.gelu(x)
    x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x)
    x = nn.Dense(features=actual_out_dim)(x)
    x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x)
    return x
Multi Head Self Attention(MSA)

MSAについては、以前のブログMulti Head Attentionの概要を掴む - stMindにもまとめました。 ViTの論文では、qkvを求めるのは \rm{U_{qkv}}となっていますが、ここでは独立したDenseで実装しています。また、qkvは[B, N, T, D]の形にして、Single Headと同じようにWeightとAttentionを計算した後で、元の[B, T, C=N*D]に戻して出力するようにします。

class MultiHeadSelfAttention(nn.Module):
  hidden_dim: int
  n_heads: int
  drop_p: float

  def setup(self):
    self.q_net = nn.Dense(self.hidden_dim)
    self.k_net = nn.Dense(self.hidden_dim)
    self.v_net = nn.Dense(self.hidden_dim)

    self.proj_net = nn.Dense(self.hidden_dim)

    self.att_drop = nn.Dropout(self.drop_p)
    self.proj_drop = nn.Dropout(self.drop_p)

  def __call__(self, x, train=True):
    B, T, C = x.shape # batch_size, seq_length, hidden_dim
    N, D = self.n_heads, C // self.n_heads # num_heads, head_dim
    q = self.q_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3) # (B, N, T, D)
    k = self.k_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3)
    v = self.v_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3)

    # weights (B, N, T, T)
    weights = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / math.sqrt(D)
    normalized_weights = nn.softmax(weights, axis=-1)

    # attention (B, N, T, D)
    attention = jnp.matmul(normalized_weights, v)
    attention = self.att_drop(attention, deterministic=not train)

    # gather heads
    attention = attention.transpose(0, 2, 1, 3).reshape(B, T, N*D)

    # project
    out = self.proj_drop(self.proj_net(attention), deterministic=not train)

    return out

5. CLS embedを使ってMLPで分類

最後にこれまで作ったブロックをまとめて、MLP head(分類ヘッド)を付け加えます。

class ViT(nn.Module):
    patch_size: int
    embed_dim: int
    hidden_dim: int
    n_heads: int
    drop_p: float
    num_layers: int
    mlp_dim: int
    num_classes: int

    def setup(self):
        self.patch_extracter = Patches(self.patch_size, self.embed_dim)
        self.patch_encoder = PatchEncoder(self.hidden_dim)
        self.transformer_blocks = [
            Transformer(
                self.embed_dim, self.hidden_dim, self.n_heads, self.drop_p, self.mlp_dim
            )
            for _ in range(self.num_layers)
        ]
        self.mlp_head = MLP(self.mlp_dim, self.drop_p)
        self.cls_head = nn.Dense(features=self.num_classes)

    def __call__(self, x, train=True):
        x = self.patch_extracter(x)
        x = self.patch_encoder(x)
        for block in self.transformer_blocks:
            x = block(x, train)
        # MLP head
        x = x[:, 0]  # [CLS] token
        x = self.mlp_head(x, train)
        x = self.cls_head(x)
        return x

Jax/Flaxによる学習

モデルが作成できたので、Jax/Flaxを使って学習を組み立てていきます。

データセット

ここではtorchvisionのCIFAR10を使います。

def image_to_numpy(img):
  img = np.array(img, dtype=np.float32)
  img = (img / 255. - DATA_MEANS) / DATA_STD
  return img
def numpy_collate(batch):
  if isinstance(batch[0], np.ndarray):
    return np.stack(batch)
  elif isinstance(batch[0], (tuple, list)):
    transposed = zip(*batch)
    return [numpy_collate(samples) for samples in transposed]
  else:
    return np.array(batch)
test_transform = image_to_numpy
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE), scale=CROP_SCALES, ratio=CROP_RATIO),
    image_to_numpy
])

# Validation set should not use the augmentation.
train_dataset = CIFAR10('data', train=True, transform=train_transform, download=True)
val_dataset = CIFAR10('data', train=True, transform=test_transform, download=True)
train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED))
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED))
test_set = CIFAR10('data', train=False, transform=test_transform, download=True)

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
val_loader = torch.utils.data.DataLoader(
    val_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate,
)

モデルの初期化

ViTクラスとして定義したモデルの初期化をします。FlaxのModule.initメソッドでPRNGキーとダミー入力を用いて初期化を実行し、戻り値として得られたパラメータを後で作成するTrainStateで管理、学習のループで更新していくという形で使用します。

def initialize_model(
    seed=42,
    patch_size=16, embed_dim=192, hidden_dim=192,
    n_heads=3, drop_p=0.1, num_layers=12, mlp_dim=768, num_classes=10
):
  main_rng = jax.random.PRNGKey(seed)
  x = jnp.ones(shape=(5, 32, 32, 3))
  # ViT
  model = ViT(
      patch_size=patch_size,
      embed_dim=embed_dim,
      hidden_dim=hidden_dim,
      n_heads=n_heads,
      drop_p=drop_p,
      num_layers=num_layers,
      mlp_dim=mlp_dim,
      num_classes=num_classes
  )
  main_rng, init_rng, drop_rng = random.split(main_rng, 3)
  params = model.init({'params': init_rng, 'dropout': drop_rng}, x, train=True)['params']
  return model, params, main_rng
vit_model, vit_params, vit_rng = initialize_model()

TrainStateの作成

Flaxでは、ステップ数やオプティマイザの状態、モデルのパラメータを含む学習の状態を管理するクラスを作成することが一般的なパターンのようです。そのため、基本的なユースケースに対応するTrainStateクラスが提供されています。また、apply_fnにモデルのforwardに相当するapplyを指定しておくことで、学習ループにおける関数の引数リストを少なくすることができます。

def create_train_state(
    model, params, learning_rate
):
  optimizer = optax.adam(learning_rate)
  return train_state.TrainState.create(
      apply_fn=model.apply,
      tx=optimizer,
      params=params
  )
state = create_train_state(vit_model, vit_params, 3e-4)

学習ループ

学習ループ自体は、Jax/Flaxに特有の記述はありません。

def train_model(train_loader, val_loader, state, rng, num_epochs=100):
  best_eval = 0.0
  for epoch_idx in tqdm(range(1, num_epochs + 1)):
    state, rng = train_epoch(train_loader, epoch_idx, state, rng)
    if epoch_idx % 1 == 0:
      eval_acc = eval_model(val_loader, state, rng)
      logger.add_scalar('val/acc', eval_acc, global_step=epoch_idx)
      if eval_acc >= best_eval:
        best_eval = eval_acc
        save_model(state, step=epoch_idx)
      logger.flush()
  # Evaluate after training
  test_acc = eval_model(test_loader, state, rng)
  print(f'test_acc: {test_acc}')
def train_epoch(train_loader, epoch_idx, state, rng):
  metrics = defaultdict(list)
  for batch in tqdm(train_loader, desc='Training', leave=False):
    state, rng, loss, acc = train_step(state, rng, batch)
    metrics['loss'].append(loss)
    metrics['acc'].append(acc)
  for key in metrics.keys():
    arg_val = np.stack(jax.device_get(metrics[key])).mean()
    logger.add_scalar('train/' + key, arg_val, global_step=epoch_idx)
    print(f'[epoch {epoch_idx}] {key}: {arg_val}')
  return state, rng
def eval_model(data_loader, state, rng):
  # Test model on all images of a data loader and return avg loss
  correct_class, count = 0, 0
  for batch in data_loader:
    rng, acc = eval_step(state, rng, batch)
    correct_class += acc * batch[0].shape[0]
    count += batch[0].shape[0]
  eval_acc = (correct_class / count).item()
  return eval_acc

Train step

train_stepの中では、ロス関数を定義し、モデルのパラメータに対する勾配を求め、勾配に基づいてパラメータの更新を行います。value_and_gradsでstate.paramsに対する勾配を求め、apply_gradientsでTrainStateを更新します。ロス関数の中では、TrainStateの作成の時に指定したapply_fn(model.applyと同じ)でlogitsを計算して、cross entropy lossを求めます。

@jax.jit
def train_step(state, rng, batch):
  loss_fn = lambda params: calculate_loss(params, state, rng, batch, train=True)
  # Get loss, gradients for loss, and other outputs of loss function
  (loss, (acc, rng)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
  # Update parameters and batch statistics
  state = state.apply_gradients(grads=grads)
  return state, rng, loss, acc
def calculate_loss(params, state, rng, batch, train):
  imgs, labels = batch
  rng, drop_rng = random.split(rng)
  logits = state.apply_fn({'params': params}, imgs, train=train, rngs={'dropout': drop_rng})
  loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean()
  acc = (logits.argmax(axis=-1) == labels).mean()
  return loss, (acc, rng)

学習結果

モデルの設定とハイパーパラメータはViT-Tinyを参考にして学習した結果です。Colab proの標準GPUで1.5h程度でした。

test_acc: 0.7704000473022461

まとめ

Jax/Flaxを用いてVision Transformerを実装する方法、学習と評価を行う方法を紹介しました。Test精度を上げるには、もう少し手を入れる必要がありそうですが、基本的なところは実現できたように思います。

github.com

Reference

pytorchとtensorflowのチュートリアル写経で見つけた小さなバグをPull Requestしてマージされるまで

正確にいうと、torchvisionとTensorflow Textのチュートリアル

torchvisionの方は、ビデオファイルを読み込んでビデオフレームとオーディオフレームを返すVideo APIのチュートリアルで、ビデオのptsはvideo_ptsにappendして、オーディオのptsはaudio_ptsにappendするべきところが、オーディオの方もvideo_ptsにappendされていた。単にvideo_ptsをaudio_ptsに置き換えるだけのPRを作成して、3日ほどでマージされました。

github.com

Tensorflow Textの方は、Transformerを使った機械翻訳のチュートリアルで、ポルトガル語を処理するTokenizerがtokenizers.enとなっていたので、tokenizers.ptと修正してPRを作成。こちらも2日ほどでマージされました。現在、このチュートリアルは結構書き換えられていて、修正した箇所を含むブロックは無くなっています:)

github.com

今回はチュートリアル写経の副産物としてのPRでしたが、このような貢献の形もあるよという紹介でした。

stackoverflowで今年回答した内容のまとめ

今年の目標の一つとして、stackoverflowで回答を増やそうと考えていたのだけど、現時点で11件ほど行なうことが出来たので、ここでまとめておこうと思う。回答がAcceptされるとベターだとは思うものの、回答後のリアクションは質問者次第なので、コミュニティに貢献できたと考えれば良いかな。

一応、スタッツとしては以下のような感じ。

  • Accepted : 1
  • Upvote : 6

カテゴリとしては、Tutorialなどをやっていたこともあって、transformer系の質問への回答が多めになった。自分自身の勉強にもなるし、コミュニティ貢献としてstackoverflowの回答をゆるりと続けていこうと思う。

1. Masking layer vs attention_mask parameter in MultiHeadAttention

Acceptされた回答。MultiHeadAttentionを使うときattention_maskでマスク指定しているが、MHAの前にMaskingレイヤーを使ったらattention_maskは不要なのか?それとも両方使わないといけないのか?という質問。 調べてみると、Maskingレイヤーで生成されたマスクは、対応しているレイヤーでは入力に対応するマスクの自動取得して使う機能が入っているが、MHAはTF2.10.0で対応ということだったので、そのように回答。

stackoverflow.com

2. Can't install tensorflow-io on m1

一番Upvoteされた回答。m1 macにtensorflow-ioに入れるには?ということで、git cloneしてwheelをローカルで作ってpip installする方法を回答。結構、こういうのが一番困っているのかもしれない。

stackoverflow.com

3. Meaning of the array returned by the activation function GELU (Vision Transformer)

これも一つUpvoteが付いた回答。ViTで推論するとき、GELUの出力の最も高い値をつかって分類しているが問題ないか?という質問。そのまま使うか、確率的な値が必要であればsoftmaxを使うこともできると回答。

stackoverflow.com

4. Transformer with multi input

二つの異なる系列データに対して、Cross Attentionを行なう方法を知りたいという質問。具体的なサンプルはMultiHeadAttentionの公式APIドキュメントにあったので、それを回答。

stackoverflow.com

5. Visualizing ViT Attention maps after fine tuning on medical dataset

vit-kerasを使って独自モデルのAttentionを可視化したいけどエラーになる、attention_mapへの引数の渡し方が間違っている様子だったので、修正方法を回答。

stackoverflow.com

6. Patch Encoder for ViT implementation in Python

PatchEncoderクラスで実行されている処理について知りたいということだったので、Dense適用、position embeddingの追加をcifar10を例に回答。

stackoverflow.com

7. ViVIT PyTorch: RuntimeError: multi-target not supported at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:15

ViViTを実行した時に、RuntimeError: multi-target not supported at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:15が出るようなので調べてみて、Pytorchでは1.10.0から対応してることと、実際に実行できるcolabのノートを回答。

stackoverflow.com

8. In transformers of ViT model, last_hidden_state is not equal to hidden_states[-1]

huggingfaceのViTで、last_hidden_stateとhidden_states[-1]は同じに思うけれど、実際の値が異なるのはなぜか?という質問。コードを読んでみて、last_hidden_stateの方はLayernormが適用されている点が違いで、実際にhidden_states[-1]にLayernormを適用すれば同じ値になることを確認してcolabのノートと合わせて回答。

stackoverflow.com

9. Vision Transformer models in vit-keras

vit-kerasでvit-b32とvit-b16以外の使用可能なモデルと入力画像サイズを知りたいということで、調べて回答。

stackoverflow.com

10. no fine_tune_checkpoint field in pipeline.config

TF OD APIでpipeline.configのfine_tune_checkpointの指定方法を回答。

stackoverflow.com

11. Google Kickstart 2014 Round D Sort a scrambled itinerary - Do I need to bring the input in a ready-to-use array format?

これは少し違う話で、GoogleのCodingコンペティションKickstartで入力をどうやって受け取ればいいのか?ということだったので、サンプルを提示して回答。

stackoverflow.com

DALL-E 2, Imagen, Parti

テキストから画像を生成するモデルが話題ですが、代表的なモデルであるDALL-E2とImagen、Partiのアーキテクチャを比較しているTwitterのスレッドを紹介。

以下は、アーキテクチャの構成ブロックの簡単な比較を表にまとめたもの。

DALL-E2 Imagen Parti
Text Encoder CLIP T5-XXL Transformer Encoder
Text Embeddings to Image (64x64 or 256x256) Diffusion Model x 2 (Prior + Decoder) Diffusion Model Transformer Decoder + ViT-VQGAN
Upsample the Image (256x256->1024x1024) Diffusion Model x 2 Diffusion Model x 2 Convolution

それぞれの論文は画像がたくさん含まれてはいますが、DALL-E2が27ページ、Imagenが46ページ、Partiは49ページもあって、読むぞ!とは気軽に言えない分量ですね...

TD OD APIでデータ拡張のオプションを追加する

データ拡張のオプションは、TrainConfigの中でdata_augmentation_optionsとして定義されている。

data_augmentation_optionsは、repeatedフィールドで、PreprocessingStepに指定されているデータ拡張を任意の数だけ指定することができる。

// Message for configuring DetectionModel training jobs (train.py).
// Next id: 31
message TrainConfig {
  // Effective batch size to use for training.
  // For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be
  // `batch_size` / number of cores (or `batch_size` / number of GPUs).
  optional uint32 batch_size = 1 [default=32];

  // Data augmentation options.
  repeated PreprocessingStep data_augmentation_options = 2;

例えば、random_horizontal_flipは、水平方向の反転を行う処理で、デフォルトでは50%の確率で行われる。

pipeline.config内では、フォーマットに従って使いたい拡張を指定すれば良い。

train_config: {
  ...
  data_augmentation_options {
    random_horizontal_flip {
    }
   random_image_scale {
       min_scale_ratio: 0.9
       max_scale_ratio: 1.1
    }
  }
}

一方、ファイル編集ではなくコードの中で直接データ拡張を追加したい場合には、以下の様に行う。

例えば、ssd_resnet50_v1_fpn_640x640_coco17_tpu-8のrandom_horizontal_flipとrandom_crop_imageに対して、random_image_scaleを追加する場合。

from google.protobuf import text_format

from object_detection.builders import preprocessor_builder
from object_detection.core import preprocessor
from object_detection.protos import preprocessor_pb2
from object_detection.utils import config_util

configs = config_util.get_configs_from_pipeline_file(/path/to/pipeline_config)
train_config = configs['train_config']

# Random image scale
preprocessor_text_proto = """
random_image_scale {
  min_scale_ratio: 0.8
  max_scale_ratio: 2.2
}
"""
preprocessor_proto = preprocessor_pb2.PreprocessingStep()
text_format.Merge(preprocessor_text_proto, preprocessor_proto)
train_config.data_augmentation_options.append(preprocessor_proto)

データ拡張オプションが追加されて、3つの処理が指定されている状態。

print(train_config.data_augmentation_options)
"""
[random_horizontal_flip {
}
, random_crop_image {
  min_object_covered: 0.0
  min_aspect_ratio: 0.75
  max_aspect_ratio: 3.0
  min_area: 0.75
  max_area: 1.0
  overlap_thresh: 0.0
}
, random_image_scale {
  min_scale_ratio: 0.800000011920929
  max_scale_ratio: 2.200000047683716
}
]
"""

参考

discuss.tensorflow.org

BERTの概要を掴む

一回目のMulti Head Attention、二回目のGPTに続いて、三回目はBERT。

Multi Head Attentionの概要を掴む - stMind

GPTの概要を掴む - stMind

以下、メモを残す。

3/n. Bi-directional

GPTでは、未来の単語の予測を目的として、言語モデルをpre-train出来ることを確認してきた。BERTでは、未来の単語の予測の代わりに、空欄を埋める目的を使用する。GPTと異なり、過去と未来のトークンを同時に見るため、双方向と呼ばれる。

4/n. BERT

BERTはどのように動作するか?アーキテクチャは以下の通り。GPTと二つの項目を除いて、ほぼ同じように動作する。一つ目の違いは、Causalマスクの代わりにランダムマスクを使うこと、二つ目はシーケンスの先頭に[CLS]を追加すること。この[CLS]トークンとは何か?

5/n. [CLS]トーク

例えば、Sentiment Analysisを行いたいとする。Transformerの出力shapeは [B, T, D]であり、これを[B, D]に圧縮し、分類器に入力する集約的な表現にしたい。最初に思いつく方法として、Tについて平均をとるのはどうか?

6/n. Aggregate representation

これでも機能するが、全てのトークンは分類に等しく有用であることを想定している。Attentionは、関連性に基づいてトークンを重みづけすることが重要だったはず。入力に新しいトークンを追加して、他のトークンをAttentionで集約したらどうか?それが[CLS]。

7/n. 分類器への入力としての[CLS]

分類タスクに対してBERTをFinetuneする場合、[CLS]トークンの最後の隠れ状態を分類器への入力として使用する。この時、MLPはshapeが[B, D]の隠れ状態を[B, クラス数]の出力に変換する。

8/n. [MASK]と[SEP]

BERTでは、[CLS]の他にも特別なトークンがある。マスクされたトークンを置き換える[MASK]と、センテンスの区切りとなる[SEP]。Tokenizationはデータの読み込み時に行われ、その後、positional embeddingが加えられる。

9/n. Transformer block

BERTがランダムマスク、GPTがCausalマスクを使用する違いはあるが、Transformer blockの実装はほぼ同じである。

10/n. generate_random_mask, mlm_loss

BERTでは、pretrainのためにデータのバッチを渡すたびに、新しいランダムなマスクをサンプリングする。これは、GPU上で直接マスクを作成すれば、効率的に行うことができる。以下が、マスクの作成とmasked language modelingの目的の実装。

11/n. tl;dr

BERT=GPT、ただしランダムマスクと他の特別なトークンを使用する違いがある。Transformerはとてもシンプルで汎用的であることは驚きである。

次回は、コンピュータビジョンにTransformerを適用したViTとMAE lossについて。

GPTの概要を掴む

前回は、Multi Head Attentionに関するTwitterの一連のスレッドを紹介した。

Multi Head Attentionの概要を掴む - stMind

今回はGPTについて。

以下、前回同様に自分が理解したメモを残す。

2/n. Multi Head Attentionの振り返り

Attentionは、ネットワークが入力に含まれる全ての単語とその関係を捉えることを可能にする。結果として、ネットワークはその目的を最適化するための最も重要な単語に注意を向ける様になる。

3/n. MHAの最適化

これまでのところ、MHAが最適化する目的を定義していない。GPTでは、非常にシンプルな、次の単語の予測というunsupervisedな目的を用いる。直前までの単語が与えられたら、次の単語を予測する。この目的であればラベルは不要なので、unsupervisedと呼ばれる。

4/n. Causal構造

未来の単語を予測する場合には、直前までの単語だけ参照するCausal構造を強制する必要がある。Causal Attention行列において、0は「関係性なし」を意味し、現在の単語と未来の単語間のAttentionは全て0にする必要がある。そのために、Weights行列( QK^{T})において、未来の単語を-infとする。

5/n. Causal Attention

Weights行列にSoftmaxを適用した後で、未来の単語をマスクして0にすると正規化されなくなる。そのため、Weights行列で未来の単語を-infにしてからSoftmaxを適用する。

6/n. Masked Causal Attention

Masked Causal AttentionがGPTのメインアイデア。GPTのTransformer BlockはMHA→LayerNorm→MLP→LayerNrom。入力shapeは(B, T, D_in)、出力shapeは(B, T, D_out)で、大抵はD_out=D_inとなる。

7/n. Loss

GPTの目的は次の単語の予測であった。英語には約100万語あり、文字通りに単語を予測する場合、100万クラスの分類をすることになる。GPTが最適化するロスを記述する。

loss = cross_entropy(preds, targets) # (B, T, 1e-7)

8/n. Tokens

ロスにおけるクラス数を削減するために、トークンを用いる。トークンは文字とベクトル間のmapで、例えばアルファベットの文字は26個の一意なベクトルで表すことができる(one-hot ベクトル)。文字列からユニークなベクトルへの変換はトークン化(Tokenization)と呼ばれる。

9/n. Byte Pair Encoding

GPTでは、頻度の高い文字グループ(ペア)をトークンにするByte Pair Encoding(BPE)を用いる。以下は、Attention is not too shabbyにBPEを適用して、11個のユニークなトークンになった例(ユニークなCharは12個)。

10/n. The order of words

最後の、しかし重要な問題は、今のモデルには語順を知る方法がないということ。

11/n. Positional tokens

語順をエンコードするために、Positional tokensを使用する。文字トークンと同じように、位置をユニークなベクトルで表す。文字と位置のトークンをそれぞれ線形層で変換した後、それらを足し合わせる。その後、Transformerブロックに入力される。

12/n. 実装