Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharding the llama2 70b on v5e-16 more efficiently. #706

Merged
merged 1 commit into from
Jun 18, 2024

Conversation

zhihaoshan-google
Copy link
Collaborator

@@ -759,9 +762,6 @@ def kv_cache_autoregressive(
assert cached_ar_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_key_axis_order)
assert cached_ar_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_value_axis_order)

key = nn.with_logical_constraint(key, (BATCH, LENGTH, HEAD, D_KV))
value = nn.with_logical_constraint(value, (BATCH, LENGTH, HEAD, D_KV))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does removing the logical constraint have any impact? I know in some cases, compiler is not able to propagate the sharding, thus hard coded nn.with_logical_constraint is in places in some cases.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of removing them could you replace them with the more specific axis names?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If they are already annotated in the https://github.com/google/maxtext/blob/main/MaxText/layers/attentions.py#L1057-L1062, I don't think we need to annotate it again. (just like we don't annotate it in the kv_cache_prefill, I think we don't need to annotate it in the kv_cache_autoregressive as well)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MaxText/layers/llama2.py Outdated Show resolved Hide resolved
MaxText/layers/attentions.py Outdated Show resolved Hide resolved
MaxText/layers/llama2.py Outdated Show resolved Hide resolved
MaxText/max_utils.py Show resolved Hide resolved
getting_started/Data_Input_Pipeline.md Outdated Show resolved Hide resolved
@@ -759,9 +762,6 @@ def kv_cache_autoregressive(
assert cached_ar_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_key_axis_order)
assert cached_ar_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_value_axis_order)

key = nn.with_logical_constraint(key, (BATCH, LENGTH, HEAD, D_KV))
value = nn.with_logical_constraint(value, (BATCH, LENGTH, HEAD, D_KV))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of removing them could you replace them with the more specific axis names?

MaxText/layers/llama2.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@vipannalla vipannalla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I've similar concerns as Morgan and Matthew regarding configs vs. model specific code.

MaxText/layers/attentions.py Outdated Show resolved Hide resolved
MaxText/layers/llama2.py Outdated Show resolved Hide resolved
MaxText/max_utils.py Outdated Show resolved Hide resolved
@vipannalla
Copy link
Collaborator

Talked off line, George is OOO soon and doesn't time right now. He can make these changes once he is back in 2 weeks. I'm ok with merging this as a short-term fix, will let @gobbleturk decide.

@zhihaoshan-google zhihaoshan-google force-pushed the zhihaoshan_dev branch 2 times, most recently from b44df5e to 5e6f7a4 Compare June 17, 2024 21:20
Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks George!

@zhihaoshan-google
Copy link
Collaborator Author

Thanks for the review, Matt, Vipan and Morgan!

@zhihaoshan-google zhihaoshan-google force-pushed the zhihaoshan_dev branch 6 times, most recently from d1c9cc7 to e8a4961 Compare June 18, 2024 00:59
@copybara-service copybara-service bot merged commit 180a780 into main Jun 18, 2024
13 checks passed
@copybara-service copybara-service bot deleted the zhihaoshan_dev branch June 18, 2024 03:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants