-
Notifications
You must be signed in to change notification settings - Fork 14
/
run_legacy_evaluation.py
49 lines (38 loc) · 1.32 KB
/
run_legacy_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import click
import torch
import run_clm
from models.lora_utils import transform_lora_adapters_nf8
from experiments.legacy_evaluation_utils import legacy_evaluation
CHECKPOINT_BASE_DIR_DICT = {
}
if __name__ == "__main__":
# Setting up the model, tokenizer
trainer = run_clm.main(return_trainer=True)
# Load the model checkpoint
checkpoint_base_dir = os.getenv(
"CHECKPOINT_BASE_DIR",
default=None)
if checkpoint_base_dir is None:
checkpoint_base_dir = (
CHECKPOINT_BASE_DIR_DICT[
trainer.args.output_dir])
checkpoint_path = os.path.join(
checkpoint_base_dir,
trainer.args.output_dir,
"full_model.pth")
state_dict = torch.load(
checkpoint_path,
map_location=torch.device("cpu"))
trainer.model.load_state_dict(state_dict)
click.secho(f"Loaded model from {checkpoint_path}", fg="green")
# Optionally transforming the adapters
if os.getenv("TRANSFORM_ADAPTERS", default=False) is not False:
transform_lora_adapters_nf8(trainer.model)
# Run the evaluation
results = legacy_evaluation(
model=trainer.model,
tokenizer=trainer.tokenizer,
device="cuda")
for dataset_name, result in results.items():
click.secho(f"{dataset_name}: {result}", fg="green")