-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Correctly invoke prefill & decode kernels for cross-attention (towards eventual encoder/decoder model support) #4888
base: main
Are you sure you want to change the base?
Conversation
…rom parent metadata struct to child metadata structs; cross-attn test runs without functional errors but fails all_close
… metadata structure!
… seq len was unused
…type annotations; formatting
…enc/dec supported feature checks
…n all relevant locations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good if you could share some pseudo code about how an encoder-decoder model would orchestrate encoder/decoder and cross-attention computation. I think clarifying that would be good to make sure that the current attention backend changes are able to support the overall flow.
|
||
# (batch_size,). The sequence length per sequence. Sequence length means | ||
# the computed tokens + new tokens None if it is a decoding. | ||
seq_lens: Optional[List[int]] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can you move seq_lens back to L:71 with seq_lens_tensor.
_cached_prefill_metadata: Optional["XFormersMetadata"] = None | ||
_cached_decode_metadata: Optional["XFormersMetadata"] = None | ||
|
||
# Begin encoder attn & enc/dec cross-attn fields... | ||
|
||
# If True, prefill_metadata() and decode_metadata() will return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the comment to reflect _attn_type is an enum and not a bool.
# otherwise, self-attention data structures will be returned. | ||
_attn_type: AttentionType = AttentionType.DECODER | ||
|
||
# (batch_size,). The "cross-sequence-length" per sequence,i.e. the key/value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the comment here.
|
||
@property | ||
def prefill_metadata(self) -> Optional["XFormersMetadata"]: | ||
if self.num_prefills == 0: | ||
return None | ||
|
||
if self._cached_prefill_metadata is not None: | ||
self._cached_prefill_metadata.attention_type = \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to set the attention_type separately here? Isn't it already getting set as part of defining self._cached_prefill_metadata?
@@ -154,28 +249,134 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: | |||
return None | |||
|
|||
if self._cached_decode_metadata is not None: | |||
self._cached_decode_metadata.attention_type = \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to set the attention_type separately here? Isn't it already getting set as part of defining self._cached_decode_metadata ?
* {query,key,value}: packed (number_of_tokens x num_heads | ||
x head_size) attention inputs | ||
* q_seq_lens: list of query start locations within packed tensor | ||
* kv_seq_lens: shared list of key/value start locations within |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kv_seq_lens -> kv_start_loc_list
kv_cache: Optional[torch.Tensor], | ||
attn_metadata: "XFormersMetadata", | ||
kv_scale: float = 1.0, | ||
) -> torch.Tensor: | ||
"""Forward pass with xFormers and PagedAttention. | ||
|
||
For decoder-only models: query, key and value must be non-None. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related directly to this pr but can you share the rough pseudo code of an encoder-decoder model and how it would be structured. I wanted to understand how we are going to pass orchestrate the encoder/decoder/cross attention computation in the model forward pass.
-> PhaseTestParameters: | ||
(num_heads, head_size, _, batch_size, _, _, max_q_seq_len, _) = test_pt | ||
|
||
scale = test_rsrcs.scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this line below the function doc?
@pytest.mark.parametrize("block_size", BLOCK_SIZES) | ||
@pytest.mark.parametrize("max_dec_seq_len", MAX_DEC_SEQ_LENS) | ||
@pytest.mark.parametrize("max_enc_seq_len", MAX_ENC_SEQ_LENS) | ||
def test_backend_fails_for_chunked_prefill_enc_dec(num_heads: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned in a previous comment if we can decide to remove these checks from the kernel and move to higher layers then we can remove these tests.
Construct fake attention metadata for a given test phase | ||
(prefill-phase or decode-phase). | ||
encoder_test_params and cross_test_params arguments all encoder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you clarify this comment? Its not clear what we are trying to say here.
This PR is a step towards encoder/decoder model support. This PR modifies the xFormers backend* such that (1) the attention impl can implement cross-attention, and (2) the attention metadata data structure can represent the necessary metadata for invoking cross-attention.
* FlashAttention backend support for encoder/decoder models is left as future work
A quick overview of the plan for supporting encoder/decoder models in vLLM:
Prefill phase: (1) Non-autoregressive encoder inference yields encoder hidden states in a single pass; no KV caching occurs. (2) decoder prefill yields first-token-prediction & cached KVs. Within the decoder, cross-attention layers cache the KVs derived from encoder hidden states:
Key_{cross-attn, layer-n} = W_{K, cross-attn, layer-n} x (Encoder hidden states)
Value_{cross-attn, layer-n} = W_{V, cross-attn, layer-n} x (Encoder hidden states)
Note that all cross-attention layers consume the same encoder hidden states; however each cross-attention layers' keys and values differ because each layer has unique W_{K, cross-attn, layer-n} and W_{V, cross-attn, layer-n}. Therefore, the cross-attention KV cache must store KVs for each decoder layer, even though these KVs are all derived from a single set of encoder hidden states.
Note that self-attention layer behavior is unchanged compared to what it would be in a decoder-only model (cache KVs computed from the previous decoder layer outputs.)
Decode phase: during each iteration of the autoregressive decode process,
To implement the above encoder/decoder inference process, the following functionality will be added to vLLM over the course of multiple PRs:
Note 1: because this PR makes an incremental contribution (cross-attention KV-caching and memory management), this PR will not enable end-to-end encoder/decoder support (this will rely on later PRs.)
Note 2: the best effort is being made to ensure that encoder/decoder models are compatible with existing vLLM features. At this time, encoder/decoder models are unlikely to be compatible with the following vLLM features:
INCREMENTAL FIX TOWARDS #187
BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE
PR Checklist (Click to Expand)
Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.
PR Title and Classification
Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:
[Bugfix]
for bug fixes.[CI/Build]
for build or continuous integration improvements.[Doc]
for documentation fixes and improvements.[Model]
for adding a new model or improving an existing model. Model name should appear in the title.[Frontend]
For changes on the vLLM frontend (e.g., OpenAI API server,LLM
class, etc.)[Kernel]
for changes affecting CUDA kernels or other compute kernels.[Core]
for changes in the core vLLM logic (e.g.,LLMEngine
,AsyncLLMEngine
,Scheduler
, etc.)[Hardware][Vendor]
for hardware-specific changes. Vendor name should appear in the prefix (e.g.,[Hardware][AMD]
).[Misc]
for PRs that do not fit the above categories. Please use this sparingly.Note: If the PR spans more than one category, please include all relevant prefixes.
Code Quality
The PR need to meet the following code quality standards:
format.sh
to format your code.docs/source/
if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.Notes for Large Changes
Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with
rfc-required
and might not go through the PR.What to Expect for the Reviews
The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:
action-required
label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.Thank You
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!