CLIPは、画像とテキストがデータセット内でペアになっているかどうかを予測するように事前学習されています。図の(2)と(3)にあるように、ゼロショット分類では、データセット内のすべてのクラス名を含んだテキストを作成し、CLIPによって最も確率の高い(画像、テキスト)ペアを予測します。具体的には、画像の特徴埋め込みとテキストの特徴埋め込みを、それぞれのエンコーダーによって計算します。これらの埋め込みのコサイン類似度が計算され、温度パラメータτでスケーリングされ、softmaxによって確率分布に正規化されます。ゼロショット評価では、テキストエンコーダーによって計算されたゼロショット分類器をキャッシュし、その後のすべての予測で再利用します。
Prompt engineering and ensembling
CLIPの論文では、クラス名そのままではなく、クラス名を含んだ"A photo of a {label}"をデフォルトのプロンプトとして使用することが効果的であり、またタスクに合わせてプロンプトをカスタマイズすることも有効であったと述べられています。例えば、Oxford-IIIT Pet Datasetでは"A photo of a {label}, a type of pet"、Food101では"a type of food"、FGVC-Aircraftでは"a type of aircraft"としてカテゴリを指定することが効果的であったことが示されている。
また、パフォーマンス向上のための別の方法としてプロンプトのアンサンブルについても試しています。複数の異なるプロンプト、例えば"A photo of a big {label}"や"A photo of a small {label}"を用いて、テキスト埋め込みベクトルを平均化します。Imagenetでは80個のプロンプトをアンサンブルすることで+3.5%の向上が見られ、プロンプトエンジニアリングと組み合わせると約5%の向上になったと述べられています。
CLIPの事前学習とゼロショット分類におけるプロンプトアンサンブルの実験
CLIPの実装と事前学習を行い、ゼロショット分類のプロンプトアンサンブルを実験してみました。
CLIPの実装と事前学習
- Implementing CLIP With PyTorch Lightningをベースに、論文の疑似コードを実装したロス計算に置き換え
- Image EncoderはResNet50、Text EncoderはDistilBERT
- 事前学習はFlickr8k datasetを利用。trainとvalidationで80 : 20に分割。
- 学習は20エポック
ゼロショット分類
- fastai/imagenetteデータセット(10クラス)を利用
- OpenAIのCLIPレポジトリで公開されている80個のプロンプトテンプレートを使用
- クラスラベルごとに、80個のプロンプトのテキスト埋め込みを平均化(下記コードブロック)、画像の埋め込みベクトルとのコサイン類似度計算
def zeroshot_classifier(labels, templates, model, tokenizer, device, max_length=100): zeroshot_weights = [] for label in tqdm(labels): texts = [template.format(label) for template in templates] texts = tokenizer( text=texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt').to(device) text_embeddings = model.encode_text(texts['input_ids'], texts['attention_mask']) text_embeddings = text_embeddings.mean(dim=0) text_embeddings /= text_embeddings.norm() text_embeddings = text_embeddings.cpu().detach().numpy() zeroshot_weights.append(text_embeddings) zeroshot_weights = np.stack(zeroshot_weights, axis=1) return zeroshot_weights
実験結果
クラス名 | アンサンブルなし | アンサンブルあり | 差分 |
---|---|---|---|
tench | 0.5 | 13.5 | +13.0 |
English springer | 43.0 | 50.4 | +7.4 |
cassette player | 39.1 | 42.2 | +3.1 |
chain saw | 22.5 | 58.2 | +35.7 |
church | 20.1 | 19.3 | -0.8 |
French horn | 76.1 | 79.3 | +3.2 |
garbage truck | 91.2 | 92.5 | +1.3 |
gas pump | 11.7 | 12.2 | +0.5 |
golf ball | 72.8 | 67.0 | -5.8 |
parachute | 87.9 | 87.6 | -0.3 |
全体 | 46.4 | 52.0 | +5.6 |
アンサンブルなしの場合は、デフォルトの"A photo of a {label}"を使用。プロンプトアンサンブルをすることで、全体のAccuracyは+5.6%となった。クラスごとのAccuracyを見てみると、アンサンブルありの場合に低下しているクラスもあるが、多くのクラスでプラスになっていて、改善幅も非常に大きい。
まとめ
CLIPのゼロショット分類において、プロンプトアンサンブルを実際に実験して、Accuracyが改善することを確かめた。論文には、コンテキストのないクラス名を使用するベースラインと比較して、プロンプトエンジニアリングとアンサンブルは、36のデータセットで平均してほぼ5ポイントのゼロショット分類性能を向上させるとも述べられていて、シンプルな方法ではあるが大きな効果がありそう。
参照
実装コード: GitHub - satojkovic/clip-pytorch
CLIP: Connecting text and images
Zero-shot Image Classification with OpenAI's CLIP