Skip to content

[CAAI AIR'24] Bilateral Reference for High-Resolution Dichotomous Image Segmentation

License

Notifications You must be signed in to change notification settings

ZhengPeng7/BiRefNet

Repository files navigation

Bilateral Reference for High-Resolution Dichotomous Image Segmentation

Peng Zheng 1,4,5,6,  Dehong Gao 2,  Deng-Ping Fan 1*,  Li Liu 3,  Jorma Laaksonen 4,  Wanli Ouyang 5,  Nicu Sebe 6
1 Nankai University  2 Northwestern Polytechnical University  3 National University of Defense Technology 
4 Aalto University  5 Shanghai AI Laboratory  6 University of Trento 
DIS-Sample_1 DIS-Sample_2

This repo is the official implementation of "Bilateral Reference for High-Resolution Dichotomous Image Segmentation" (CAAI AIR 2024).

Note

We need more GPU resources to push forward the performance of BiRefNet, especially on video tasks and more efficient model designs on higher-resolution images. If you are happy to cooperate, please contact me at [email protected].

News 📰

  • Jan 6, 2025: Validate the success of FP16 inference with ~0 decrease of performance and better efficiency: the standard BiRefNet can run in 17 FPS with resolution==1024x1024 with 3.45GB GPU memory on a single RTX 4090. Check more details in the model efficiency part below in model zoo section.
  • Dec 5, 2024: Fix the bug of using torch.compile in latest PyTorch versions (2.5.1) and the slow iteration in FP16 training with accelerate (set as default).
  • Nov 28, 2024: Congrats to students @Nankai University employed BiRefNet to build their project and won the provincial gold medal and national bronze medal on the China International College Students’ Innovation Competition 2024.
  • Oct 26, 2024: We added the guideline of conducting fine-tuning on custom data with existing weights.
  • Oct 6, 2024: We uploaded the BiRefNet-matting model for general trimap-free matting use.
  • Sep 24, 2024: We uploaded the BiRefNet_lite-2K model, which takes inputs in a much higher resolution (2560x1440). We also added the notebook for inference on videos.
  • Sep 7, 2024: Thanks to Freepik for supporting me with GPUs for more extensive experiments, especially on BiRefNet for 2K inference!
  • Aug 30, 2024: We uploaded notebooks in tutorials to run the inference and ONNX conversion locally.
  • Aug 23, 2024: Our BiRefNet is now officially released online on CAAI AIR journal. And thanks to the press release.
  • Aug 19, 2024: We uploaded the ONNX model files of all weights in the GitHub release and GDrive folder. Check out the ONNX conversion part in model zoo for more details.
  • Jul 30, 2024: Thanks to @not-lain for his kind efforts in adding BiRefNet to the official huggingface.js repo.
  • Jul 28, 2024: We released the Colab demo for box-guided segmentation.
  • Jul 15, 2024: We deployed our BiRefNet on Hugging Face Models for users to easily load it in one line code.
  • Jun 21, 2024: We released and uploaded the Chinese version of our original paper to my GDrive.
  • May 28, 2024: We hold a model zoo with well-trained weights of our BiRefNet in different sizes and for different tasks, including general use, matting segmentation, DIS, HRSOD, COD, etc.
  • May 7, 2024: We also released the Colab demo for multiple images inference. Many thanks to @rishabh063 for his support on it.
  • Apr 9, 2024: Thanks to Features and Labels Inc. for deploying a cool online BiRefNet inference API and providing me with strong GPU resources for 4 months on more extensive experiments!
  • Mar 7, 2024: We released BiRefNet codes, the well-trained weights for all tasks in the original papers, and all related stuff in my GDrive folder. Meanwhile, we also deployed our BiRefNet on Hugging Face Spaces for easier online use and released the Colab demo for inference and evaluation.
  • Jan 7, 2024: We released our paper on arXiv.

🚀 Load BiRefNet in ONE LINE by HuggingFace, check more: BiRefNet

from transformers import AutoModelForImageSegmentation
birefnet = AutoModelForImageSegmentation.from_pretrained('zhengpeng7/BiRefNet', trust_remote_code=True)

🛬 Inference Partner:

You can access the inference API service of BiRefNet on FAL or click the Deploy button on our HF model page to set up your own deployment.

Our BiRefNet has achieved SOTA on many similar HR tasks:

DIS: PWC PWC PWC PWC PWC

Figure of Comparison on DIS Papers with Codes (by the time of this work):

COD:PWC PWC PWC PWC

Figure of Comparison on COD Papers with Codes (by the time of this work):

HRSOD: PWC PWC PWC PWC PWC

Figure of Comparison on HRSOD Papers with Codes (by the time of this work):

Try our online demos for inference:

  • Inference and evaluation of your given weights: Open In Colab
  • Online Inference with GUI with adjustable resolutions: Hugging Face Spaces
  • Online Multiple Images Inference on Colab: Open In Colab

Model Zoo

For more general use of our BiRefNet, I extended the original academic one to more general ones for better real-life application.

Datasets and datasets are suggested to be downloaded from official pages. But you can also download the packaged ones: DIS, HRSOD, COD, Backbones.

Find performances (almost all metrics) of all models in the exp-TASK_SETTINGS folders in [stuff].

Models in the original paper, for comparison on benchmarks:
Task Training Sets Backbone Download
DIS DIS5K-TR swin_v1_large google-drive
COD COD10K-TR, CAMO-TR swin_v1_large google-drive
HRSOD DUTS-TR swin_v1_large google-drive
HRSOD DUTS-TR, HRSOD-TR swin_v1_large google-drive
HRSOD DUTS-TR, UHRSD-TR swin_v1_large google-drive
HRSOD HRSOD-TR, UHRSD-TR swin_v1_large google-drive
HRSOD DUTS-TR, HRSOD-TR, UHRSD-TR swin_v1_large google-drive
Models trained with customed data (general, matting), for general use in practical application:
Task Training Sets Backbone Test Set Metric (S, wF[, HCE]) Download
general use DIS5K-TR,DIS-TEs, DUTS-TR_TE,HRSOD-TR_TE,UHRSD-TR_TE, HRS10K-TR_TE, TR-P3M-10k, TE-P3M-500-NP, TE-P3M-500-P, TR-humans swin_v1_large DIS-VD 0.911, 0.875, 1069 google-drive
general use DIS5K-TR,DIS-TEs, DUTS-TR_TE,HRSOD-TR_TE,UHRSD-TR_TE, HRS10K-TR_TE, TR-P3M-10k, TE-P3M-500-NP, TE-P3M-500-P, TR-humans swin_v1_tiny DIS-VD 0.882, 0.830, 1175 google-drive
general use DIS5K-TR, DIS-TEs swin_v1_large DIS-VD 0.907, 0.865, 1059 google-drive
general matting P3M-10k (except TE-P3M-500-NP), TR-humans, AM-2k, AIM-500, Human-2k (synthesized with BG-20k), Distinctions-646 (synthesized with BG-20k), HIM2K, PPM-100 swin_v1_large TE-P3M-500-NP 0.979, 0.988 google-drive
portrait matting P3M-10k, humans swin_v1_large P3M-500-P 0.983, 0.989 google-drive
Segmentation with box guidance:
  • Given box guidance: Open In Colab
Model efficiency:

Screenshot from the original paper. All tests here are conducted on a single A100 GPU.

The devices used in the below table differ from those in the original paper (the standard). So, it's only for reference.

Runtime FP32 FP16
A100 86.8ms 69.4ms
4090 95.8ms 57.7ms
V100 384ms 152ms
GPU Memory FP32 FP16
Inference 4.76GB 3.45GB
Training (#GPU=1, batch_size=2, compile=False+PyTorch=2.5.1) 36.3GB 30.4GB
Training (#GPU=1, batch_size=2, compile=True+PyTorch=2.5.1) 35.9GB 24.9GB
ONNX conversion:

We converted from .pth weights files to .onnx files.
We referred a lot to the Kazuhito00/BiRefNet-ONNX-Sample, many thanks to @Kazuhito00.

  • Check our Colab demo for ONNX conversion or the notebook file for local running, where you can do the conversion/inference by yourself and find all relevant info.
  • As tested, BiRefNets with SwinL (default backbone) cost ~90% more time (the inference costs ~165ms on an A100 GPU) using ONNX files. Meanwhile, BiRefNets with SwinT (lightweight) cost ~75% more time (the inference costs ~93.8ms on an A100 GPU) using ONNX files. Input resolution is 1024x1024 as default.
  • The results of the original pth files and the converted onnx files are slightly different, which is acceptable.
  • Pay attention to the compatibility among onnxruntime-gpu, CUDA, and CUDNN (we use torch==2.0.1, cuda=11.8 here).

Third-Party Creations

We found there've been some 3rd party applications based on our BiRefNet. Many thanks for their contribution to the community!
Choose the one you like to try with clicks instead of codes:

  1. Applications:

    • Thanks tin2tin/2D_Asset_Generator: this project combined BiRefNet and FLUX as a Blender add-on for "AI generating 2D cutout assets for ex. previz".

      0001-0441.mp4
    • Thanks camenduru/text-behind-tost: this project employed BiRefNet to extract foreground subjects and add texts between the subjects and background, which looks amazing especially for videos. Check their tweets for more examples.

    • Thanks briaai/RMBG-2.0: this project trained BiRefNet with their high-quality private data, which brings improvement on the DIS task. Note that their weights are for only non-commercial use and are not aware of transparency due to training in the DIS task setting, which focuses only on predicting binary masks.

    • Thanks lldacing/ComfyUI_BiRefNet_ll: this project further upgrade the ComfyUI node for BiRefNet with both our latest weights and the legacy ones.

    • Thanks MoonHugo/ComfyUI-BiRefNet-Hugo: this project further upgrade the ComfyUI node for BiRefNet with our latest weights.

    • Thanks lbq779660843/BiRefNet-Tensorrt and yuanyang1991/birefnet_tensorrt: they both provided the project to convert BiRefNet to TensorRT, which is faster and better for deployment. Their repos offer solid local establishment (Win and Linux) and colab demo, respectively. And @yuanyang1991 kindly offered the comparison among the inference efficiency of naive PyTorch, ONNX, and TensorRT on an RTX 4080S:

Methods Pytorch ONNX TensorRT
       First Inference Time       0.71s 5.32s 0.17s
Methods Pytorch ONNX TensorRT
Avg Inf Time (excluding 1st) 0.15s 4.43s 0.11s
  1. More Visual Comparisons

    video-from_twitter_toyxyz3_2.mp4
    video-from_twitter_toyxyz3_1.mp4

Usage

Environment Setup

# PyTorch==2.5.1+CUDA12.4 (or 2.0.1+CUDA11.8) is used for faster training (~40%) with compilation.
conda create -n birefnet python=3.10 -y && conda activate birefnet
pip install -r requirements.txt

Dataset Preparation

Download combined training / test sets I have organized well from: DIS--COD--HRSOD or the single official ones in the single_ones folder, or their official pages. You can also find the same ones on my BaiduDisk: DIS--COD--HRSOD.

Weights Preparation

Download backbone weights from my google-drive folder or their official pages.

Run

# Train & Test & Evaluation
./train_test.sh RUN_NAME GPU_NUMBERS_FOR_TRAINING GPU_NUMBERS_FOR_TEST
# Example: ./train_test.sh tmp-proj 0,1,2,3,4,5,6,7 0

# See train.sh / test.sh for only training / test-evaluation.
# After the evaluation, run `gen_best_ep.py` to select the best ckpt from a specific metric (you choose it from Sm, wFm, HCE (DIS only)).

🖊️ Fine-tuning on Custom Data

Guideline:

Suppose you have some custom data, fine-tuning on it tends to bring improvement.

  1. Pre-requisites: you have put your datasets in the path ${data_root_dir}/TASK_NAME/DATASET_NAME. For example, ${data_root_dir}/DIS5K/DIS-TR and ${data_root_dir}/General/TR-HRSOD, where im and gt are both in each dataset folder.
  2. Change an existing task to your custom one: replace all 'General' (with single quotes) in the whole project with your custom task name as the screenshot of vscode given below shows:
  3. Adapt settings:
    • sys_home_dir: path to the root folder, which contains codes / datasets / weights / ... -- project folder / data folder / backbone weights folder are ${sys_home_dir}/codes/dis/BiRefNet / ${sys_home_dir}/datasets/dis/General / ${sys_home_dir}/weights/cv/swin_xxx, respectively.
    • testsets: your validation set.
    • training_set: your training set.
    • lambdas_pix_last: adapt the weights of different losses if you want, especially for the difference between segmentation (classification task) and matting (regression task).
  4. Use existing weights: if you want to use some existing weights to fine-tune that model, please refer to the resume argument in train.py. Attention: the epoch of training continues from the epochs the weights file name indicates (e.g., 244 in BiRefNet-general-epoch_244.pth), instead of 1. So, if you want to fine-tune 50 more epochs, please specify the epochs as 294. \#Epochs, \#last epochs for validation, and validation step are set in train.sh.
  5. Good luck to your training :) If you still have questions, feel free to leave issues (recommended way) or contact me.

Well-trained weights:

Download the BiRefNet-{TASK}-{EPOCH}.pth from [stuff] and the release page of this repo. Info of the corresponding (predicted_maps/performance/training_log) weights can be also found in folders like exp-BiRefNet-{TASK_SETTINGS} in the same directory.

You can also download the weights from the release of this repo.

The results might be a bit different from those in the original paper, you can see them in the eval_results-BiRefNet-{TASK_SETTINGS} folder in each exp-xx, we will update them in the following days. Due to the very high cost I used (A100-80G x 8), which many people cannot afford (including myself....), I re-trained BiRefNet on a single A100-40G only and achieved the performance on the same level (even better). It means you can directly train the model on a single GPU with 36.5G+ memory. BTW, 5.5G GPU memory is needed for inference in 1024x1024. (I personally paid a lot for renting an A100-40G to re-train BiRefNet on the three tasks... T_T. Hope it can help you.)

But if you have more and more powerful GPUs, you can set GPU IDs and increase the batch size in config.py to accelerate the training. We have made all these kinds of things adaptive in scripts to seamlessly switch between single-card training and multi-card training. Enjoy it :)

Some of my messages:

This project was originally built for DIS only. But after the updates one by one, I made it larger and larger with many functions embedded together. Finally, you can use it for any binary image segmentation tasks, such as DIS/COD/SOD, medical image segmentation, anomaly segmentation, etc. You can eaily open/close below things (usually in config.py):

  • Multi-GPU training: open/close with one variable.
  • Backbone choices: Swin_v1, PVT_v2, ConvNets, ...
  • Weighted losses: BCE, IoU, SSIM, MAE, Reg, ...
  • Training tricks: multi-scale supervision, freezing backbone, multi-scale input...
  • Data collator: loading all in memory, smooth combination of different datasets for combined training and test.
  • ... I really hope you enjoy this project and use it in more works to achieve new SOTAs.

Quantitative Results

Qualitative Results

Acknowledgement:

Many of my thanks to the companies / institutes below.

Citation

@article{zheng2024birefnet,
  title={Bilateral Reference for High-Resolution Dichotomous Image Segmentation},
  author={Zheng, Peng and Gao, Dehong and Fan, Deng-Ping and Liu, Li and Laaksonen, Jorma and Ouyang, Wanli and Sebe, Nicu},
  journal={CAAI Artificial Intelligence Research},
  volume = {3},
  pages = {9150038},
  year={2024}
}

Contact

Any questions, discussions, or even complaints, feel free to leave issues here (recommended) or send me e-mails ([email protected]) or book a meeting with me: calendly.com/zhengpeng0108/30min. You can also join the Discord Group (https://discord.gg/d9NN5sgFrq) if you want to talk a lot publicly.