-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sharding the llama2 70b on v5e-16 more efficiently. #706
Conversation
c5bbd85
to
0ef4504
Compare
@@ -759,9 +762,6 @@ def kv_cache_autoregressive( | |||
assert cached_ar_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_key_axis_order) | |||
assert cached_ar_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_value_axis_order) | |||
|
|||
key = nn.with_logical_constraint(key, (BATCH, LENGTH, HEAD, D_KV)) | |||
value = nn.with_logical_constraint(value, (BATCH, LENGTH, HEAD, D_KV)) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does removing the logical constraint have any impact? I know in some cases, compiler is not able to propagate the sharding, thus hard coded nn.with_logical_constraint is in places in some cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are already annotated in https://github.com/google/maxtext/blob/main/MaxText/layers/attentions.py#L1057-L1062.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of removing them could you replace them with the more specific axis names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If they are already annotated in the https://github.com/google/maxtext/blob/main/MaxText/layers/attentions.py#L1057-L1062, I don't think we need to annotate it again. (just like we don't annotate it in the kv_cache_prefill, I think we don't need to annotate it in the kv_cache_autoregressive as well)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these two lines are safe to be deleted given https://github.com/google/maxtext/blob/main/MaxText/layers/attentions.py#L1057-L1062
0ef4504
to
a9951b0
Compare
77c181e
to
5122191
Compare
@@ -759,9 +762,6 @@ def kv_cache_autoregressive( | |||
assert cached_ar_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_key_axis_order) | |||
assert cached_ar_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_value_axis_order) | |||
|
|||
key = nn.with_logical_constraint(key, (BATCH, LENGTH, HEAD, D_KV)) | |||
value = nn.with_logical_constraint(value, (BATCH, LENGTH, HEAD, D_KV)) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of removing them could you replace them with the more specific axis names?
5122191
to
247e3a1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, I've similar concerns as Morgan and Matthew regarding configs vs. model specific code.
Talked off line, George is OOO soon and doesn't time right now. He can make these changes once he is back in 2 weeks. I'm ok with merging this as a short-term fix, will let @gobbleturk decide. |
b44df5e
to
5e6f7a4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks George!
Thanks for the review, Matt, Vipan and Morgan! |
d1c9cc7
to
e8a4961
Compare
e8a4961
to
8bf9f8e
Compare
https://arxiv.org/pdf/2211.05102
https://arxiv.org/pdf/1909.08053