Skip to content

[NeurIPS 2023] Official implementations of "Parameter and Computation Efficient Transfer Learning for Vision-Language Pre-trained Models"

Notifications You must be signed in to change notification settings

DoubtedSteam/DAS

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

This repository contains the implementation of the NeurIPS 2023 paper:

Parameter and Computation Efficient Transfer Learning for Vision-Language Pre-trained Models [Paper]
Qiong Wu12, Wei Yu12, Yiyi Zhou12, Shubin Huang1, Xiaoshuai Sun12, Rongrong Ji12 1Media Analytics and Computing Lab, Department of Artificial Intelligence, School of Informatics, Xiamen University
2Institute of Artificial Intelligence, Xiamen University

In this paper, we aim at parameter and computation efficient transfer learning (PCETL) for VLP models. In particular, PCETL not only needs to limit the number of trainable parameters in VLP models, but also to reduce the computational redundancy during inference, thus enabling a more efficient transfer. To approach this target, we propose a novel dynamic architecture skipping (DAS) approach towards effective PCETL. DAS first observes the significances of their modules to downstream tasks via a reinforcement learning (RL) based process, and then skips the redundant ones with lightweight networks, i.e., adapters, according to the obtained rewards.



Setup

Install for ViLT and METER

pip install -r requirements.txt
pip install -e .

Dataset Preparation for ViLT and METER

We follow ViLT and use pyarrow to serialize the datasets. See this link for details.

Install for LaVIN

cd LaVIN-DAS
conda create -n lavin python=3.8 -y
conda activate lavin

# install pytorch
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 -c pytorch

# install dependency and lavin
pip install -r requirements.txt
pip install -e .

Preparation for LaVIN

Obtain the weights of LLaMA from this form (official) or Download LLaMA-7B

For ScienceQA, please prepare the dataset from the official repo.

For BoolQ, CommonSenseQA and gsm8k, please run:

pip install datasets
python OrgBoolQ.py
python OrgCommonSenseQA.py
python OrgGSM8K.py

The file structure should look like:

LaVIN-DAS/
  |-- das
  |-- scripts
  |-- train.py
  |-- eval.py
  ......
data/
  |-- problem.json
  |-- pid_splits.json
  |-- captions.json
  |-- all_data.json
  |-- images
      |-- train          # ScienceQA train image
      |-- val            # ScienceQA val image
      |-- test           # ScienceQA test image
  |-- weights
      |-- tokenizer.model
          |--7B
              |-- params.json
              |-- consolidated.00.pth
          ......
  |-- BoolQ
      |-- boolq_0_shot_test.json
  |-- GSM8K
      |-- gsm8k_0_shot_test.json
  |-- CommonSenseQA
      |-- commonsense_qa_0_shot_test.json

Fine-tuning on Downstream Tasks

Work on the METER:

cd METER

Work on the ViLT:

cd ViLT

VQAv2

Search

sh script/vqa_search.sh

Train

Add search result to vqa_train.sh by additional parameter 'skip_module'.

sh script/vqa_train.sh

Evaluate

Add the path of checkpoint and 'skip_module' to vqa_eval.sh.

sh script/vqa_eval.sh

Flickr30k IR/TR

Search

sh script/F30K_search.sh

Train

Add search result to F30K_train.sh by additional parameter 'skip_module'.

sh script/F30K_train.sh

Evaluate

Add the path of checkpoint and 'skip_module' to F30K_eval.sh.

sh script/F30K_eval.sh

NLVR2

Search

sh script/nlvr_search.sh

Train

Add search result to F30K_train.sh by additional parameter 'skip_module'.

sh script/nlvr_train.sh

Evaluate

Add the path of checkpoint and 'skip_module' to nlvr_eval.sh.

sh script/nlvr_eval.sh

ScienceQA

We also evaluate the experiment results on SceinceQA following LaVIN

Experiments results

Table 1: Comparison of DAS and PETL methods on ScienceQA for LLaMA.

Method Update Params Inference Time Modality Natural Modality Social Modality Language Context Text Context Image Context No Grade G1-6 Grade G7-12 Avg
LaVIN-7B 3.8M 3.70s 89.25 94.94 85.24 88.51 87.46 88.08 90.16 88.07 89.41
DAS2-7B 4.2M 3.44s 88.68 94.94 86.45 88.03 86.81 88.92 90.20 88.00 89.41
DAS4-7B 4.6M 3.23s 88.99 94.60 85.09 87.88 86.51 88.36 89.72 88.13 89.15
DAS6-7B 5.0M 3.06s 87.30 93.36 82.36 86.12 85.97 85.71 88.18 85.70 87.29

LLaMA Based Tasks

To search and finetuning the LLaMA based tasks, run:

cd LaVIN-DAS
sh scripts/{task_type}_{benchmark}_7b.sh

The task_type includes evaluate, finetuning and search.

The benchmark includes boolq (BoolQ), csqa (CommonSenceQA), gsm8k (GSM8K) and sqa (ScienceQA).

Acknowledgements

The code is based on ViLT licensed under Apache 2.0 and METER licensed under MIT and some of the code is borrowed from CLIP and Swin-Transformer.

About

[NeurIPS 2023] Official implementations of "Parameter and Computation Efficient Transfer Learning for Vision-Language Pre-trained Models"

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published