-
Notifications
You must be signed in to change notification settings - Fork 50
/
dqn_algorithm.py
216 lines (198 loc) · 9.89 KB
/
dqn_algorithm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# Copyright (c) 2022 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DQN Algorithm."""
import torch
import torch.distributions as td
from typing import Callable, Optional, Union
import alf
from alf.algorithms.config import TrainerConfig
from alf.algorithms.sac_algorithm import SacAlgorithm, ActionType, \
SacState as DqnState, SacCriticState as DqnCriticState, \
SacActionState as DqnActionState, \
SacInfo as DqnInfo, SacCriticInfo as DqnCriticInfo, \
SacLossInfo as DqnLossInfo
from alf.algorithms.td_loss import TDLoss
from alf.data_structures import AlgStep, LossInfo, TimeStep
from alf.environments.alf_environment import AlfEnvironment
from alf.networks import QNetwork
from alf.optimizers import AdamTF
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils import common, dist_utils
from alf.utils.schedulers import as_scheduler, Scheduler
@alf.configurable
class DqnAlgorithm(SacAlgorithm):
r"""DQN/DDQN algorithm:
::
Mnih et al "Playing Atari with Deep Reinforcement Learning", arXiv:1312.5602
Hasselt et al "Deep Reinforcement Learning with Double Q-learning", arXiv:1509.06461
The difference with DQN is that a minimum is taken from the two critics,
similar to TD3, instead of choosing the maximum action using the Q network
and evaluating the action value using the target Q network.
The implementation is based on the SAC algorithm.
"""
def __init__(self,
observation_spec: alf.tensor_specs.NestedTensorSpec,
action_spec: alf.tensor_specs.BoundedTensorSpec,
reward_spec: TensorSpec = TensorSpec(()),
q_network_cls: Callable[..., QNetwork] = QNetwork,
q_optimizer: Optional[torch.optim.Optimizer] = None,
rollout_epsilon_greedy: Union[float, Scheduler] = 0.1,
target_net_target_action: bool = True,
num_critic_replicas: int = 2,
env: Optional[AlfEnvironment] = None,
config: Optional[TrainerConfig] = None,
critic_loss_ctor: Optional[Callable[..., TDLoss]] = None,
checkpoint=None,
debug_summaries: bool = False,
name: str = "DqnAlgorithm"):
"""
Args:
observation_spec (nested TensorSpec): representing the observations.
action_spec (BoundedTensorSpec): Only one discrete action allowed.
reward_spec (TensorSpec): a rank-1 or rank-0 tensor spec representing
the reward(s).
q_network: is used to construct QNetwork for estimating ``Q(s,a)``
given that the action is discrete. Its output spec must be consistent with
the discrete action in ``action_spec``.
q_optimizer: A custom optimizer for the q network.
Uses the enclosing algorithm's optimizer if None.
rollout_epsilon_greedy: epsilon greedy policy for rollout.
Together with the following two parameters, the SAC algorithm
can be converted to a DQN or DDQN algorithm when e.g.
``rollout_epsilon_greedy=0.3``, ``max_target_action=True``, and
``use_entropy_reward=False``.
target_net_target_action: when ``True`` uses target critic network to get
target action (similar as DDPG). When ``False``, uses
critic network to get target action (similar as DDQN/SAC).
num_critic_replicas: number of critics to be used. Default is 2.
env: The environment to interact with. ``env`` is a
batched environment, which means that it runs multiple simulations
simultateously. ``env` only needs to be provided to the root
algorithm.
config: config for training. It only needs to be
provided to the algorithm which performs ``train_iter()`` by
itself.
critic_loss_ctor: a critic loss
constructor. If ``None``, a default ``OneStepTDLoss`` will be used.
checkpoint (None|str): a string in the format of "prefix@path",
where the "prefix" is the multi-step path to the contents in the
checkpoint to be loaded. "path" is the full path to the checkpoint
file saved by ALF. Refer to ``Algorithm`` for more details.
debug_summaries (bool): True if debug summaries should be created.
name (str): The name of this algorithm.
"""
self._rollout_epsilon_greedy = as_scheduler(rollout_epsilon_greedy)
self._target_net_target_action = target_net_target_action
# Disable alpha learning:
alpha_optimizer = AdamTF(lr=0)
super().__init__(
observation_spec=observation_spec,
action_spec=action_spec,
reward_spec=reward_spec,
actor_network_cls=None,
critic_network_cls=None,
q_network_cls=q_network_cls,
# Do not use entropy reward:
use_entropy_reward=False,
num_critic_replicas=num_critic_replicas,
env=env,
config=config,
critic_loss_ctor=critic_loss_ctor,
# Allow custom optimizer for q_network:
critic_optimizer=q_optimizer,
alpha_optimizer=alpha_optimizer,
checkpoint=checkpoint,
debug_summaries=debug_summaries,
name=name)
assert self._act_type == ActionType.Discrete
# Copied and modified from sac_algorithm (discrete actions).
def _predict_action(self,
observation,
state: DqnActionState,
epsilon_greedy=None,
eps_greedy_sampling=False):
new_state = DqnActionState()
critic_network_inputs = (observation, None)
# NOTE: This block departs from SAC:
if eps_greedy_sampling or not self._target_net_target_action:
# SAC always uses critic_networks to obtain action,
# even during training
nets = self._critic_networks
else:
nets = self._target_critic_networks
q_values, critic_state = self._compute_critics(
nets, *critic_network_inputs, state.critic)
new_state = new_state._replace(critic=critic_state)
# NOTE: This block departs from SAC:
size = q_values.shape # [B, actions]
if eps_greedy_sampling:
# Epsilon greedy for rollout or evaluation.
rand_act_prob = epsilon_greedy / size[-1]
probs = torch.ones_like(q_values) * rand_act_prob
if epsilon_greedy >= 1:
# Uniform random action distribution
greedy_act_prob = rand_act_prob
else:
# Epsilon greedy
greedy_act_prob = 1 - epsilon_greedy + rand_act_prob
else:
# Greedy for train_step to obtain target value from target network.
# The greedy action here is the maximizer of q values, and will be used
# to obtain target value using the target network.
probs = torch.zeros_like(q_values)
greedy_act_prob = 1
greedy_action = torch.argmax(q_values, dim=-1)
probs[torch.arange(size[0]), greedy_action] = greedy_act_prob
action_dist = td.Categorical(probs=probs)
action = dist_utils.sample_action_distribution(action_dist)
return action_dist, action, q_values, new_state
# Copied and modified from sac_algorithm (discrete actions).
def rollout_step(self, inputs: TimeStep, state: DqnState):
"""``rollout_step()`` basically predicts actions like what is done by
``predict_step()``. Additionally, if states are to be stored a in replay
buffer, then this function also call ``_critic_networks`` and
``_target_critic_networks`` to maintain their states.
"""
action_dist, action, _, action_state = self._predict_action(
inputs.observation,
state=state.action,
# NOTE: This is the only departure from SAC.
epsilon_greedy=self._rollout_epsilon_greedy(),
eps_greedy_sampling=True)
if self.need_full_rollout_state():
_, critics_state = self._compute_critics(
self._critic_networks, inputs.observation, action,
state.critic.critics)
_, target_critics_state = self._compute_critics(
self._target_critic_networks, inputs.observation, action,
state.critic.target_critics)
critic_state = DqnCriticState(
critics=critics_state, target_critics=target_critics_state)
actor_state = ()
else:
actor_state = state.actor
critic_state = state.critic
new_state = DqnState(
action=action_state, actor=actor_state, critic=critic_state)
return AlgStep(
output=action,
state=new_state,
info=DqnInfo(action=action, action_distribution=action_dist))
def calc_loss(self, info: DqnInfo):
# Adapted from SAC: Removes irrelevant losses and logging.
critic_loss = self._calc_critic_loss(info)
return LossInfo(
loss=critic_loss.loss,
priority=critic_loss.priority,
extra=DqnLossInfo(critic=critic_loss.extra, actor=(), alpha=()))