From cbddff5066176fb00692e56a56756ece40c119f1 Mon Sep 17 00:00:00 2001
From: Carol-gutianle <2311601671@qq.com>
Date: Fri, 28 Jul 2023 19:12:52 +0800
Subject: [PATCH 1/2] fix: local paths
---
README.md | 14 +-
README_EN.md | 1 +
examples/alpaca/train.py | 9 +-
examples/finetune_chatglm2_for_summary.py | 2 +-
examples/finetune_chatglm_for_translation.py | 9 +-
examples/finetune_llama_for_classification.py | 24 +--
examples/finetune_llama_for_summary.py | 5 +-
examples/finetune_llama_for_translation.py | 11 +-
examples/finetune_moss_for_training.py | 10 +-
.../further_pretrain_llama/expand_vocab.py | 9 +-
.../3d_parallelism.py | 10 +-
examples/peft/finetune_llama_ptuning.py | 150 ++++++++++++++++++
examples/server.py | 10 +-
13 files changed, 191 insertions(+), 73 deletions(-)
create mode 100644 examples/peft/finetune_llama_ptuning.py
diff --git a/README.md b/README.md
index e35d1105..577ed7b0 100644
--- a/README.md
+++ b/README.md
@@ -113,9 +113,21 @@ CoLLiE 基于 *DeepSpeed* 和 *PyTorch*,为大型语言模型提供协作式
注:在使用Adam优化器的情况下,各个模型需要的最少的GPU(A100)数量
## 安装
+在安装前,你需要确保:
+* PyTorch >= 1.13
+* CUDA >= 11.6
+* Linux OS
+### PyPI安装
+你可以简单地通过PyPI安装,命令如下:
```bash
-pip install git+https://github.com/OpenLMLab/collie.git
+pip install collie-lm
```
+### 源码安装
+```bash
+git clone https://github.com/OpenLMLab/collie
+python setup.py install
+```
+
## Docker安装
## 使用
diff --git a/README_EN.md b/README_EN.md
index 3f3a3f7f..57d15ff6 100644
--- a/README_EN.md
+++ b/README_EN.md
@@ -19,6 +19,7 @@ CoLLiE (Collaborative Tuning of Large Language Models in an Efficient Way) is a
## Latest News
+* [2023/07/18] Release python package collie-lm(1.0.2). You can find more detials in this [link](https://pypi.org/project/collie-lm/#history).
## Table of Contents
diff --git a/examples/alpaca/train.py b/examples/alpaca/train.py
index a39a193b..cded3d73 100644
--- a/examples/alpaca/train.py
+++ b/examples/alpaca/train.py
@@ -63,14 +63,7 @@
eval_dataset = dataset[-32:]
# 5. 加载预训练模型
-model = LlamaForCausalLM(config)
-state_dict = LlamaForCausalLM.load_parallel_state_dict(
- path="hdd:s3://opennlplab_hdd/models/llama/llama-7b-hf",
- config=config,
- protocol="petrel",
- format="hf"
-)
-model.load_state_dict(state_dict)
+model = LlamaForCausalLM.from_config(config)
# 6. 设置优化器
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
diff --git a/examples/finetune_chatglm2_for_summary.py b/examples/finetune_chatglm2_for_summary.py
index 3e39e217..bfc6c13f 100644
--- a/examples/finetune_chatglm2_for_summary.py
+++ b/examples/finetune_chatglm2_for_summary.py
@@ -1,5 +1,5 @@
import sys
-sys.path.append("../")
+sys.path.append("..")
from typing import Dict
import argparse
diff --git a/examples/finetune_chatglm_for_translation.py b/examples/finetune_chatglm_for_translation.py
index a366ef03..9ac7c7f9 100644
--- a/examples/finetune_chatglm_for_translation.py
+++ b/examples/finetune_chatglm_for_translation.py
@@ -47,16 +47,14 @@
} for sample in load_dataset("iwslt2017", name="iwslt2017-fr-en", split="train[100:150]")
]
# Prepare model
-model = ChatGLMForCausalLM.from_pretrained(
- "/mnt/petrelfs/zhangshuo/model/chatglm-6b", config=config)
+model = ChatGLMForCausalLM.from_pretrained("THUDM/chatglm-6b", config=config)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
# optimizer=optimizer, T_max=config.train_epochs * len(train_dataset) / (config.train_micro_batch_size * config.gradient_accumulation_steps))
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer=optimizer, step_size=1, gamma=0.9
)
-tokenizer = AutoTokenizer.from_pretrained(
- "THUDM/chatglm-6b", trust_remote_code=True)
+tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
# 默认的tokenizer不会把[gMASK]当作一个token,所以需要手动添加
tokenizer.unique_no_split_tokens.append("[gMASK]")
# Convert to CoLLie Dataset
@@ -117,5 +115,4 @@
)),
evaluators=[evaluator_ppl, evaluator_bleu]
)
-trainer.train()
-# trainer.save_checkpoint(path="/mnt/petrelfs/zhangshuo/model/test_save_checkpoint", mode="model")
\ No newline at end of file
+trainer.train()
\ No newline at end of file
diff --git a/examples/finetune_llama_for_classification.py b/examples/finetune_llama_for_classification.py
index 1bf87e4b..2866a428 100644
--- a/examples/finetune_llama_for_classification.py
+++ b/examples/finetune_llama_for_classification.py
@@ -17,21 +17,12 @@
"fp16": {
"enabled": True
},
- # "monitor_config": {
- # "enabled": True,
- # "wandb": {
- # "enabled": True,
- # "team": "00index",
- # "project": "collie",
- # "group": "test_evaluator"
- # }
- # },
"zero_optimization": {
"stage": 3,
}
}
config.seed = 1024
-model = LlamaForCausalLM.from_pretrained("/mnt/petrelfs/zhangshuo/model/llama-7b-hf", config=config)
+model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", config=config)
# model = LlamaForCausalLM.from_config(config)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)
@@ -42,11 +33,6 @@
"output": "positive." if sample["label"] else "negative."
} for sample in load_dataset("imdb", split="train")
]
-# train_dataset = [
-# {
-# "text": f"Comment: {sample['text']}. The sentiment of this comment is: {'positive.' if sample['label'] else 'negative.'}",
-# } for sample in load_dataset("imdb", split="train")
-# ]
### Prepare perplexity evaluation dataset
ratio = 0.01
eval_dataset_ppl, train_dataset = train_dataset[:int(len(train_dataset) * ratio)], train_dataset[int(len(train_dataset) * ratio):]
@@ -60,11 +46,11 @@
][:1000]
### Convert to CoLLie Dataset
traine_dataset = CollieDatasetForTraining(train_dataset,
- tokenizer=LlamaTokenizer.from_pretrained("/mnt/petrelfs/zhangshuo/model/llama-7b-hf", add_eos_token=True))
+ tokenizer=LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf", add_eos_token=True))
eval_dataset_ppl = CollieDatasetForTraining(eval_dataset_ppl,
- tokenizer=LlamaTokenizer.from_pretrained("/mnt/petrelfs/zhangshuo/model/llama-7b-hf", add_eos_token=True))
+ tokenizer=LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf", add_eos_token=True))
eval_dataset_cls = CollieDatasetForClassification(eval_dataset_cls,
- tokenizer=LlamaTokenizer.from_pretrained("/mnt/petrelfs/zhangshuo/model/llama-7b-hf", add_eos_token=True))
+ tokenizer=LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf", add_eos_token=True))
### Prepare Evaluator
evaluator_ppl = EvaluatorForPerplexity(
model=model,
@@ -101,7 +87,7 @@
MemoryMonitor(config),
LRMonitor(config)
],
- data_provider=GradioProvider(LlamaTokenizer.from_pretrained("/mnt/petrelfs/zhangshuo/model/llama-7b-hf"), port=12300, stream=True),
+ data_provider=GradioProvider(LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf"), port=12300, stream=True),
evaluators=[evaluator_ppl, evaluator_cls]
)
trainer.train()
diff --git a/examples/finetune_llama_for_summary.py b/examples/finetune_llama_for_summary.py
index de905122..a6c241b8 100644
--- a/examples/finetune_llama_for_summary.py
+++ b/examples/finetune_llama_for_summary.py
@@ -1,5 +1,5 @@
import sys
-sys.path.append("../")
+sys.path.append("..")
from typing import Dict
import argparse
@@ -78,8 +78,7 @@ def load_data(path_dict):
tokenizer=tokenizer)
# Prepare model
-model = LlamaForCausalLM.from_pretrained(
- "decapoda-research/llama-7b-hf", config=config)
+model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", config=config)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer=optimizer, step_size=1, gamma=0.9
diff --git a/examples/finetune_llama_for_translation.py b/examples/finetune_llama_for_translation.py
index dca12431..316cc1f6 100644
--- a/examples/finetune_llama_for_translation.py
+++ b/examples/finetune_llama_for_translation.py
@@ -5,7 +5,8 @@
from transformers import LlamaTokenizer, GenerationConfig
from collie import Trainer, EvaluatorForPerplexity, LlamaForCausalLM, CollieConfig, PPLMetric, AccuracyMetric, DecodeMetric, CollieDatasetForTraining, CollieDatasetForGeneration, \
LossMonitor, TGSMonitor, MemoryMonitor, EvalMonitor, GradioProvider, EvaluatorForGeneration, LRMonitor, BleuMetric, DashProvider
-config = CollieConfig.from_pretrained("/mnt/petrelfs/zhangshuo/model/llama-7b-hf")
+
+config = CollieConfig.from_pretrained("decapoda-research/llama-7b-hf")
config.pp_size = 8
config.train_micro_batch_size = 1
config.eval_batch_size = 1
@@ -46,16 +47,12 @@
} for sample in load_dataset("iwslt2017", name="iwslt2017-fr-en", split="train[100:150]")
]
# Prepare model
-model = LlamaForCausalLM.from_pretrained(
- "/mnt/petrelfs/zhangshuo/model/llama-7b-hf", config=config)
+model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", config=config)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
-# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
-# optimizer=optimizer, T_max=config.train_epochs * len(train_dataset) / (config.train_micro_batch_size * config.gradient_accumulation_steps))
lr_scheduler = torch.optim.lr_scheduler.StepLR(
optimizer=optimizer, step_size=1, gamma=0.9
)
-tokenizer = LlamaTokenizer.from_pretrained(
- "/mnt/petrelfs/zhangshuo/model/llama-7b-hf", add_eos_token=False, add_bos_token=False)
+tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf", add_eos_token=False, add_bos_token=False)
# Convert to CoLLie Dataset
train_dataset = CollieDatasetForTraining(train_dataset,
tokenizer=tokenizer)
diff --git a/examples/finetune_moss_for_training.py b/examples/finetune_moss_for_training.py
index 33a949c5..8addb671 100644
--- a/examples/finetune_moss_for_training.py
+++ b/examples/finetune_moss_for_training.py
@@ -3,17 +3,11 @@
"""
import sys
sys.path.append('..')
-
-import os
-import json
-import torch
-
from transformers import AutoTokenizer
from collie.config import CollieConfig
from collie.data import CollieDatasetForTraining
-from collie.data import CollieDataLoader
from collie.optim.lomo import Lomo
@@ -23,11 +17,9 @@
from collie.models.moss_moon import Moss003MoonForCausalLM
from collie.utils.monitor import StepTimeMonitor, TGSMonitor, MemoryMonitor, LossMonitor, EvalMonitor
-from collie.metrics import DecodeMetric, PPLMetric, BleuMetric
+from collie.metrics import DecodeMetric, PPLMetric
from collie.module import GPTLMLoss
-from collie.utils import env
-
# 1. 设置路径
# 1.1 预训练模型路径
pretrained_model = "fnlp/moss-moon-003-sft"
diff --git a/examples/further_pretrain_llama/expand_vocab.py b/examples/further_pretrain_llama/expand_vocab.py
index a3cf3bb1..fb21ae09 100644
--- a/examples/further_pretrain_llama/expand_vocab.py
+++ b/examples/further_pretrain_llama/expand_vocab.py
@@ -30,13 +30,9 @@
"group": "further_pretrain_llama"
}
},
- # "zero_optimization": {
- # "stage": 1,
- # }
}
# 合并词表
-llama_tokenizer = LlamaTokenizer.from_pretrained(
- "/mnt/petrelfs/zhangshuo/model/llama-7b-hf")
+llama_tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
chinese_sp_model = spm.SentencePieceProcessor()
chinese_sp_model.Load("./chinese_sp.model")
llama_spm = sp_pb2_model.ModelProto()
@@ -72,8 +68,7 @@
eval_dataset, train_dataset = dataset[:int(
len(dataset) * ratio)], dataset[int(len(dataset) * ratio):]
# 准备模型并调整 embedding 层大小,设置只训练 embedding 和 lm_head 层,加速收敛
-model = LlamaForCausalLM.from_pretrained(
- "/mnt/petrelfs/zhangshuo/model/llama-7b-hf", config=config)
+model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", config=config)
model.resize_token_embeddings(len(llama_tokenizer) + 7) # 取个整
for p in model.parameters():
p.requires_grad = False
diff --git a/examples/one_sentence_overfitting/3d_parallelism.py b/examples/one_sentence_overfitting/3d_parallelism.py
index e772585e..4613cd48 100644
--- a/examples/one_sentence_overfitting/3d_parallelism.py
+++ b/examples/one_sentence_overfitting/3d_parallelism.py
@@ -1,5 +1,5 @@
import sys
-sys.path.append("../../")
+sys.path.append("../..")
from collie.models.llama.model import LlamaForCausalLM
from collie.controller import Trainer, EvaluatorForGeneration
from collie.metrics.decode import DecodeMetric
@@ -9,12 +9,10 @@
from transformers.generation.utils import GenerationConfig
import torch
-tokenizer = LlamaTokenizer.from_pretrained("/mnt/petrelfs/zhangshuo/model/llama-7b-hf",
- padding_side="left",
- add_eos_token=False)
+tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf", padding_side="left",add_eos_token=False)
tokenizer.bos_token_id = 1
tokenizer.eos_token_id = 2
-config = CollieConfig.from_pretrained("/mnt/petrelfs/zhangshuo/model/llama-7b-hf")
+config = CollieConfig.from_pretrained("decapoda-research/llama-7b-hf")
config.tp_size = 4
config.dp_size = 1
config.pp_size = 2
@@ -31,7 +29,7 @@
}
}
-model = LlamaForCausalLM.from_pretrained("/mnt/petrelfs/zhangshuo/model/llama-7b-hf", config=config)
+model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", config=config)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
train_sample = tokenizer("Collie is a python package for finetuning large language models.", return_tensors="pt").input_ids.squeeze(0)
eval_sample = tokenizer("Collie is", return_tensors="pt")
diff --git a/examples/peft/finetune_llama_ptuning.py b/examples/peft/finetune_llama_ptuning.py
new file mode 100644
index 00000000..0ee9a094
--- /dev/null
+++ b/examples/peft/finetune_llama_ptuning.py
@@ -0,0 +1,150 @@
+"""
+一个使用CoLLie对LLaMA基座进行Prompt tuning的实例。
+"""
+import os
+import sys
+sys.path.append('../..')
+import json
+import torch
+
+from transformers import LlamaTokenizer
+from transformers.generation.utils import GenerationConfig
+from peft import get_peft_model
+
+from collie.config import CollieConfig
+
+from collie.data import CollieDatasetForTraining
+from collie.data import CollieDataLoader
+
+from collie.controller.trainer import Trainer
+from collie.controller.evaluator import EvaluatorForPerplexity, EvaluatorForGeneration
+
+from collie.models.llama.model import LlamaForCausalLM
+from collie.utils.dist_utils import setup_distribution
+
+from collie.utils.monitor import StepTimeMonitor, TGSMonitor, MemoryMonitor, LossMonitor, EvalMonitor
+from collie.metrics import DecodeMetric, PPLMetric, BleuMetric
+from collie.module import GPTLMLoss
+
+from peft import (
+ get_peft_config,
+ get_peft_model,
+ PromptTuningInit,
+ PromptTuningConfig,
+ TaskType,
+ PromptEncoderConfig,
+ PeftType
+)
+
+# 1. 设置路径
+# 1.1 预训练模型路径
+pretrained_model = 'decapoda-research/llama-7b-hf'
+# 1.2 Eval的decode结果保存路径
+save_path = './result'
+
+# 2. 设置配置
+# 2.1 加载配置
+config = CollieConfig.from_pretrained(pretrained_model)
+# 2.2 添加配置
+config.tp_size = 1
+config.dp_size = 4
+config.pp_size = 1
+config.train_epochs = 1
+config.train_micro_batch_size = 1
+config.eval_batch_size = 32
+config.eval_per_n_steps = 100
+config.checkpointing = False
+config.peft_config = PromptTuningConfig(
+ task_type = TaskType.CAUSAL_LM,
+ prompt_tuning_init = PromptTuningInit.TEXT,
+ num_virtual_tokens = 8,
+ token_dim = 4096,
+ num_attention_heads = 32,
+ num_layers = 32,
+ num_transformer_submodules = None,
+ prompt_tuning_init_text="Classify if the tweet is a complaint or not:",
+ tokenizer_name_or_path=pretrained_model
+)
+config.ds_config = {
+ "fp16": {"enabled": True},
+ "monitor_config": {
+ "enabled": True,
+ "tag": "sophia_alpaca",
+ "csv_monitor": {
+ "enabled": True,
+ "output_path": "./ds_logs/"
+ }
+ }
+}
+
+# 3. 设置tokenizer
+tokenizer = LlamaTokenizer.from_pretrained(pretrained_model, padding_side="left")
+
+# 4. 加载数据集
+train_dataset = [
+ {
+ 'input': 'The movie is terrible. ',
+ 'output': 'Yes.'
+ } for _ in range(100)
+]
+train_dataset = CollieDatasetForTraining(train_dataset, tokenizer)
+eval_dataset = train_dataset[:32]
+
+# 5. 加载预训练模型
+model = LlamaForCausalLM.from_config(config)
+model = get_peft_model(model, config.peft_config)
+print(model.print_trainable_parameters())
+
+# 6. 设置优化器
+optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
+
+# 7. 添加监视器
+monitors = [
+ StepTimeMonitor(config),
+ TGSMonitor(config),
+ MemoryMonitor(config),
+ LossMonitor(config),
+ EvalMonitor(config)
+]
+
+# 8. 添加Evaluator
+evaluator_ppl = EvaluatorForPerplexity(
+ model = model,
+ config = config,
+ dataset = eval_dataset,
+ monitors = [
+ EvalMonitor(config)
+ ],
+ metrics = {
+ 'ppl': PPLMetric()
+ }
+)
+evaluator_decode = EvaluatorForGeneration(
+ model = model,
+ config = config,
+ tokenizer = tokenizer,
+ dataset = eval_dataset,
+ monitors = [
+ EvalMonitor(config)
+ ],
+ metrics = {
+ 'decode': DecodeMetric(save_to_file = True, save_path = save_path)
+ }
+
+)
+
+# 9. 实例化trainer
+trainer = Trainer(
+ model = model,
+ config = config,
+ loss_fn = GPTLMLoss(-100),
+ optimizer = optimizer,
+ train_dataset = train_dataset,
+ monitors = monitors,
+ evaluators = [evaluator_ppl, evaluator_decode]
+)
+
+# 10. 训练/验证
+trainer.train()
+
+# Command CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --rdzv_backend=c10d --rdzv_endpoint=localhost:29402 --nnodes=1 --nproc_per_node=4 finetune_llama_prompt_tuning.py
\ No newline at end of file
diff --git a/examples/server.py b/examples/server.py
index bceb44e8..aae7623a 100644
--- a/examples/server.py
+++ b/examples/server.py
@@ -2,16 +2,14 @@
import torch
sys.path.append("..")
-from collie import Server, LlamaForCausalLM, DashProvider, CollieConfig, env, MossForCausalLM, ChatGLMForCausalLM
-from transformers import LlamaTokenizer, GenerationConfig, BitsAndBytesConfig
+from collie import Server, LlamaForCausalLM, DashProvider, CollieConfig
+from transformers import LlamaTokenizer, GenerationConfig
config = CollieConfig.from_pretrained("openlm-research/open_llama_13b", trust_remote_code=True)
config.pp_size = 1
config.tp_size = 1
-model = LlamaForCausalLM.from_pretrained(
- "/mnt/petrelfs/zhangshuo/model/llama-13b-hf", config=config).cuda()
-tokenizer = LlamaTokenizer.from_pretrained(
- "/mnt/petrelfs/zhangshuo/model/llama-13b-hf", add_eos_token=False)
+model = LlamaForCausalLM.from_pretrained("openlm-research/open_llama_13b", config=config).cuda()
+tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_13b", add_eos_token=False)
data_provider = DashProvider(tokenizer=tokenizer)
data_provider.generation_config = GenerationConfig(max_new_tokens=250)
server = Server(model, data_provider, config=config)
From 13e0178dac39f1f2e84051ab98ea7d083159ae42 Mon Sep 17 00:00:00 2001
From: Carol-gutianle <2311601671@qq.com>
Date: Fri, 28 Jul 2023 19:14:13 +0800
Subject: [PATCH 2/2] add: news
---
README.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/README.md b/README.md
index 577ed7b0..1e50c3aa 100644
--- a/README.md
+++ b/README.md
@@ -24,6 +24,7 @@ CoLLiE (Collaborative Tuning of Large Language Models in an Efficient Way),一
## 新闻
+* [2023/07/18] 发布Python包`collie-lm`。您可以在[链接](https://pypi.org/project/collie-lm/#history)中查看更多细节!
## 目录