stMind

about Tech, Computer vision and Machine learning

CLIPのゼロショット分類におけるプロンプトアンサンブル

CLIPは、画像とテキストがデータセット内でペアになっているかどうかを予測するように事前学習されています。図の(2)と(3)にあるように、ゼロショット分類では、データセット内のすべてのクラス名を含んだテキストを作成し、CLIPによって最も確率の高い(画像、テキスト)ペアを予測します。具体的には、画像の特徴埋め込みとテキストの特徴埋め込みを、それぞれのエンコーダーによって計算します。これらの埋め込みのコサイン類似度が計算され、温度パラメータτでスケーリングされ、softmaxによって確率分布に正規化されます。ゼロショット評価では、テキストエンコーダーによって計算されたゼロショット分類器をキャッシュし、その後のすべての予測で再利用します。

Prompt engineering and ensembling

CLIPの論文では、クラス名そのままではなく、クラス名を含んだ"A photo of a {label}"をデフォルトのプロンプトとして使用することが効果的であり、またタスクに合わせてプロンプトをカスタマイズすることも有効であったと述べられています。例えば、Oxford-IIIT Pet Datasetでは"A photo of a {label}, a type of pet"、Food101では"a type of food"、FGVC-Aircraftでは"a type of aircraft"としてカテゴリを指定することが効果的であったことが示されている。

また、パフォーマンス向上のための別の方法としてプロンプトのアンサンブルについても試しています。複数の異なるプロンプト、例えば"A photo of a big {label}"や"A photo of a small {label}"を用いて、テキスト埋め込みベクトルを平均化します。Imagenetでは80個のプロンプトをアンサンブルすることで+3.5%の向上が見られ、プロンプトエンジニアリングと組み合わせると約5%の向上になったと述べられています。

CLIPの事前学習とゼロショット分類におけるプロンプトアンサンブルの実験

CLIPの実装と事前学習を行い、ゼロショット分類のプロンプトアンサンブルを実験してみました。

CLIPの実装と事前学習

ゼロショット分類

def zeroshot_classifier(labels, templates, model, tokenizer, device, max_length=100):
    zeroshot_weights = []
    for label in tqdm(labels):
        texts = [template.format(label) for template in templates]
        texts = tokenizer(
            text=texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors='pt').to(device)
        text_embeddings = model.encode_text(texts['input_ids'], texts['attention_mask'])
        text_embeddings = text_embeddings.mean(dim=0)
        text_embeddings /= text_embeddings.norm()
        text_embeddings = text_embeddings.cpu().detach().numpy()
        zeroshot_weights.append(text_embeddings)
    zeroshot_weights = np.stack(zeroshot_weights, axis=1)
    return zeroshot_weights

実験結果

クラス名 アンサンブルなし アンサンブルあり 差分
tench 0.5 13.5 +13.0
English springer 43.0 50.4 +7.4
cassette player 39.1 42.2 +3.1
chain saw 22.5 58.2 +35.7
church 20.1 19.3 -0.8
French horn 76.1 79.3 +3.2
garbage truck 91.2 92.5 +1.3
gas pump 11.7 12.2 +0.5
golf ball 72.8 67.0 -5.8
parachute 87.9 87.6 -0.3
全体 46.4 52.0 +5.6

アンサンブルなしの場合は、デフォルトの"A photo of a {label}"を使用。プロンプトアンサンブルをすることで、全体のAccuracyは+5.6%となった。クラスごとのAccuracyを見てみると、アンサンブルありの場合に低下しているクラスもあるが、多くのクラスでプラスになっていて、改善幅も非常に大きい。

まとめ

CLIPのゼロショット分類において、プロンプトアンサンブルを実際に実験して、Accuracyが改善することを確かめた。論文には、コンテキストのないクラス名を使用するベースラインと比較して、プロンプトエンジニアリングとアンサンブルは、36のデータセットで平均してほぼ5ポイントのゼロショット分類性能を向上させるとも述べられていて、シンプルな方法ではあるが大きな効果がありそう。

参照

実装コード: GitHub - satojkovic/clip-pytorch

CLIP: Connecting text and images

Zero-shot Image Classification with OpenAI's CLIP

Prompt_Engineering_for_ImageNet.ipynb

Implementing CLIP With PyTorch Lightning