Scenicは、TransformerベースのモデルにフォーカスしたオープンソースのJAXライブラリ。 最近、Transformerを適用した動画認識モデルの論文(ViViT, MTV, ObjectViViT)を読んでいる中で見かけていました。
研究のコードであっても、構造化され、実験しやすいことが、色々なアイデアを素早く検証できるベースになることを実感していて、 Scenicが気になっていました。 そこで、arxivに公開されているScenicの論文を読んでみたので、ここで内容をメモしておきます。
Abstract
Scenicの目的は、新しいビジョンアーキテクチャやモデルの素早い実験、プロトタイピング、リサーチを促進すること。 Scenicは、マルチホスト、マルチデバイスの大規模学習のためのGPU/TPUサポートとともに、多様なビジョンタスク(分類、セグメンテーション、検出など)に対応し、マルチモーダルな問題に対する作業を容易にする。 また、幅広いモダリティのSOTAなモデルの最適化実装も提供する。
Introduction
Scenicとは何かを一言で表すと以下の通りとなる。
- ビジョン分野を中心として、大規模なモデルを学習する際に遭遇するタスクを解決するための軽量な共有ライブラリ
- これらのライブラリを利用した固有の問題に対応する多数のプロジェクト
Scenicは、様々な抽象化レベルを提供するように設計されている。例えば、ハイパラの変更のみのプロジェクト、入力パイプラインから、モデルのアーキテクチャ、ロスやメトリクス、学習ループまでカスタマイズが必要なプロジェクトなど。
これを実現するために、Scenicは大きく二つのレベルに整理されている。
- project-level code
- 特定のプロジェクトやベースライン用にカスタマイズされたコード
- library-level code
- 多くのプロジェクトに共通する機能や一般的なパターンのコード
philosophy
Scenicは大規模モデルの素早いプロトタイピングを促進することを目的としている。コードを理解、拡張しやすくするために、複雑さを加えたり、抽象度を上げるよりも、フォークやコピーを好む。 複数のモデルやタスクに広く有用である場合のみ、library-levelに機能を加える。library-levelで固有のユースケースのサポートを最小限にすることで、複雑で理解しづらくなる一般化を避ける。一方で、project-levelでは、複雑さや抽象化を加えることができる。
Design
Library-level code
目標は、ライブラリレベルのコードを最小限かつ十分にテストされたものに保ち、マイナーなユースケースをサポートするために余分な抽象化を導入しないようにすること。 共有ライブラリはFigure 1にあるように4つに分割されている。
- dataset_lib
- 一般的なタスクやベンチマークのデータをロードし、前処理するためのIOパイプラインを実装。
- model_lib
- タスクに特化したロスとメトリクスを持つ、いくつかの抽象モデル・インターフェース(例:model_lib/base_modelsのClassificationModelやSegmentationModel)
- attentionとtransformerの効率的な実装に焦点を当てたNN Layer(model_lib/layers)
- アクセラレーターフレンドリーなbipartite matching algorithmの実装(model_lib/matchers)
- train_lib
- 学習ループを構築するためのツールを提供
- common_lib
- ロギングやデバッグモジュール、Raw データを処理するための機能などの共通ユーティリティ
Project-level code
Project-level codeは、"プロジェクト "という概念によって、個別タスクやデータのためにカスタマイズされたソリューションの開発をサポート。プロジェクトは、設定ファイルのみで共通のモデルや学習器などのlibrary-level codeを使うこともできるし、フォークして再定義することもできる。 ResNetやViT、DETRはprojects/baselinesに実装されている。
Scenic BaseModel
ソリューションは通常、データやタスクのパイプライン、モデルアーキテクチャ、ロスとメトリクス、学習と評価などのパーツに分かれている。Scenicで行われる研究の多くが異なるアーキテクチャを試していることから、プラグイン/プラグアウトでの実験を容易にするための「model」という概念を導入。「model」は、ネットワークのアーキテクチャ、ロス、評価メトリクスとして定義され、BaseModelとして実装されている。
BaseModelは抽象クラスで、3つのメンバを持つ。
- build_flax_model
- loss_fn
- get_metrics_fn
Scenicのモデルを定義する抽象クラスは、model_lib/base_modelsにあり、BaseModelの他に、BaseModelを継承したClassificationModel、MulitLabelClassificationModel、EncoderDecoderModel、SegmentationModelも含まれる。
これらのデザインパターンは推奨であり、様々なプロジェクトでうまく機能するが、強制ではなく、プロジェクト内でこの構造から逸脱しても問題はない。
一言
Library-levelのコードと、Project-levelのコードの二つに分けて、複数のモデルやタスクに広く有用である場合のみ、Library-levelに機能を追加するという考えは分かりやすいと思いました。bipartite matchingはattentionと同じ階層でLibrary-levelなんですね。OpenPoseで使われていたと記憶してますが、他にも幅広く使われているということなのか。