-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.py
54 lines (37 loc) · 1.22 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gym
import torch
import numpy as np
from mpo import MPO
from mpo_nets import CategoricalActor, Critic
np.random.seed(0)
torch.manual_seed(0)
if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_envs = 5
env_name = 'LunarLander-v2'
vec_env = gym.vector.make(env_name, num_envs=5)
obs_shape = vec_env.observation_space.shape[-1]
action_shape = vec_env.action_space[0].n
actor = CategoricalActor(obs_shape, action_shape)
critic = Critic(obs_shape, action_shape)
if device == 'cuda':
actor.cuda()
critic.cuda()
def train():
mpo = MPO(vec_env, actor, critic, obs_shape, action_shape, device=device)
mpo.load_model()
mpo.train()
def evaluate():
env = gym.make(env_name)
mpo = MPO(vec_env, actor, critic, obs_shape, action_shape, device=device)
mpo.load_model()
obs = env.reset()
while True:
act, _ = mpo.actor.action(torch.Tensor(np.array([obs])).to(device))
act = act.cpu().detach().numpy()[0]
obs, r, d, _ = env.step(act)
env.render()
if d:
obs = env.reset()
#evaluate()
train()