Skip to content
This repository has been archived by the owner on Dec 9, 2024. It is now read-only.

How to replace AllReduce with Reduce in parameter server mode? #467

Open
zhao1157 opened this issue Apr 17, 2020 · 4 comments
Open

How to replace AllReduce with Reduce in parameter server mode? #467

zhao1157 opened this issue Apr 17, 2020 · 4 comments

Comments

@zhao1157
Copy link

@reedwm
In parameter server mode, I managed to replace grad = tf.add_n(grads) by nccl all reduce:

#nccl all reduce
sum_grads = allreduce.sum_grad_and_var_all_reduce(
                                        False,
                                        grad_and_vars,
                                        1, 'nccl', gpu_indices)
#get the variable device index
var_dev = grad_and_vars[0][1].device.split(':')[-1]
#only use the sum tensor on variable device
for s_v in sum_grads:
  if s_v[0].device.split(':')[-1] == var_dev:
    grad = s_v[0]
#make sure all the tensors of nccl-all-reduce sum are evaluated, otherwise the process will hang
with tf.control_dependencies([s_v[0] for s_v in sum_grads]):
  grad = tf.identity(grad)

I tried to figure out a way to accomplish the sum without using all-reduce since I only need one copy of the sum, not number-of-gpus copies of sum. In tf, is there a reduce API I can use? Thanks.

@reedwm
Copy link
Member

reedwm commented Apr 17, 2020

grad = tf.add_n(grads) is a sum without an all-reduce. An all-reduce is simply an add_n except you get the output tensor on every device instead of just one device.

Note the reason parameter server mode takes a mean of the gradients instead of a sum is that use_mean=True is passed here. Passing use_mean=True causes the gradients to be divided by the number of replicas after summing them.

@zhao1157
Copy link
Author

@reedwm Thanks for getting back to me about tf.add_n(). The reason I asked this question is that I heard nccl all-reduce consumes far less memory than tf.add_n() does, so I wondered if there is a function I could use to achieve the functionality of tf.add_n() with less memory.

@reedwm
Copy link
Member

reedwm commented Apr 20, 2020

On a single device, add_n uses the least possible amount memory, as it only allocates its output tensor. The reason add_n may use more memory than nccl all-reduce is that if add_n is used to implement an all-reduce, all tensors must be copied to a single device, so that device must allocate all the tensors. I haven't checked, but it's likely nccl all-reduce does not require a single device to allocate enough space for all the tensors.

However, with parameter server mode, add_n isn't used to implement an all-reduce, so the issue doesn't occur. There is no way to replace add_n with some other reduction to make it take less memory. All the per-GPU tensors are transferred to the parameter server device(s), so the parameter server device(s) must have enough memory to hold all the per-GPU tensors.

@zhao1157
Copy link
Author

zhao1157 commented Apr 21, 2020

@reedwm
Thanks for your very helpful clarification.

I haven't checked, but it's likely nccl all-reduce does not require a single device to allocate enough space for all the tensors.

I know ring all-reduce does the sum sequentially on each device instead of having all tensors transferred to one single device, thus requiring less memory. Do you know the difference between nccl all-reduce and ring all-reduce?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants