Skip to content

[v1] Support multiple KV cache groups in GPU model runner #17945

New issue

Have a question about this project? Sign up for a free account to open an issue and contact its maintainers and the community.

By clicking “Sign up for ”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on ? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

heheda12345
Copy link
Collaborator

@heheda12345 heheda12345 commented May 10, 2025

Should be merged after #17483

This PR finishes the hybrid allocator support on worker side. It does the following things:

  1. change block_ids in SchedulerOutput to list[list[int]], where the outer list is for multiple kv cache groups and inner list is for blocks in one group.
  2. Create BlockTable class for each kv cache group.
  3. Build different attention metadata for each kv cache group.
  4. TPU backend still only supports one KVCacheGroup after this PR.

Splitted from #16101

@github-actionsGitHub Actions
Copy link

-actions bot commented May 10, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergifymergify bot added v1 tpuRelated to Google TPUslabels May 10, 2025
@mergifyMergify
Copy link

mergify bot commented May 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @heheda12345.

https://docs..com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@heheda12345heheda12345 force-pushed the multi_group_worker branch from 5ef5bed to f65b904 Compare May 11, 2025 02:26
@mergifymergify bot removed the needs-rebase label May 11, 2025
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Comment on lines 87 to 92
# Some layers may be regarded as full attention layers in KV cache manager (
# blocks are allocated for all tokens), while computed as sliding window
# attention in model runner. In this case, we use FullAttentionSpec and
# record the sliding window size. Default to None for not using sliding
# window attention.
sliding_window: Optional[int] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for the case where the hybrid allocator is disabled? If so, please leave a comment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I've updated the comment.

Comment on lines +271 to +280
batch_reordered = self.attn_metadata_builders[0].reorder_batch(
self.input_batch, scheduler_output)

# For models with multiple KV cache groups, the groups should agree on
# the same order of requests. We ensure this by only allowing the first
# group to reorder the batch and asserting that all other groups do not
# reorder the batch.
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
assert not self.attn_metadata_builders[i].reorder_batch(
self.input_batch, scheduler_output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: What if the first group is full attn and the second group is MLA? IIUC, the current code will fail in this case. Is this intended?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. But it's fine as no model contains both full attn and MLA now. Prefer to raise an error here and find a solution when such a model is released.

Comment on lines +59 to +67
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
"""
assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), (
"All layers in the same KV cache group must share the same "
"type_id.")
return copy.deepcopy(specs[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want to inherit and override this? What about defining this as a utility function outside the class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to keep the function inside the class. If it is a utility function, it is highly possible that people will forget to update that function when extending the KVCacheSpecs.

Comment on lines +102 to +119
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
merged_spec = super().merge(specs)
sliding_window = set(spec.sliding_window for spec in specs
if spec.sliding_window is not None)
if len(sliding_window) == 0:
merged_spec.sliding_window = None
elif len(sliding_window) == 1:
merged_spec.sliding_window = sliding_window.pop()
else:
raise ValueError(
"All sliding window layers in the same KV cache group "
"must have the same window size.")
return merged_spec
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need a similar logic in SlidingWindowSpec as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need it as SlidingWindowSpec.type_id contains sliding window size and can help to ensure that layers with different sliding window size are in different kv cache groups.

Signed-off-by: Chen Zhang <[email protected]>
@WoosukKwonWoosukKwon added the readyONLY add when PR is ready to merge/full CI is neededlabel May 11, 2025
@mergifyMergify
Copy link

mergify bot commented May 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @heheda12345.

https://docs..com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on . Already have an account? Sign in to comment
Labels
needs-rebase readyONLY add when PR is ready to merge/full CI is neededtpuRelated to Google TPUsv1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants