Skip to content

SwinTransformer/Feature-Distillation

Repository files navigation

Feature-Distillation

By Yixuan Wei*, Han Hu*, Zhenda Xie, Zheng Zhang, Yue Cao, Jianmin Bao, Dong Chen and Baining Guo.

This repo is the official implementation of "Contrastive Learning Rivals Masked Image Modeling in Fine-tuning via Feature Distillation".

Updates

11/30/2022

  1. Distilled and fine-tuned models on ImageNet-1K (ViT Large) are provided.

11/28/2022

Initial commits:

  1. Distilled and fine-tuned models on ImageNet-1K (Swin Base, and ViT Base) are provided.
  2. The supported code for ImageNet-1K distillation and fine-tuning is provided.

Introduction

FD is initially described in arxiv, which is a simple framework to convert the traditional pre-training models, such as image classification (DeiT), instance contrastive learning (DINO) and image-text alignment (CLIP) into new models with better fine-tuning performances. Through a set of diagosing tools, we find that the models distilled with feature map are endowed with following good properties which are also revealed in masked image modeling models: 1) more diverse attention heads; 2) more diagonal attention patterns; 3) flatten loss landscapes.

Main Results on ImageNet

Swin Transformer

ImageNet-1K Distilled and Fine-tuned Models

name distillation epochs teacher model image resolution acc@1 distilled model fine-tuned model
Swin-Base 300 EsViT-Base 224x224 85.1 google/config google/config

Vision Transformer

ImageNet-1K Distilled and Fine-tuned Models

name distillation epochs teacher model image resolution acc@1 distilled model fine-tuned model
ViT-Base 300 CLIP-Base 224x224 84.9 google/config google/config
ViT-Base 300 DINO-Base 224x224 83.8 google/config google/config
ViT-Base 300 DeiT-Base 224x224 83.0 google/config google/config
ViT-Large 300 CLIP-Large 224x224 87.7 google/config google/config

Citation

If you find our work useful in your research, please cite:

@article{wei2022FD,
  title={Contrastive Learning Rivals Masked Image Modeling in Fine-tuning via Feature Distillation},
  author={Yixuan Wei and Han Hu and Zhenda Xie and Zheng Zhang and Yue Cao and Jianmin Bao and Dong Chen and Baining Guo},
  journal={Tech Report},
  year={2022}
}

Getting Started

Installation

  • Install CUDA 11.3 with cuDNN 8 following the official installation guide of CUDA and cuDNN.

  • Setup conda environment:

# Create environment
conda create -n FD python=3.8 -y
conda activate FD

# Install requirements
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113

# Clone codes
git clone https://github.com/SwinTransformer/Feature-Distillation
cd Feature-Distillation

# Install other requirements
pip install -r requirements.txt

Feature-Distillation

To distill models, run:

python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> main_fd.py \ 
--cfg <config-file> --data-path <imagenet-path>/train [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]

For example, to distill CLIP-Base for 300 epochs on one DGX-2 server, run:

python -m torch.distributed.launch --nproc_per_node=16 main_fd.py --cfg configs/pretrain/fd_pretrain__clip_vit_base__img224__300ep.yaml --batch-size 128 --data-path <imagenet-path>/train [--output <output-directory> --tag <job-tag>]

If you want to save gpu memory consumption, add --use-checkpoint.

Fine-tuning distilled models

To fine-tune distilled models, run:

python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> main_finetune.py \ 
--cfg <config-file> --data-path <imagenet-path> --pretrained <pretrained-ckpt> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]

For example, to fine-tune Distilled-CLIP-Base on one DGX-2 server, run:

python -m torch.distributed.launch --nproc_per_node 16 main_finetune.py \ 
--cfg configs/finetune/fd_finetune__clip_vit_base__img224__300ep.yaml --batch-size 128 --data-path <imagenet-path> --pretrained <pretrained-ckpt> [--output <output-directory> --tag <job-tag>]