forked from kingoflolz/mesh-transformer-jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
device_sample.py
105 lines (77 loc) · 3.42 KB
/
device_sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import json
import time
import jax
import numpy as np
import optax
from mesh_transformer import util
from mesh_transformer.checkpoint import read_ckpt
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer
import transformers
from smart_open import open
from mesh_transformer.util import clip_by_global_norm
def parse_args():
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=None, help="Config file location")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
params = json.load(open(args.config))
gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1)
per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
assert cores_per_replica <= 8
bucket = params["bucket"]
model_dir = params["model_dir"]
layers = params["layers"]
d_model = params["d_model"]
n_heads = params["n_heads"]
n_vocab = params["n_vocab"]
seq = params["seq"]
norm = params["norm"]
params["sampler"] = nucleaus_sample
opt = optax.chain(
optax.scale(1 / gradient_accumulation_steps),
clip_by_global_norm(1),
optax.scale_by_adam(),
optax.additive_weight_decay(0),
optax.scale(-1),
optax.scale_by_schedule(util.gpt3_schedule(0, 1, 0, 0))
)
params["optimizer"] = opt
start = time.time()
print(f"jax devices: {jax.device_count()}")
print(f"jax runtime initialized in {time.time() - start:.06}s")
mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)
with open(f"gs://{bucket}/{model_dir}/meta.json", "r") as f:
meta = json.load(f)
ckpt_step = meta["checkpoints"][-1]
print(f"using checkpoint {ckpt_step}")
total_batch = per_replica_batch * jax.device_count() // cores_per_replica
with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
network = CausalTransformer(params)
start = time.time()
network.state = read_ckpt(network.state, f"gs://{bucket}/{model_dir}/step_{ckpt_step}/", devices.shape[1])
print(f"network loaded in {time.time() - start:.06}s")
local_shards = max(jax.local_device_count() // mesh_shape[1], 1)
del network.state["opt_state"]
network.state = network.move_xmap(network.state, np.zeros(local_shards))
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
while True:
context = input("Type input:")
tokens = tokenizer.encode(context)
start = time.time()
provided_ctx = len(tokens)
pad_amount = seq - provided_ctx
padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
batched_tokens = np.array([padded_tokens] * total_batch)
length = np.ones(total_batch, dtype=np.uint32) * len(tokens)
output = network.generate(batched_tokens, length, 512, {"top_p": np.ones(total_batch) * 0.9,
"temp": np.ones(total_batch) * 0.75})
for idx, o in enumerate(output[1][0][:, :, 0]):
print(f"sample {idx}: {repr(tokenizer.decode(o))}")
print(f"completion done in {time.time() - start:06}s")