-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow Different Compute Layout for Attention #709
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,11 @@ def string_to_bool(s: str) -> bool: | |
_yaml_types_to_parser = {str: str, int: int, float: float, bool: string_to_bool} | ||
|
||
|
||
def validate_compute_axis_order(s: str) -> None: | ||
valid_compute_axis_order = ("0,1,2,3", "0,2,1,3") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not allow other layouts? Does the code break with others or just run slow? I'd allow other layouts similar to prefill_cache_axis_order/ar_cache_axis_order. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nvm, I noticed in attention.py you specifically look for those two. |
||
if s not in valid_compute_axis_order: # currently supported compute_axis_order | ||
raise ValueError("Invalid compute_axis_order was passed. Valid options ", valid_compute_axis_order) | ||
|
||
def validate_attention_type(s: str) -> None: | ||
valid_attention_types = ("autoselected", "dot_product", "flash", "cudnn_flash_te") | ||
if s not in valid_attention_types: # currently supported attention | ||
|
@@ -66,6 +71,7 @@ def validate_profiler_type(s: str) -> None: | |
def validate_keys(keys): | ||
validate_attention_type(keys["attention"]) | ||
validate_profiler_type(keys["profiler"]) | ||
validate_compute_axis_order(keys["compute_axis_order"]) | ||
|
||
assert (keys["load_parameters_path"] == "" and keys["load_full_state_path"] == "") or keys[ | ||
"enable_checkpointing" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ | |
|
||
"""Tests for Attentions.""" | ||
|
||
import itertools | ||
import random | ||
import sys | ||
import unittest | ||
|
||
|
@@ -29,7 +31,6 @@ | |
import pyconfig | ||
|
||
from layers import attentions | ||
from layers import embeddings | ||
|
||
Mesh = jax.sharding.Mesh | ||
Attention = attentions.Attention | ||
|
@@ -257,71 +258,73 @@ def tpu_kernel_attention_helper(self, num_kv_heads): | |
) | ||
|
||
@pytest.mark.tpu | ||
def test_dot_product_1203_1203(self): | ||
self.dot_product_attention_helper( | ||
prefill_cache_axis_order=(1,2,0,3), | ||
ar_cache_axis_order=(1,2,0,3) | ||
) | ||
def test_dot_product_cache_axis_order(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for adding unitests! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. of course! |
||
all_axis_orders = [axis_order for axis_order in itertools.permutations(range(4))] | ||
for axis_order in random.choices(all_axis_orders, k=4): | ||
self.dot_product_attention_helper( | ||
prefill_cache_axis_order=axis_order, | ||
ar_cache_axis_order=axis_order | ||
) | ||
print(f"passed test for {axis_order=}") | ||
|
||
@pytest.mark.tpu | ||
def test_dot_product_1203_2130(self): | ||
self.dot_product_attention_helper( | ||
prefill_cache_axis_order=(1,2,0,3), | ||
ar_cache_axis_order=(2,1,3,0) | ||
) | ||
def dot_product_attention_helper(self, prefill_cache_axis_order, ar_cache_axis_order): | ||
for compute_axis_order in [(0,1,2,3), (0,2,1,3)]: | ||
self._dot_product_attention( | ||
prefill_cache_axis_order, | ||
ar_cache_axis_order, | ||
compute_axis_order=compute_axis_order, | ||
) | ||
print(f"passed subtest for {compute_axis_order=}") | ||
|
||
def _dot_product_attention( | ||
self, | ||
prefill_cache_axis_order, | ||
ar_cache_axis_order, | ||
compute_axis_order, | ||
): | ||
"""Test equalvant between different layout control in dot_product""" | ||
|
||
@pytest.mark.tpu | ||
def test_dot_product_2130_1203(self): | ||
self.dot_product_attention_helper( | ||
prefill_cache_axis_order=(2,1,3,0), | ||
ar_cache_axis_order=(1,2,0,3) | ||
) | ||
rtol, atol = 1e-02, 1e-02 | ||
|
||
@pytest.mark.tpu | ||
def test_dot_product_2130_2130(self): | ||
self.dot_product_attention_helper( | ||
prefill_cache_axis_order=(2,1,3,0), | ||
ar_cache_axis_order=(2,1,3,0), | ||
pyconfig.initialize( | ||
[sys.argv[0], "configs/base.yml"], | ||
per_device_batch_size=1.0, | ||
run_name="test", | ||
enable_checkpointing=False, | ||
max_target_length=128, | ||
max_prefill_predict_length=16, | ||
attention="dot_product", | ||
) | ||
config = pyconfig.config | ||
|
||
def dot_product_attention_helper(self, prefill_cache_axis_order, ar_cache_axis_order): | ||
self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=False, rtol=1e-02, atol=1e-01) | ||
self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=True, rtol=1e-01, atol=1e-01) | ||
|
||
def _dot_product_attention(self, prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache, rtol, atol): | ||
"""Test equalvant between different layout control in dot_product""" | ||
prefill_length = self.max_prefill_predict_length | ||
decode_total_length = self.max_target_length | ||
lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) | ||
prefill_length = config.max_prefill_predict_length | ||
decode_total_length = config.max_target_length | ||
lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(config.dtype) | ||
|
||
lnx_prefill = lnx[:, 0:prefill_length, :] | ||
decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] | ||
decoder_positions_prefill = decoder_positions[:, 0:prefill_length] | ||
|
||
attention_w_layout = Attention( | ||
config=self.cfg, | ||
num_query_heads=self.num_query_heads, | ||
num_kv_heads=self.num_kv_heads, | ||
head_dim=self.head_dim, | ||
max_target_length=self.max_target_length, | ||
max_prefill_predict_length=self.max_prefill_predict_length, | ||
mesh=self.mesh, | ||
attention_kernel="dot_product", | ||
dtype=self.dtype, | ||
prefill_key_axis_order=prefill_cache_axis_order, | ||
prefill_value_axis_order=prefill_cache_axis_order, | ||
ar_key_axis_order=ar_cache_axis_order, | ||
ar_value_axis_order=ar_cache_axis_order, | ||
quantize_kvcache=quantize_kvcache, | ||
config=config, | ||
num_query_heads=config.num_query_heads, | ||
num_kv_heads=config.num_kv_heads, | ||
head_dim=config.head_dim, | ||
max_target_length=config.max_target_length, | ||
max_prefill_predict_length=config.max_prefill_predict_length, | ||
attention_kernel=config.attention, | ||
dtype=config.dtype, | ||
prefill_cache_axis_order=prefill_cache_axis_order, | ||
ar_cache_axis_order=ar_cache_axis_order, | ||
compute_axis_order=compute_axis_order, | ||
) | ||
|
||
attention_w_layout_variable = attention_w_layout.init( | ||
{"params": self.rng, "aqt": self.rng}, | ||
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), | ||
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), | ||
jnp.ones((self.global_batch_size, self.max_target_length)), | ||
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)), | ||
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)), | ||
jnp.ones((self.global_batch_size, config.max_target_length)), | ||
) | ||
|
||
attention_w_layout_full = attention_w_layout.apply( | ||
attention_w_layout_variable, | ||
lnx, | ||
|
@@ -371,59 +374,74 @@ def _dot_product_attention(self, prefill_cache_axis_order, ar_cache_axis_order, | |
|
||
@pytest.mark.tpu | ||
def test_dot_product_reshape_q(self): | ||
self._dot_product_attention_reshape_q(quantize_kvcache=True, rtol=1e-01, atol=1e-01) | ||
self._dot_product_attention_reshape_q(quantize_kvcache=False, rtol=1e-02, atol=1e-02) | ||
for compute_axis_order in [(0,1,2,3), (0,2,1,3)]: | ||
self._dot_product_attention_reshape_q( | ||
compute_axis_order=compute_axis_order, | ||
) | ||
print(f"test passed for compute_axis_order: {compute_axis_order}") | ||
|
||
def _dot_product_attention_reshape_q(self, quantize_kvcache, rtol, atol): | ||
def _dot_product_attention_reshape_q(self, compute_axis_order): | ||
"""Test equalvant between q and reshape q in dot_product""" | ||
prefill_length = self.max_prefill_predict_length | ||
decode_total_length = self.max_target_length | ||
lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) | ||
|
||
rtol, atol = 1e-02, 1e-02 | ||
|
||
pyconfig.initialize( | ||
[sys.argv[0], "configs/base.yml"], | ||
per_device_batch_size=1.0, | ||
run_name="test", | ||
enable_checkpointing=False, | ||
max_target_length=128, | ||
max_prefill_predict_length=16, | ||
attention="dot_product", | ||
) | ||
config = pyconfig.config | ||
|
||
prefill_length = config.max_prefill_predict_length | ||
decode_total_length = config.max_target_length | ||
lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(config.dtype) | ||
|
||
lnx_prefill = lnx[:, 0:prefill_length, :] | ||
decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] | ||
decoder_positions_prefill = decoder_positions[:, 0:prefill_length] | ||
|
||
attention_wo_reshape_q = Attention( | ||
config=self.cfg, | ||
num_query_heads=self.num_query_heads, | ||
num_kv_heads=self.num_kv_heads, | ||
head_dim=self.head_dim, | ||
max_target_length=self.max_target_length, | ||
max_prefill_predict_length=self.max_prefill_predict_length, | ||
mesh=self.mesh, | ||
attention_kernel="dot_product", | ||
dtype=self.dtype, | ||
config=config, | ||
num_query_heads=config.num_query_heads, | ||
num_kv_heads=config.num_kv_heads, | ||
head_dim=config.head_dim, | ||
max_target_length=config.max_target_length, | ||
max_prefill_predict_length=config.max_prefill_predict_length, | ||
attention_kernel=config.attention, | ||
dtype=config.dtype, | ||
compute_axis_order=compute_axis_order, | ||
reshape_q=False, | ||
quantize_kvcache=quantize_kvcache, | ||
) | ||
attention_wo_reshape_q_variable = attention_wo_reshape_q.init( | ||
{"params": self.rng, "aqt": self.rng}, | ||
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)), | ||
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)), | ||
jnp.ones((self.global_batch_size, config.max_target_length)), | ||
) | ||
|
||
attention_w_reshape_q = Attention( | ||
config=self.cfg, | ||
num_query_heads=self.num_query_heads, | ||
num_kv_heads=self.num_kv_heads, | ||
head_dim=self.head_dim, | ||
max_target_length=self.max_target_length, | ||
max_prefill_predict_length=self.max_prefill_predict_length, | ||
mesh=self.mesh, | ||
attention_kernel="dot_product", | ||
dtype=self.dtype, | ||
config=config, | ||
num_query_heads=config.num_query_heads, | ||
num_kv_heads=config.num_kv_heads, | ||
head_dim=config.head_dim, | ||
max_target_length=config.max_target_length, | ||
max_prefill_predict_length=config.max_prefill_predict_length, | ||
attention_kernel=config.attention, | ||
dtype=config.dtype, | ||
compute_axis_order=compute_axis_order, | ||
reshape_q=True, | ||
quantize_kvcache=quantize_kvcache, | ||
) | ||
|
||
attention_wo_reshape_q_variable = attention_wo_reshape_q.init( | ||
{"params": self.rng, "aqt": self.rng}, | ||
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), | ||
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), | ||
jnp.ones((self.global_batch_size, self.max_target_length)), | ||
) | ||
|
||
attention_w_reshape_q_variable = attention_w_reshape_q.init( | ||
{"params": self.rng, "aqt": self.rng}, | ||
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), | ||
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), | ||
jnp.ones((self.global_batch_size, self.max_target_length)), | ||
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)), | ||
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)), | ||
jnp.ones((self.global_batch_size, config.max_target_length)), | ||
) | ||
|
||
attention_wo_reshape_q_full = attention_wo_reshape_q.apply( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n00b question -- Is this to auto run for every local commit/amend on vscode? or just for convinience?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, I added this when I was debugging on my local vscode, which I think will be helpful for other engineers too. You may need to change the flags depends on your run though.