最近、Jax/Flaxを触るようになりました。ここでは、Jax/Flaxを用いてVision Transformerを実装した方法と、Jax/Flaxによる学習の方法について紹介しようと思います。
Vision Transformerのおさらい
Vision Transformerを実装するにあたって、まずはこの図を頭に入れておきます。
併せて、ViTの処理を論文で把握しておきます。
- 入力画像からパッチ画像を切り出し、フラットなベクトルに変換。
- Transformer Encoderで扱う隠れベクトルの次元へ射影
- Positional embedを追加、CLSトークンの追加
- Transformer Encoderで処理
- CLS embedを使ってMLPで分類
処理ブロックの実装
ここから、Jax/Flaxを使ってそれぞれの処理ブロックを作っていきます。
1. パッチ画像生成とフラットベクトル化
入力画像をH x W x Cとすると、patch_size x patch_size x Cのパッチ画像を切り出して、フラットなベクトルにします(embed_dim次元)。ここでは、カーネルサイズがpatch_size x patch_size、重複しないようにstridesをpatch_size x patch_sizeにした畳み込みで実装します。
class Patches(nn.Module): patch_size: int embed_dim: int def setup(self): self.conv = nn.Conv( features=self.embed_dim, kernel_size=(self.patch_size, self.patch_size), strides=(self.patch_size, self.patch_size), padding='VALID' ) def __call__(self, images): patches = self.conv(images) b, h, w, c = patches.shape patches = jnp.reshape(patches, (b, h*w, c)) return patches
2と3. D次元ベクトルへの射影とCLSトークン及びPosition Embeddingの追加
Transformer Encoderでは、全ての層で同じ次元サイズhidden_dimを使用します。先ほど作ったフラットなパッチ画像のベクトルを、hidden_dim次元ベクトルに射影します。また、BERTと同じように、分類に使用する特別なトークンとして、CLSトークンをパッチ系列の先頭に追加します。さらに、位置情報を保持するため、学習可能な1D position embeddingも追加します。
class PatchEncoder(nn.Module): hidden_dim: int @nn.compact def __call__(self, x): assert x.ndim == 3 n, seq_len, _ = x.shape # Hidden dim x = nn.Dense(self.hidden_dim)(x) # Add cls token cls = self.param('cls_token', nn.initializers.zeros, (1, 1, self.hidden_dim)) cls = jnp.tile(cls, (n, 1, 1)) x = jnp.concatenate([cls, x], axis=1) # Add position embedding pos_embed = self.param( 'position_embedding', nn.initializers.normal(stddev=0.02), # From BERT (1, seq_len + 1, self.hidden_dim) ) return x + pos_embed
4. Transformer Encoder
上で貼り付けた図にあるように、Transformer EncoderはMulti Head Self Attention(MSA)とMLPが交互に接続された層で構成され、MSAとMLPブロックの前にLayernorm(LN)、ブロックの後にresidual接続を適用します。
class TransformerEncoder(nn.Module): embed_dim: int hidden_dim: int n_heads: int drop_p: float mlp_dim: int def setup(self): self.mha = MultiHeadSelfAttention(self.hidden_dim, self.n_heads, self.drop_p) self.mlp = MLP(self.mlp_dim, self.drop_p) self.layer_norm = nn.LayerNorm(epsilon=1e-6) def __call__(self, inputs, train=True): # Attention Block x = self.layer_norm(inputs) x = self.mha(x, train) x = inputs + x # MLP block y = self.layer_norm(x) y = self.mlp(y, train) return x + y
MLPは2層のネットワークです。活性化関数はGELUを用いています。論文に従い、DropoutをDense層の後に適用しています。
class MLP(nn.Module): mlp_dim: int drop_p: float out_dim: Optional[int] = None @nn.compact def __call__(self, inputs, train=True): actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim x = nn.Dense(features=self.mlp_dim)(inputs) x = nn.gelu(x) x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x) x = nn.Dense(features=actual_out_dim)(x) x = nn.Dropout(rate=self.drop_p, deterministic=not train)(x) return x
Multi Head Self Attention(MSA)
MSAについては、以前のブログMulti Head Attentionの概要を掴む - stMindにもまとめました。
ViTの論文では、qkvを求めるのはとなっていますが、ここでは独立したDenseで実装しています。また、qkvは[B, N, T, D]の形にして、Single Headと同じようにWeightとAttentionを計算した後で、元の[B, T, C=N*D]に戻して出力するようにします。
class MultiHeadSelfAttention(nn.Module): hidden_dim: int n_heads: int drop_p: float def setup(self): self.q_net = nn.Dense(self.hidden_dim) self.k_net = nn.Dense(self.hidden_dim) self.v_net = nn.Dense(self.hidden_dim) self.proj_net = nn.Dense(self.hidden_dim) self.att_drop = nn.Dropout(self.drop_p) self.proj_drop = nn.Dropout(self.drop_p) def __call__(self, x, train=True): B, T, C = x.shape # batch_size, seq_length, hidden_dim N, D = self.n_heads, C // self.n_heads # num_heads, head_dim q = self.q_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3) # (B, N, T, D) k = self.k_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3) v = self.v_net(x).reshape(B, T, N, D).transpose(0, 2, 1, 3) # weights (B, N, T, T) weights = jnp.matmul(q, jnp.swapaxes(k, -2, -1)) / math.sqrt(D) normalized_weights = nn.softmax(weights, axis=-1) # attention (B, N, T, D) attention = jnp.matmul(normalized_weights, v) attention = self.att_drop(attention, deterministic=not train) # gather heads attention = attention.transpose(0, 2, 1, 3).reshape(B, T, N*D) # project out = self.proj_drop(self.proj_net(attention), deterministic=not train) return out
5. CLS embedを使ってMLPで分類
最後にこれまで作ったブロックをまとめて、MLP head(分類ヘッド)を付け加えます。
class ViT(nn.Module): patch_size: int embed_dim: int hidden_dim: int n_heads: int drop_p: float num_layers: int mlp_dim: int num_classes: int def setup(self): self.patch_extracter = Patches(self.patch_size, self.embed_dim) self.patch_encoder = PatchEncoder(self.hidden_dim) self.dropout = nn.Dropout(self.drop_p) self.transformer_encoder = TransformerEncoder(self.embed_dim, self.hidden_dim, self.n_heads, self.drop_p, self.mlp_dim) self.cls_head = nn.Dense(features=self.num_classes) def __call__(self, x, train=True): x = self.patch_extracter(x) x = self.patch_encoder(x) x = self.dropout(x, deterministic=not train) for i in range(self.num_layers): x = self.transformer_encoder(x, train) # MLP head x = x[:, 0] # [CLS] token x = self.cls_head(x) return x
Jax/Flaxによる学習
モデルが作成できたので、Jax/Flaxを使って学習を組み立てていきます。
データセット
ここではtorchvisionのCIFAR10を使います。
def image_to_numpy(img): img = np.array(img, dtype=np.float32) img = (img / 255. - DATA_MEANS) / DATA_STD return img
def numpy_collate(batch): if isinstance(batch[0], np.ndarray): return np.stack(batch) elif isinstance(batch[0], (tuple, list)): transposed = zip(*batch) return [numpy_collate(samples) for samples in transposed] else: return np.array(batch)
test_transform = image_to_numpy train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop((IMAGE_SIZE, IMAGE_SIZE), scale=CROP_SCALES, ratio=CROP_RATIO), image_to_numpy ]) # Validation set should not use the augmentation. train_dataset = CIFAR10('data', train=True, transform=train_transform, download=True) val_dataset = CIFAR10('data', train=True, transform=test_transform, download=True) train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED)) _, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000], generator=torch.Generator().manual_seed(SEED)) test_set = CIFAR10('data', train=False, transform=test_transform, download=True) train_loader = torch.utils.data.DataLoader( train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2, persistent_workers=True, collate_fn=numpy_collate, ) val_loader = torch.utils.data.DataLoader( val_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate, ) test_loader = torch.utils.data.DataLoader( test_set, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=2, persistent_workers=True, collate_fn=numpy_collate, )
モデルの初期化
ViTクラスとして定義したモデルの初期化をします。FlaxのModule.initメソッドでPRNGキーとダミー入力を用いて初期化を実行し、戻り値として得られたパラメータを後で作成するTrainStateで管理、学習のループで更新していくという形で使用します。
def initialize_model( seed=42, patch_size=16, embed_dim=192, hidden_dim=192, n_heads=3, drop_p=0.1, num_layers=12, mlp_dim=768, num_classes=10 ): main_rng = jax.random.PRNGKey(seed) x = jnp.ones(shape=(5, 32, 32, 3)) # ViT model = ViT( patch_size=patch_size, embed_dim=embed_dim, hidden_dim=hidden_dim, n_heads=n_heads, drop_p=drop_p, num_layers=num_layers, mlp_dim=mlp_dim, num_classes=num_classes ) main_rng, init_rng, drop_rng = random.split(main_rng, 3) params = model.init({'params': init_rng, 'dropout': drop_rng}, x, train=True)['params'] return model, params, main_rng
vit_model, vit_params, vit_rng = initialize_model()
TrainStateの作成
Flaxでは、ステップ数やオプティマイザの状態、モデルのパラメータを含む学習の状態を管理するクラスを作成することが一般的なパターンのようです。そのため、基本的なユースケースに対応するTrainStateクラスが提供されています。また、apply_fnにモデルのforwardに相当するapplyを指定しておくことで、学習ループにおける関数の引数リストを少なくすることができます。
def create_train_state( model, params, learning_rate ): optimizer = optax.adam(learning_rate) return train_state.TrainState.create( apply_fn=model.apply, tx=optimizer, params=params )
state = create_train_state(vit_model, vit_params, 3e-4)
学習ループ
学習ループ自体は、Jax/Flaxに特有の記述はありません。
def train_model(train_loader, val_loader, state, rng, num_epochs=100): best_eval = 0.0 for epoch_idx in tqdm(range(1, num_epochs + 1)): state, rng = train_epoch(train_loader, epoch_idx, state, rng) if epoch_idx % 1 == 0: eval_acc = eval_model(val_loader, state, rng) logger.add_scalar('val/acc', eval_acc, global_step=epoch_idx) if eval_acc >= best_eval: best_eval = eval_acc save_model(state, step=epoch_idx) logger.flush() # Evaluate after training test_acc = eval_model(test_loader, state, rng) print(f'test_acc: {test_acc}')
def train_epoch(train_loader, epoch_idx, state, rng): metrics = defaultdict(list) for batch in tqdm(train_loader, desc='Training', leave=False): state, rng, loss, acc = train_step(state, rng, batch) metrics['loss'].append(loss) metrics['acc'].append(acc) for key in metrics.keys(): arg_val = np.stack(jax.device_get(metrics[key])).mean() logger.add_scalar('train/' + key, arg_val, global_step=epoch_idx) print(f'[epoch {epoch_idx}] {key}: {arg_val}') return state, rng
def eval_model(data_loader, state, rng): # Test model on all images of a data loader and return avg loss correct_class, count = 0, 0 for batch in data_loader: rng, acc = eval_step(state, rng, batch) correct_class += acc * batch[0].shape[0] count += batch[0].shape[0] eval_acc = (correct_class / count).item() return eval_acc
Train step
train_stepの中では、ロス関数を定義し、モデルのパラメータに対する勾配を求め、勾配に基づいてパラメータの更新を行います。value_and_gradsでstate.paramsに対する勾配を求め、apply_gradientsでTrainStateを更新します。ロス関数の中では、TrainStateの作成の時に指定したapply_fn(model.applyと同じ)でlogitsを計算して、cross entropy lossを求めます。
@jax.jit def train_step(state, rng, batch): loss_fn = lambda params: calculate_loss(params, state, rng, batch, train=True) # Get loss, gradients for loss, and other outputs of loss function (loss, (acc, rng)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params) # Update parameters and batch statistics state = state.apply_gradients(grads=grads) return state, rng, loss, acc
def calculate_loss(params, state, rng, batch, train): imgs, labels = batch rng, drop_rng = random.split(rng) logits = state.apply_fn({'params': params}, imgs, train=train, rngs={'dropout': drop_rng}) loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels).mean() acc = (logits.argmax(axis=-1) == labels).mean() return loss, (acc, rng)
学習結果
モデルの設定とハイパーパラメータはViT-Tinyを参考にして学習した結果です。Colab proの標準GPUで1.5h程度でした。
test_acc: 0.7704000473022461
まとめ
Jax/Flaxを用いてVision Transformerを実装する方法、学習と評価を行う方法を紹介しました。Test精度を上げるには、もう少し手を入れる必要がありそうですが、基本的なところは実現できたように思います。
Reference
- [2010.11929] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
- [2106.10270] How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers
- Getting Started
- Writing a Training Loop in JAX + FLAX
- Tutorial 15 (JAX): Vision Transformers — UvA DL Notebooks v1.2 documentation