Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in alpha optimizer in MTSAC #2303

Open
lucasosouza opened this issue Oct 12, 2021 · 1 comment
Open

Bug in alpha optimizer in MTSAC #2303

lucasosouza opened this issue Oct 12, 2021 · 1 comment
Assignees
Labels
algos bug Something isn't working pytorch

Comments

@lucasosouza
Copy link

lucasosouza commented Oct 12, 2021

There is a potential bug in how alpha optimizer is initialized in MTSAC. During init we have:

            self._log_alpha = torch.Tensor([self._initial_log_entropy] *
                                           self._num_tasks).requires_grad_()
            self._alpha_optimizer = optimizer([self._log_alpha] *
                                              self._num_tasks,
                                              lr=self._policy_lr)

Since log alpha is a tensor, what is being passed to the optimizer is the same tensor multiple times. I don't think that is the intended behavior. In the to() function below, it is overridden with the correct initialization for the optimizer:

            self._alpha_optimizer = self._optimizer([self._log_alpha],
                                                    lr=self._policy_lr) 

PyTorch recognizes the parameters are duplicates:

UserWarning: optimizer contains a parameter group with duplicate parameters; in future, this will cause an error; see github.com/pytorch/pytorch/issues/40967 for more information. 

But as github.com/pytorch/pytorch/issues/40967 details, the net effect is the tensor log_alpha gets updated num_task times at each step, since all copies belong to the same param_group.

A quick test can show that:

# incorrect version, in init
num_tasks=50
log_alpha = torch.Tensor([0.] * num_tasks).requires_grad_()
log_alpha._grad = torch.ones(num_tasks)
alpha_optimizer = torch.optim.Adam([log_alpha] * num_tasks, lr=0.1)
alpha_optimizer.step()
torch.allclose(log_alpha, torch.ones(num_tasks) * -5.)  # True

# correct version, in to
num_tasks=50
log_alpha = torch.Tensor([0.] * num_tasks).requires_grad_()
log_alpha._grad = torch.ones(num_tasks)
alpha_optimizer = torch.optim.Adam([log_alpha], lr=0.1)
alpha_optimizer.step()
torch.allclose(log_alpha, torch.ones(num_tasks) * -0.1)  # True

@abhi-iyer

@krzentner
Copy link
Contributor

Thanks for reporting this. I suppose the overall effect is that the alpha learning rate is essentially multiplied by the number of tasks being trained. This should be about as simple as just replacing that one line of code, so we should definitely fix this.

@krzentner krzentner added algos bug Something isn't working pytorch labels Oct 13, 2021
@krzentner krzentner self-assigned this Oct 13, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
algos bug Something isn't working pytorch
Projects
None yet
Development

No branches or pull requests

2 participants