-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
executable file
·42 lines (32 loc) · 1.31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import tensorflow as tf
from environment import Environment
from agent import Agent
from logger import Logger
if __name__ == '__main__':
tf.compat.v1.enable_eager_execution()
tf.device("/gpu:0")
# create environment object
env = Environment(space_sleep=0.55, no_action_sleep=0.05)
logger = Logger(fp=None)
memory_fp = 'C:/Users/ryano/rl_projects/trex_memory/memory.pkl'
save_path = 'C:/Users/ryano/repos/deep-rl-trex/model/model'
mem_length = 150000
agent = Agent(env,
tf.keras.optimizers.Adam(learning_rate=0.0001),
loss='mse',
memory_length=mem_length,
dueling=True,
noisy_net=False,
egreedy=False,
save_memory=memory_fp,
save_weights=save_path,
verbose_action=True)
agent.load_weights(save_path)
agent.load_memory(memory_fp)
agent.set_beta_schedule(beta_start=.99999, beta_max=1, annealed_samplings=30000)
agent.set_epsilon_decay_schedule(epsilon=0.001, epsilon_min=0.0001, annealed_steps=10000)
agent.pretraining_steps = 0
print(f'pretraining for {agent.pretraining_steps} steps...')
env.init_game()
for episode in range(10000000):
env.run(episode, agent, batch_size=32, logger=logger)