Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Different Compute Layout for Attention #709

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 35 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,40 @@
"dataset_path=gs://test-maxtext-dataset",
"steps=2",
"enable_checkpointing=false"]
}
},
{
"name": "Debug MaxText Inference Microbenchmark",
"type": "python",
"request": "launch",
"console": "integratedTerminal",
"justMyCode": false,
"python": "python3",
"program": "${workspaceFolder}/MaxText/inference_microbenchmark.py",
"args": [
"MaxText/configs/base.yml",
"model_name=llama2-7b",
"tokenizer_path=assets/tokenizer.llama2",
"weight_dtype=bfloat16",
"scan_layers=false",
"attention=dot_product",
"max_prefill_predict_length=1024",
"max_target_length=2048",
"ici_fsdp_parallelism=1",
"ici_tensor_parallelism=-1",
"ici_autoregressive_parallelism=1",
"inference_microbenchmark_prefill_lengths=32,64,128,256,512,1024",
"inference_microbenchmark_stages=generate",
"inference_microbenchmark_loop_iters=1",
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
"base_output_directory=gs://test-maxtext-output",
"prefill_cache_axis_order=0,2,1,3",
"ar_cache_axis_order=0,2,1,3",
"compute_axis_order=0,2,1,3",
"reshape_q=true",
"per_device_batch_size=24",
"quantization=int8",
"quantize_kvcache=True",
]
},
Comment on lines +36 to +68
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question -- Is this to auto run for every local commit/amend on vscode? or just for convinience?

Copy link
Collaborator Author

@morgandu morgandu Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, I added this when I was debugging on my local vscode, which I think will be helpful for other engineers too. You may need to change the flags depends on your run though.

]
}
12 changes: 8 additions & 4 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,14 @@ inference_microbenchmark_log_file_path: ""
# KV Cache layout control
# Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV
# Default layout: 1,2,0,3 ; CACHE_SEQUENCE, CACHE_HEADS, CACHE_BATCH, CACHE_KV
prefill_key_axis_order: "1,2,0,3"
prefill_value_axis_order: "1,2,0,3"
ar_key_axis_order: "1,2,0,3"
ar_value_axis_order: "1,2,0,3"
prefill_cache_axis_order: "1,2,0,3"
ar_cache_axis_order: "1,2,0,3"

# Compute layout control
# Default layout: 0,1,2,3 ; BATCH, LENGTH, HEAD, D_KV
# Currently only support compute layout: 0,1,2,3 and 0,2,1,3
compute_axis_order: "0,1,2,3"
morgandu marked this conversation as resolved.
Show resolved Hide resolved

reshape_q: False

# Maxengine Metrics
Expand Down
138 changes: 78 additions & 60 deletions MaxText/layers/attentions.py

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,9 @@ def __call__(
name="self_attention",
quant=self.quant,
quantize_kvcache=cfg.quantize_kvcache,
prefill_key_axis_order=tuple([int(i) for i in cfg.prefill_key_axis_order.split(",")]),
prefill_value_axis_order=tuple([int(i) for i in cfg.prefill_value_axis_order.split(",")]),
ar_key_axis_order=tuple([int(i) for i in cfg.ar_key_axis_order.split(",")]),
ar_value_axis_order=tuple([int(i) for i in cfg.ar_value_axis_order.split(",")]),
prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]),
ar_cache_axis_order=tuple([int(i) for i in cfg.ar_cache_axis_order.split(",")]),
compute_axis_order=tuple([int(i) for i in cfg.compute_axis_order.split(",")]),
reshape_q=cfg.reshape_q,
)

Expand Down
8 changes: 4 additions & 4 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ def __call__(
name="self_attention",
quant=self.quant,
quantize_kvcache=cfg.quantize_kvcache,
prefill_key_axis_order=tuple([int(i) for i in cfg.prefill_key_axis_order.split(",")]),
prefill_value_axis_order=tuple([int(i) for i in cfg.prefill_value_axis_order.split(",")]),
ar_key_axis_order=tuple([int(i) for i in cfg.ar_key_axis_order.split(",")]),
ar_value_axis_order=tuple([int(i) for i in cfg.ar_value_axis_order.split(",")]),
prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]),
ar_cache_axis_order=tuple([int(i) for i in cfg.ar_cache_axis_order.split(",")]),
compute_axis_order=tuple([int(i) for i in cfg.compute_axis_order.split(",")]),
reshape_q=cfg.reshape_q,
)

attention_lnx = attention_layer(
Expand Down
6 changes: 6 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def string_to_bool(s: str) -> bool:
_yaml_types_to_parser = {str: str, int: int, float: float, bool: string_to_bool}


def validate_compute_axis_order(s: str) -> None:
valid_compute_axis_order = ("0,1,2,3", "0,2,1,3")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not allow other layouts? Does the code break with others or just run slow? I'd allow other layouts similar to prefill_cache_axis_order/ar_cache_axis_order.
Maybe remove the exception and just print a warning that others are untested?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, I noticed in attention.py you specifically look for those two.

if s not in valid_compute_axis_order: # currently supported compute_axis_order
raise ValueError("Invalid compute_axis_order was passed. Valid options ", valid_compute_axis_order)

def validate_attention_type(s: str) -> None:
valid_attention_types = ("autoselected", "dot_product", "flash", "cudnn_flash_te")
if s not in valid_attention_types: # currently supported attention
Expand All @@ -66,6 +71,7 @@ def validate_profiler_type(s: str) -> None:
def validate_keys(keys):
validate_attention_type(keys["attention"])
validate_profiler_type(keys["profiler"])
validate_compute_axis_order(keys["compute_axis_order"])

assert (keys["load_parameters_path"] == "" and keys["load_full_state_path"] == "") or keys[
"enable_checkpointing"
Expand Down
188 changes: 103 additions & 85 deletions MaxText/tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Tests for Attentions."""

import itertools
import random
import sys
import unittest

Expand All @@ -29,7 +31,6 @@
import pyconfig

from layers import attentions
from layers import embeddings

Mesh = jax.sharding.Mesh
Attention = attentions.Attention
Expand Down Expand Up @@ -257,71 +258,73 @@ def tpu_kernel_attention_helper(self, num_kv_heads):
)

@pytest.mark.tpu
def test_dot_product_1203_1203(self):
self.dot_product_attention_helper(
prefill_cache_axis_order=(1,2,0,3),
ar_cache_axis_order=(1,2,0,3)
)
def test_dot_product_cache_axis_order(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding unitests!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of course!

all_axis_orders = [axis_order for axis_order in itertools.permutations(range(4))]
for axis_order in random.choices(all_axis_orders, k=4):
self.dot_product_attention_helper(
prefill_cache_axis_order=axis_order,
ar_cache_axis_order=axis_order
)
print(f"passed test for {axis_order=}")

@pytest.mark.tpu
def test_dot_product_1203_2130(self):
self.dot_product_attention_helper(
prefill_cache_axis_order=(1,2,0,3),
ar_cache_axis_order=(2,1,3,0)
)
def dot_product_attention_helper(self, prefill_cache_axis_order, ar_cache_axis_order):
for compute_axis_order in [(0,1,2,3), (0,2,1,3)]:
self._dot_product_attention(
prefill_cache_axis_order,
ar_cache_axis_order,
compute_axis_order=compute_axis_order,
)
print(f"passed subtest for {compute_axis_order=}")

def _dot_product_attention(
self,
prefill_cache_axis_order,
ar_cache_axis_order,
compute_axis_order,
):
"""Test equalvant between different layout control in dot_product"""

@pytest.mark.tpu
def test_dot_product_2130_1203(self):
self.dot_product_attention_helper(
prefill_cache_axis_order=(2,1,3,0),
ar_cache_axis_order=(1,2,0,3)
)
rtol, atol = 1e-02, 1e-02

@pytest.mark.tpu
def test_dot_product_2130_2130(self):
self.dot_product_attention_helper(
prefill_cache_axis_order=(2,1,3,0),
ar_cache_axis_order=(2,1,3,0),
pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
enable_checkpointing=False,
max_target_length=128,
max_prefill_predict_length=16,
attention="dot_product",
)
config = pyconfig.config

def dot_product_attention_helper(self, prefill_cache_axis_order, ar_cache_axis_order):
self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=False, rtol=1e-02, atol=1e-01)
self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=True, rtol=1e-01, atol=1e-01)

def _dot_product_attention(self, prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache, rtol, atol):
"""Test equalvant between different layout control in dot_product"""
prefill_length = self.max_prefill_predict_length
decode_total_length = self.max_target_length
lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype)
prefill_length = config.max_prefill_predict_length
decode_total_length = config.max_target_length
lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(config.dtype)

lnx_prefill = lnx[:, 0:prefill_length, :]
decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length]
decoder_positions_prefill = decoder_positions[:, 0:prefill_length]

attention_w_layout = Attention(
config=self.cfg,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
max_target_length=self.max_target_length,
max_prefill_predict_length=self.max_prefill_predict_length,
mesh=self.mesh,
attention_kernel="dot_product",
dtype=self.dtype,
prefill_key_axis_order=prefill_cache_axis_order,
prefill_value_axis_order=prefill_cache_axis_order,
ar_key_axis_order=ar_cache_axis_order,
ar_value_axis_order=ar_cache_axis_order,
quantize_kvcache=quantize_kvcache,
config=config,
num_query_heads=config.num_query_heads,
num_kv_heads=config.num_kv_heads,
head_dim=config.head_dim,
max_target_length=config.max_target_length,
max_prefill_predict_length=config.max_prefill_predict_length,
attention_kernel=config.attention,
dtype=config.dtype,
prefill_cache_axis_order=prefill_cache_axis_order,
ar_cache_axis_order=ar_cache_axis_order,
compute_axis_order=compute_axis_order,
)

attention_w_layout_variable = attention_w_layout.init(
{"params": self.rng, "aqt": self.rng},
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length)),
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)),
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)),
jnp.ones((self.global_batch_size, config.max_target_length)),
)

attention_w_layout_full = attention_w_layout.apply(
attention_w_layout_variable,
lnx,
Expand Down Expand Up @@ -371,59 +374,74 @@ def _dot_product_attention(self, prefill_cache_axis_order, ar_cache_axis_order,

@pytest.mark.tpu
def test_dot_product_reshape_q(self):
self._dot_product_attention_reshape_q(quantize_kvcache=True, rtol=1e-01, atol=1e-01)
self._dot_product_attention_reshape_q(quantize_kvcache=False, rtol=1e-02, atol=1e-02)
for compute_axis_order in [(0,1,2,3), (0,2,1,3)]:
self._dot_product_attention_reshape_q(
compute_axis_order=compute_axis_order,
)
print(f"test passed for compute_axis_order: {compute_axis_order}")

def _dot_product_attention_reshape_q(self, quantize_kvcache, rtol, atol):
def _dot_product_attention_reshape_q(self, compute_axis_order):
"""Test equalvant between q and reshape q in dot_product"""
prefill_length = self.max_prefill_predict_length
decode_total_length = self.max_target_length
lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype)

rtol, atol = 1e-02, 1e-02

pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
enable_checkpointing=False,
max_target_length=128,
max_prefill_predict_length=16,
attention="dot_product",
)
config = pyconfig.config

prefill_length = config.max_prefill_predict_length
decode_total_length = config.max_target_length
lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(config.dtype)

lnx_prefill = lnx[:, 0:prefill_length, :]
decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length]
decoder_positions_prefill = decoder_positions[:, 0:prefill_length]

attention_wo_reshape_q = Attention(
config=self.cfg,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
max_target_length=self.max_target_length,
max_prefill_predict_length=self.max_prefill_predict_length,
mesh=self.mesh,
attention_kernel="dot_product",
dtype=self.dtype,
config=config,
num_query_heads=config.num_query_heads,
num_kv_heads=config.num_kv_heads,
head_dim=config.head_dim,
max_target_length=config.max_target_length,
max_prefill_predict_length=config.max_prefill_predict_length,
attention_kernel=config.attention,
dtype=config.dtype,
compute_axis_order=compute_axis_order,
reshape_q=False,
quantize_kvcache=quantize_kvcache,
)
attention_wo_reshape_q_variable = attention_wo_reshape_q.init(
{"params": self.rng, "aqt": self.rng},
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)),
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)),
jnp.ones((self.global_batch_size, config.max_target_length)),
)

attention_w_reshape_q = Attention(
config=self.cfg,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
max_target_length=self.max_target_length,
max_prefill_predict_length=self.max_prefill_predict_length,
mesh=self.mesh,
attention_kernel="dot_product",
dtype=self.dtype,
config=config,
num_query_heads=config.num_query_heads,
num_kv_heads=config.num_kv_heads,
head_dim=config.head_dim,
max_target_length=config.max_target_length,
max_prefill_predict_length=config.max_prefill_predict_length,
attention_kernel=config.attention,
dtype=config.dtype,
compute_axis_order=compute_axis_order,
reshape_q=True,
quantize_kvcache=quantize_kvcache,
)

attention_wo_reshape_q_variable = attention_wo_reshape_q.init(
{"params": self.rng, "aqt": self.rng},
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length)),
)

attention_w_reshape_q_variable = attention_w_reshape_q.init(
{"params": self.rng, "aqt": self.rng},
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length)),
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)),
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)),
jnp.ones((self.global_batch_size, config.max_target_length)),
)

attention_wo_reshape_q_full = attention_wo_reshape_q.apply(
Expand Down