This repository contains the implementation of the NeurIPS 2023 paper:
Parameter and Computation Efficient Transfer Learning for Vision-Language Pre-trained Models [Paper]
Qiong Wu12, Wei Yu12, Yiyi Zhou12, Shubin Huang1, Xiaoshuai Sun12, Rongrong Ji12 1Media Analytics and Computing Lab, Department of Artificial Intelligence, School of Informatics, Xiamen University
2Institute of Artificial Intelligence, Xiamen University
In this paper, we aim at parameter and computation efficient transfer learning (PCETL) for VLP models. In particular, PCETL not only needs to limit the number of trainable parameters in VLP models, but also to reduce the computational redundancy during inference, thus enabling a more efficient transfer. To approach this target, we propose a novel dynamic architecture skipping (DAS) approach towards effective PCETL. DAS first observes the significances of their modules to downstream tasks via a reinforcement learning (RL) based process, and then skips the redundant ones with lightweight networks, i.e., adapters, according to the obtained rewards.
pip install -r requirements.txt
pip install -e .
We follow ViLT and use pyarrow
to serialize the datasets. See this link for details.
cd LaVIN-DAS
conda create -n lavin python=3.8 -y
conda activate lavin
# install pytorch
conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 -c pytorch
# install dependency and lavin
pip install -r requirements.txt
pip install -e .
Obtain the weights of LLaMA from this form (official) or Download LLaMA-7B
For ScienceQA, please prepare the dataset from the official repo.
For BoolQ, CommonSenseQA and gsm8k, please run:
pip install datasets
python OrgBoolQ.py
python OrgCommonSenseQA.py
python OrgGSM8K.py
The file structure should look like:
LaVIN-DAS/
|-- das
|-- scripts
|-- train.py
|-- eval.py
......
data/
|-- problem.json
|-- pid_splits.json
|-- captions.json
|-- all_data.json
|-- images
|-- train # ScienceQA train image
|-- val # ScienceQA val image
|-- test # ScienceQA test image
|-- weights
|-- tokenizer.model
|--7B
|-- params.json
|-- consolidated.00.pth
......
|-- BoolQ
|-- boolq_0_shot_test.json
|-- GSM8K
|-- gsm8k_0_shot_test.json
|-- CommonSenseQA
|-- commonsense_qa_0_shot_test.json
Work on the METER:
cd METER
Work on the ViLT:
cd ViLT
sh script/vqa_search.sh
Add search result to vqa_train.sh by additional parameter 'skip_module'.
sh script/vqa_train.sh
Add the path of checkpoint and 'skip_module' to vqa_eval.sh.
sh script/vqa_eval.sh
sh script/F30K_search.sh
Add search result to F30K_train.sh by additional parameter 'skip_module'.
sh script/F30K_train.sh
Add the path of checkpoint and 'skip_module' to F30K_eval.sh.
sh script/F30K_eval.sh
sh script/nlvr_search.sh
Add search result to F30K_train.sh by additional parameter 'skip_module'.
sh script/nlvr_train.sh
Add the path of checkpoint and 'skip_module' to nlvr_eval.sh.
sh script/nlvr_eval.sh
We also evaluate the experiment results on SceinceQA following LaVIN
Table 1: Comparison of DAS and PETL methods on ScienceQA for LLaMA.
Method | Update Params | Inference Time | Modality Natural | Modality Social | Modality Language | Context Text | Context Image | Context No | Grade G1-6 | Grade G7-12 | Avg |
---|---|---|---|---|---|---|---|---|---|---|---|
LaVIN-7B | 3.8M | 3.70s | 89.25 | 94.94 | 85.24 | 88.51 | 87.46 | 88.08 | 90.16 | 88.07 | 89.41 |
DAS2-7B | 4.2M | 3.44s | 88.68 | 94.94 | 86.45 | 88.03 | 86.81 | 88.92 | 90.20 | 88.00 | 89.41 |
DAS4-7B | 4.6M | 3.23s | 88.99 | 94.60 | 85.09 | 87.88 | 86.51 | 88.36 | 89.72 | 88.13 | 89.15 |
DAS6-7B | 5.0M | 3.06s | 87.30 | 93.36 | 82.36 | 86.12 | 85.97 | 85.71 | 88.18 | 85.70 | 87.29 |
To search and finetuning the LLaMA based tasks, run:
cd LaVIN-DAS
sh scripts/{task_type}_{benchmark}_7b.sh
The task_type includes evaluate, finetuning and search.
The benchmark includes boolq (BoolQ), csqa (CommonSenceQA), gsm8k (GSM8K) and sqa (ScienceQA).
The code is based on ViLT licensed under Apache 2.0 and METER licensed under MIT and some of the code is borrowed from CLIP and Swin-Transformer.