Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Counting how many forward passes/steps were done when using PAIN #14

Open
jivanph opened this issue Jan 30, 2024 · 4 comments
Open

Counting how many forward passes/steps were done when using PAIN #14

jivanph opened this issue Jan 30, 2024 · 4 comments

Comments

@jivanph
Copy link

jivanph commented Jan 30, 2024

I wanted to ask if there's a way to count how many forward passes/steps are done when using PAIN, to contrast it with standard decoding.

@jivanph
Copy link
Author

jivanph commented Jan 30, 2024

On a different note, what are the parameters for the tree object? How many branches are made and how deep does the tree go?

@zheyishine
Copy link
Collaborator

You can count the steps with two methods, one is turning on the debug_lookahead, it will output debug info of each step and you can count the steps manually, the other is turning on return_dict_in_generate in model.generation method, the kwargs of outputs will output decoding summary, len(kwargs['dls']) is step count.

We use different parameters for different tasks. As methoned in the readme of out repo, we use decoding_length=128 (i.e., forward token count) and branch_length=32 (i.e., tree depth) for RAG tasks and decoding_length=64 and branch_length=8 for dolly and GSM8K tasks. We do not use the branch count parameter as we care more about factual token count in a forward pass rather than logical branches.

@jivanph
Copy link
Author

jivanph commented Feb 1, 2024

Thank you so much for your response. This helped me greatly. If I understand correctly, if I want to count how many draft token in total were used when using PAIN, I could just compute sum(kwargs['dls'])

@zheyishine
Copy link
Collaborator

Should be sum(kwargs['dls'])-len(kwargs['dls']), because the decoding_length(i.e., dls) is compose of the next token and draft tokens, we should minus 1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants