Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pipeline Parallel (and 2D PP+FSDP) support #161

Closed
wants to merge 34 commits into from

Conversation

wconstab
Copy link
Contributor

@wconstab wconstab commented Mar 23, 2024

Stack from ghstack (oldest at bottom):


  • uses pipeline tracer frontend to extract a graph and partition it into
    chunks per stage
  • hardcodes one schedule (1F1B) for now (need to expose option to switch
    schedule and test other schedules)
  • supports 2D parallelism currently, 3D (TP) is work in progress

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Mar 23, 2024
ghstack-source-id: 14902407f0c573a4b4e9f615495b805af0ed8afc
Pull Request resolved: #161
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 23, 2024
[ghstack-poisoned]
torchtrain/parallelisms/parallelize_llama.py Outdated Show resolved Hide resolved
train.py Outdated
logger.info(
f"{Color.blue}Extracting pipeline module for stage {pp_mesh.get_local_rank()}{Color.reset}"
)
model = pmod.get_stage_module(pp_mesh.get_local_rank())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: watch out for rank-stage inequality in case of Interleaved 1F1B.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, i need to switch to an interleaved schedule and clean this up

Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the demo! LGTM!

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Mar 29, 2024
traced module is burning in a 'meta' device arg for one 'ones' op which
breaks runtime after moving model to 'cuda'.

Haven't worked on loss fn yet.

ghstack-source-id: 47735f666b6086e179699b1bbfb06168b488d4d4
Pull Request resolved: #161
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Apr 2, 2024
Haven't worked on loss fn yet.

ghstack-source-id: 4c438ddd2989e427489c4e2d5a9ddd35711bdb78
Pull Request resolved: #161
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Apr 3, 2024
(fake) Loss now runs and propagates to logger

ghstack-source-id: b5a290878909ebc67bbcfda25809be439e222523
Pull Request resolved: #161
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Apr 3, 2024
Loss now runs and propagates to logger, but optimizer isn't working

ghstack-source-id: 56b0ef0ed92d181126e6866a153316f00431c7e7
Pull Request resolved: #161
wconstab added a commit that referenced this pull request Apr 3, 2024
Loss now runs and propagates to logger, but optimizer isn't working

ghstack-source-id: 56b0ef0ed92d181126e6866a153316f00431c7e7
Pull Request resolved: #161
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Apr 5, 2024
Loss now runs and propagates to logger, but optimizer isn't working

ghstack-source-id: 4ede08f5a9d1bc994448cb057bb491d24866d078
Pull Request resolved: #161
[ghstack-poisoned]
wconstab added a commit that referenced this pull request Apr 5, 2024
- uses pipeline tracer frontend to extract a graph and partition it into
  chunks per stage
- hardcodes one schedule (1F1B) for now
- supports 1D parallelism currently.

WIP: support 2D/3D parallel and clean up seed-checkpoint ux

ghstack-source-id: 7055ffe515b79fa6edad58a72543d9bc8e866f80
Pull Request resolved: #161
@wconstab wconstab changed the title WIP integrate pippy's tracer frontend Add Pipeline Parallel support Apr 5, 2024
wconstab added a commit that referenced this pull request Apr 5, 2024
- uses pipeline tracer frontend to extract a graph and partition it into
  chunks per stage
- hardcodes one schedule (1F1B) for now
- supports 1D parallelism currently.

WIP: support 2D/3D parallel and clean up seed-checkpoint ux

ghstack-source-id: 7055ffe515b79fa6edad58a72543d9bc8e866f80
Pull Request resolved: #161
[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 2, 2024
- uses pipeline tracer frontend to extract a graph and partition it into
  chunks per stage
- hardcodes one schedule (1F1B) for now (need to expose option to switch
  schedule and test other schedules)
- supports 2D parallelism currently, 3D (TP) is work in progress

ghstack-source-id: a6cb4c35ccd218219eabdd25d55f62743278ed81
Pull Request resolved: #161
[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 2, 2024
- uses pipeline tracer frontend to extract a graph and partition it into
  chunks per stage
- hardcodes one schedule (1F1B) for now (need to expose option to switch
  schedule and test other schedules)
- supports 2D parallelism currently, 3D (TP) is work in progress

ghstack-source-id: 6bd801399be3f77a45d1dda11bc87e9a90b92df4
Pull Request resolved: #161
Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
Thanks for pulling PP in!

torchtitan/parallelisms/parallelize_llama.py Outdated Show resolved Hide resolved
print("labels: ", labels.shape, labels.dtype)

# Create a pipeline representation from the model
pipe = pipeline(model, parallel_dims.pp, example_args=(input_ids,))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: strictly speaking, the second arg is the number of microbatches -- it is okay if you using PP dim to represent it for now. Longer term I think it should be exposed as a field in the config file.

torchtitan/parallelisms/parallelize_llama.py Outdated Show resolved Hide resolved
train.py Outdated Show resolved Hide resolved
[ghstack-poisoned]
requirements.txt Outdated Show resolved Hide resolved
torchtitan/models/llama/model.py Outdated Show resolved Hide resolved
torchtitan/parallelisms/parallelize_llama.py Outdated Show resolved Hide resolved
)

# Get example input
label_shape = input_shape = (8, 2048) # TODO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm would PP be working for all cases that are not this shape, or it requires the shape to be the exact input shape of the runtime?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to double check how this works and fix.

# TODO(whc) need to fix PP + FSDP-mixed-precision
# tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs
# param_dtype=torch.bfloat16, reduce_dtype=torch.float32
param_dtype=torch.float32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shouldn't by default change this, this would make the cases where FSDP or FSDP + TP use fp32 instead of bf16

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if supporting bf16 should be a criteria for landing. I would imagine that training with FSDP + PP in fp32 is not really viable efficiency-wise (at least for larger jobs).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should fix this before landing the PP change. I think there was a possible way to fix this in the tracer, but lost track of it, will dig it up

torchtitan/parallelisms/parallelize_llama.py Outdated Show resolved Hide resolved
torchtitan/parallelisms/parallelize_llama.py Outdated Show resolved Hide resolved
train.py Outdated
# there are virtual stages
if parallel_dims.pp_enabled:
stage = PipelineStage(
pipe=pipe_meta,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be pipe_meta or model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its correct. Ke proposed an alternative, but we'd still have to pass the pipe_info and the model into _PipelineStage in that case. I could make this change.

pipe=pipe_meta,
stage_index=pp_rank,
device=device,
group=pp_mesh.get_group(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if we should put the stage creation into parallelize_llama, IMO we only need pp_schedule in train.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, I think this question and Ke's suggestion about returning a PipelineStage from parallelize_llama are better taken in context of a next PR that also adds support for looped schedules.

Looped schedules further complicate things bc the PP logic first needs to chunk up the model, then apply the DP/TP portion of parallelize_llama on each chunk, and finally pass all the chunks into the schedule.

I think in the end, I might prefer to separate out PP from parallelize_llama, and have a flow where we can take the return from PP apply function and iteratively call parallelize_llama on those chunks.

loss = (
torch.mean(torch.stack(losses))
if is_last_stage
else torch.Tensor([-1.0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need the default -1 value? because of logging purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, yea i could make it a 'None' but then i have to update logger to not log at all. maybe that's actually a better way to do it. let me try that.

Copy link
Contributor Author

@wconstab wconstab May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok- so what I could do is try to alter the metrics code so that on non-last-stage ranks, we omit printing loss, or, we print "loss: None" instead of -1.

The change will add more lines of code, since I need to deal with several places that expect loss and global_[avg/mean]_loss to be valid numbers

  • avoid writing them into metrics dict
  • replace their format string with a string value instead of a float value in the logger.info
  • avoid calling loss.item() in the first place

I agree in principle that's the "right" fix, but i'm not sure if its worth the LOC / complexity. I don't totally hate the -1 thing.

Another option I considered is to skip the whole codeblock of '# log metrics' on non-last-stage ranks. I ruled this out, since it is still useful to log mfu, memory for other ranks.

So let me know what you want to do here @wanchaol

[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 2, 2024
- uses pipeline tracer frontend to extract a graph and partition it into
  chunks per stage
- hardcodes one schedule (1F1B) for now (need to expose option to switch
  schedule and test other schedules)
- supports 2D parallelism currently, 3D (TP) is work in progress

ghstack-source-id: 205f8b08eac15bb7bee66ecdec439b9828b0949c
Pull Request resolved: #161
[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 2, 2024
- uses pipeline tracer frontend to extract a graph and partition it into
  chunks per stage
- hardcodes one schedule (1F1B) for now (need to expose option to switch
  schedule and test other schedules)
- supports 2D parallelism currently, 3D (TP) is work in progress

ghstack-source-id: cbbb628fd823d579064a8038e6511ec77457ef19
Pull Request resolved: #161
[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 3, 2024
- uses pipeline tracer frontend to extract a graph and partition it into
  chunks per stage
- hardcodes one schedule (1F1B) for now (need to expose option to switch
  schedule and test other schedules)
- supports 2D parallelism currently, 3D (TP) is work in progress

ghstack-source-id: 94f89f90787cca27310cb966a7edf7ea9bbc0098
Pull Request resolved: #161
[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 3, 2024
- uses pipeline tracer frontend to extract a graph and partition it into
  chunks per stage
- hardcodes one schedule (1F1B) for now (need to expose option to switch
  schedule and test other schedules)
- supports 2D parallelism currently, 3D (TP) is work in progress

ghstack-source-id: ac8c37124f79f8246155e14da23c2f5cfd75c0de
Pull Request resolved: #161
[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 3, 2024
- uses pipeline tracer frontend to extract a graph and partition it into
  chunks per stage
- hardcodes one schedule (1F1B) for now (need to expose option to switch
  schedule and test other schedules)
- supports 2D parallelism currently, 3D (TP) is work in progress

ghstack-source-id: feb45e115f7bbee37179887bb196c12d21d93b43
Pull Request resolved: #161
for i in range(1, parallel_dims.pp)
}
# Get example input
label_shape = input_shape = (8, 2048) # TODO
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kwen2501 any ideas for a clean way to do this in torchtrain? do we expect people to get a batch out of their dataloader and then reset it? or do we expect people to hardcode it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think what i might do is directly pass input_shape from train.py,

and in train.py i can set input_shape = (job_config.batch_size, job_config.seq_len) or something. is that clean enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok pushed a variation on this.

not sure if its better to hide this inside parallelize since we already have job config, or make it explicit from train.py that we are passing input_shape in for some reason

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either way sounds okay to me -- eventually, the shape comes the config.

[ghstack-poisoned]
wconstab added a commit that referenced this pull request May 3, 2024
- uses pipeline tracer frontend to extract a graph and partition it into
  chunks per stage
- hardcodes one schedule (1F1B) for now (need to expose option to switch
  schedule and test other schedules)
- supports 2D parallelism currently, 3D (TP) is work in progress

ghstack-source-id: 0616a1c0d40f8e51ddfc1b2d330dbddc491e00e2
Pull Request resolved: #161
layers_per_rank = len(model.layers) // parallel_dims.pp
split_spec = {
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING
for i in range(1, parallel_dims.pp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm new to PP api and have a question:
If layers_per_rank = 5, parallel_dims.pp = 2, what should be the split_spec. My straightforward thought is SplitPoint.BEGINNING should contain i = 1, 3, 5, but according to the code it's just i = 1.

Copy link
Contributor

@kwen2501 kwen2501 May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parallel_dims.pp refers to the number of pipeline stages we split the model into.
For example, if model.layers = 10, 10 // 2 = 5, then we put 5 layers per stage (i.e. layers_per_rank = 5).
Hence we make a cut at model.layers.5 -- (nRanks - 1) split points.

[ghstack-poisoned]
[ghstack-poisoned]
@wconstab
Copy link
Contributor Author

squashed

@wconstab wconstab closed this May 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants