Learning Visual Representations via Language-Guided Sampling
Mohamed El Banani, Karan Desai, and Justin Johnson
If you have any questions, please feel free to email me at [email protected].
We recommend using Anaconda or Miniconda. To setup the environment, follow the instructions below.
conda create -n lgssl python=3.8 --yes
conda activate lgssl
conda install pytorch=1.12.1 torchvision cudatoolkit=11.3 -c pytorch --yes
python -m pip install -r requirements.txt
python setup.py develop
Expand
We train our models on RedCaps and ConceptualCaptions (CC3M and CC12M). We note that all 3 datasets can decay, so you might end up with a different number of instances. Please refer to the original papers for dataset download instructions. In our case, the datasets had the following sizes:
Dataset | Size |
---|---|
RedCaps-2020 | 3273223 |
RedCaps | 12010494 |
CC3M | 2913035 |
CC12M | 10958691 |
We assume all training datasets are in data/datasets
which is set as the default data_root
in
the base dataset class. We expect the dataset to be in the format
below where each dataset is subdivided into several directories and each directory contains a set of
instances where each instance has an image file and a json caption file.
data/datasets/<dataset_name>
|- directory_0
|- <instance_0>.jpg <- image for instance 0
|- <instance_0>.json <- caption for instance 0
|- <instance_1>.jpg
|- <instance_1>.json
...
|- <instance_n>.jpg
|- <instance_n>.json
|- directory_1
|- directory_2
...
|- directory_m
For RedCaps, the directory names are encoded as <subreddit>_<year>_<id>
, e.g.,
crochet_2017_000001
, where each directory only has 10000 classes. We use this naming convention
for some of the experiments: experiments with redcaps-2020 and sampling scope.
We create dataset specific dictionaries that contain the information for each dataset (eg, image
paths, captions) which allow for easy sampling in subsequent steps. To generate a dataset
dictionary, run the following code where <dataset_name>
is the name of the dataset repo in
data/datasets
.
cd preprocess
python make_imagecaption_dict.py <dataset_name>
Once we have the dataset dictionaries, we can easily sample nearest neighbor pairs. We provide the code for sampling using language or visual embeddings. We also provide the sampling based on dataset subsets for the experiments reported in supplementary. Check the commands below for language sampling based on SBERT, visual sampling based on an ImageNet pretrained model, and language sampling within each subreddit.
python sample_language_nn.py <dataset_name> all-mpnet-base-v2 # Language - MPNet (SBERT)
python sample_language_nn_subsets.py <dataset_name> all-mpnet-base-v2 subreddit # Language Subset - MPNet (SBERT) on subreddits
python sample_visual_nn.py <dataset_name> vit_b_32 IMAGENET1K_V1 # Visual - ImageNet-supervised ViT-B/32
Expand
We use TensorFlow Datasets for our evaluations.
This package provides us with all the evaluations except for FGVC Aircraft.
Our code will automatically download and extract all the
datasets in data/evaluation_datasets
on the first run of the evaluation code.
This means that the first evaluation run will be much slower than usual.
Note 1: We encountered a bug with SUN 397 where one image could not be decoded correctly. This is a known bug which has not been fixed yet in the stable version. To fix it, simply make the two changes outlined by this commit.
Note 2: TensorFlow Datasets will require you to independently downloaded RESISC45. Please follow the instructions provided here
We use hydra configs for our training experiments. The configs can all be found here. To run an experiment, you can either to define a new experiment config which can be used to override the default configs. Alternatively, you can just overwrite some configs in the command. We provide a few sample training commands configs for clarity:
python train.py +experiment=ours % LG SimCLR
python train.py +experiment=vis_baseline % SimCLR
python train.py +experiment=vis_baseline model=simsiam % SimSiam
We use two primary evaluations: linear probe using L-BFGS and few-shot evaluation. The configs for those evaluations can be found here.
Linear Probe: we train a single layer using logistic regression and sweep over regualizer weight values.
We provide an implementation of logistic regression using PyTorch's L-BFGS, however, you can easily use scikit-learn's implementation by setting the use_sklearn
flag in the evaluation configs.
For datasets without a standard validation split, we randomly split the training set while maintaining the class distribution.
Few-Shot Evaluation: we also evaluate our frozen features on 5-shot, 5-way classification. The evaluation can be found here. We sample the training samples from the train/valid splits and the query samples for the test set.
The following commands can be used to evaluate checkpoints or baselines. For example, you can evaluate our model or the pretrained SimCLR checkpoint on all the datasets by running the following commands:
python evaluate.py model.name=lgssl_checkpoints model.checkpoint=lgsimclr dataset.name=all
python evaluate.py model.name=simclr dataset.name=all
You can find all our pretrained checkpoints
here. You should
download them to data/checkpoints
. Alternatively, you could just use hubconf to get the relevant
checkpoint as shown in the code snippet below:
import torch
model = torch.hub.load("mbanani/lgssl", "lgsimclr")
For a list of released models, check hubconf.py
If you find this code useful, please consider citing:
@inproceedings{elbanani2022languageguided,
title={{Learning Visual Representations via Language-Guided Sampling}},
author={El Banani, Mohamed and Desai, Karan and Johnson, Justin},
booktitle={CVPR},
year={2023},
}
We thank Richard Higgins, Ashkan Kazemi, and Santiago Castro for many helpful discussions. We also thank David Fouhey, Ziyang Chen, Chenhao Zheng, and Fahad Kamran, and Dandan Shan for their feedback on early drafts. This project was funded under the Ford-UM Alliance partnership. We thank Alireza Rahimpour, Devesh Upadhyay, and Ali Hassani from Ford Research for their support and discussion.