Skip to content

Commit

Permalink
add compute_axis_order
Browse files Browse the repository at this point in the history
  • Loading branch information
morgandu committed Jun 17, 2024
1 parent 75b3a5e commit c5ee451
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 158 deletions.
36 changes: 35 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,40 @@
"dataset_path=gs://test-maxtext-dataset",
"steps=2",
"enable_checkpointing=false"]
}
},
{
"name": "Debug MaxText Inference Microbenchmark",
"type": "python",
"request": "launch",
"console": "integratedTerminal",
"justMyCode": false,
"python": "python3",
"program": "${workspaceFolder}/MaxText/inference_microbenchmark.py",
"args": [
"MaxText/configs/base.yml",
"model_name=llama2-7b",
"tokenizer_path=assets/tokenizer.llama2",
"weight_dtype=bfloat16",
"scan_layers=false",
"attention=dot_product",
"max_prefill_predict_length=1024",
"max_target_length=2048",
"ici_fsdp_parallelism=1",
"ici_tensor_parallelism=-1",
"ici_autoregressive_parallelism=1",
"inference_microbenchmark_prefill_lengths=32,64,128,256,512,1024",
"inference_microbenchmark_stages=generate",
"inference_microbenchmark_loop_iters=1",
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
"base_output_directory=gs://test-maxtext-output",
"prefill_cache_axis_order=0,2,1,3",
"ar_cache_axis_order=0,2,1,3",
"compute_axis_order=0,2,1,3",
"reshape_q=true",
"per_device_batch_size=24",
"quantization=int8",
"quantize_kvcache=True",
]
},
]
}
11 changes: 7 additions & 4 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,13 @@ inference_microbenchmark_log_file_path: ""
# KV Cache layout control
# Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV
# Default layout: 1,2,0,3 ; CACHE_SEQUENCE, CACHE_HEADS, CACHE_BATCH, CACHE_KV
prefill_key_axis_order: "1,2,0,3"
prefill_value_axis_order: "1,2,0,3"
ar_key_axis_order: "1,2,0,3"
ar_value_axis_order: "1,2,0,3"
prefill_cache_axis_order: "1,2,0,3"
ar_cache_axis_order: "1,2,0,3"

# Compute layout control
# Default layout: 0,1,2,3 ; BATCH, LENGTH, HEAD, D_KV
compute_axis_order: "0,1,2,3"

reshape_q: False

# Maxengine Metrics
Expand Down
138 changes: 78 additions & 60 deletions MaxText/layers/attentions.py

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,9 @@ def __call__(
name="self_attention",
quant=self.quant,
quantize_kvcache=cfg.quantize_kvcache,
prefill_key_axis_order=tuple([int(i) for i in cfg.prefill_key_axis_order.split(",")]),
prefill_value_axis_order=tuple([int(i) for i in cfg.prefill_value_axis_order.split(",")]),
ar_key_axis_order=tuple([int(i) for i in cfg.ar_key_axis_order.split(",")]),
ar_value_axis_order=tuple([int(i) for i in cfg.ar_value_axis_order.split(",")]),
prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]),
ar_cache_axis_order=tuple([int(i) for i in cfg.ar_cache_axis_order.split(",")]),
compute_axis_order=tuple([int(i) for i in cfg.compute_axis_order.split(",")]),
reshape_q=cfg.reshape_q,
)

Expand Down
8 changes: 4 additions & 4 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ def __call__(
name="self_attention",
quant=self.quant,
quantize_kvcache=cfg.quantize_kvcache,
prefill_key_axis_order=tuple([int(i) for i in cfg.prefill_key_axis_order.split(",")]),
prefill_value_axis_order=tuple([int(i) for i in cfg.prefill_value_axis_order.split(",")]),
ar_key_axis_order=tuple([int(i) for i in cfg.ar_key_axis_order.split(",")]),
ar_value_axis_order=tuple([int(i) for i in cfg.ar_value_axis_order.split(",")]),
prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]),
ar_cache_axis_order=tuple([int(i) for i in cfg.ar_cache_axis_order.split(",")]),
compute_axis_order=tuple([int(i) for i in cfg.compute_axis_order.split(",")]),
reshape_q=cfg.reshape_q,
)

attention_lnx = attention_layer(
Expand Down
6 changes: 6 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ def string_to_bool(s: str) -> bool:
_yaml_types_to_parser = {str: str, int: int, float: float, bool: string_to_bool}


def validate_compute_axis_order(s: str) -> None:
valid_compute_axis_order = ("0,1,2,3", "0,2,1,3")
if s not in valid_compute_axis_order: # currently supported compute_axis_order
raise ValueError("Invalid compute_axis_order was passed. Valid options ", valid_compute_axis_order)

def validate_attention_type(s: str) -> None:
valid_attention_types = ("autoselected", "dot_product", "flash", "cudnn_flash_te")
if s not in valid_attention_types: # currently supported attention
Expand All @@ -66,6 +71,7 @@ def validate_profiler_type(s: str) -> None:
def validate_keys(keys):
validate_attention_type(keys["attention"])
validate_profiler_type(keys["profiler"])
validate_compute_axis_order(keys["compute_axis_order"])

assert (keys["load_parameters_path"] == "" and keys["load_full_state_path"] == "") or keys[
"enable_checkpointing"
Expand Down
188 changes: 103 additions & 85 deletions MaxText/tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Tests for Attentions."""

import itertools
import random
import sys
import unittest

Expand All @@ -29,7 +31,6 @@
import pyconfig

from layers import attentions
from layers import embeddings

Mesh = jax.sharding.Mesh
Attention = attentions.Attention
Expand Down Expand Up @@ -257,71 +258,73 @@ def tpu_kernel_attention_helper(self, num_kv_heads):
)

@pytest.mark.tpu
def test_dot_product_1203_1203(self):
self.dot_product_attention_helper(
prefill_cache_axis_order=(1,2,0,3),
ar_cache_axis_order=(1,2,0,3)
)
def test_dot_product_cache_axis_order(self):
all_axis_orders = [axis_order for axis_order in itertools.permutations(range(4))]
for axis_order in random.choices(all_axis_orders, k=4):
self.dot_product_attention_helper(
prefill_cache_axis_order=axis_order,
ar_cache_axis_order=axis_order
)
print(f"passed test for {axis_order=}")

@pytest.mark.tpu
def test_dot_product_1203_2130(self):
self.dot_product_attention_helper(
prefill_cache_axis_order=(1,2,0,3),
ar_cache_axis_order=(2,1,3,0)
)
def dot_product_attention_helper(self, prefill_cache_axis_order, ar_cache_axis_order):
for compute_axis_order in [(0,1,2,3), (0,2,1,3)]:
self._dot_product_attention(
prefill_cache_axis_order,
ar_cache_axis_order,
compute_axis_order=compute_axis_order,
)
print(f"passed subtest for {compute_axis_order=}")

def _dot_product_attention(
self,
prefill_cache_axis_order,
ar_cache_axis_order,
compute_axis_order,
):
"""Test equalvant between different layout control in dot_product"""

@pytest.mark.tpu
def test_dot_product_2130_1203(self):
self.dot_product_attention_helper(
prefill_cache_axis_order=(2,1,3,0),
ar_cache_axis_order=(1,2,0,3)
)
rtol, atol = 1e-02, 1e-02

@pytest.mark.tpu
def test_dot_product_2130_2130(self):
self.dot_product_attention_helper(
prefill_cache_axis_order=(2,1,3,0),
ar_cache_axis_order=(2,1,3,0),
pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
enable_checkpointing=False,
max_target_length=128,
max_prefill_predict_length=16,
attention="dot_product",
)
config = pyconfig.config

def dot_product_attention_helper(self, prefill_cache_axis_order, ar_cache_axis_order):
self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=False, rtol=1e-02, atol=1e-01)
self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=True, rtol=1e-01, atol=1e-01)

def _dot_product_attention(self, prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache, rtol, atol):
"""Test equalvant between different layout control in dot_product"""
prefill_length = self.max_prefill_predict_length
decode_total_length = self.max_target_length
lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype)
prefill_length = config.max_prefill_predict_length
decode_total_length = config.max_target_length
lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(config.dtype)

lnx_prefill = lnx[:, 0:prefill_length, :]
decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length]
decoder_positions_prefill = decoder_positions[:, 0:prefill_length]

attention_w_layout = Attention(
config=self.cfg,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
max_target_length=self.max_target_length,
max_prefill_predict_length=self.max_prefill_predict_length,
mesh=self.mesh,
attention_kernel="dot_product",
dtype=self.dtype,
prefill_key_axis_order=prefill_cache_axis_order,
prefill_value_axis_order=prefill_cache_axis_order,
ar_key_axis_order=ar_cache_axis_order,
ar_value_axis_order=ar_cache_axis_order,
quantize_kvcache=quantize_kvcache,
config=config,
num_query_heads=config.num_query_heads,
num_kv_heads=config.num_kv_heads,
head_dim=config.head_dim,
max_target_length=config.max_target_length,
max_prefill_predict_length=config.max_prefill_predict_length,
attention_kernel=config.attention,
dtype=config.dtype,
prefill_cache_axis_order=prefill_cache_axis_order,
ar_cache_axis_order=ar_cache_axis_order,
compute_axis_order=compute_axis_order,
)

attention_w_layout_variable = attention_w_layout.init(
{"params": self.rng, "aqt": self.rng},
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length)),
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)),
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)),
jnp.ones((self.global_batch_size, config.max_target_length)),
)

attention_w_layout_full = attention_w_layout.apply(
attention_w_layout_variable,
lnx,
Expand Down Expand Up @@ -371,59 +374,74 @@ def _dot_product_attention(self, prefill_cache_axis_order, ar_cache_axis_order,

@pytest.mark.tpu
def test_dot_product_reshape_q(self):
self._dot_product_attention_reshape_q(quantize_kvcache=True, rtol=1e-01, atol=1e-01)
self._dot_product_attention_reshape_q(quantize_kvcache=False, rtol=1e-02, atol=1e-02)
for compute_axis_order in [(0,1,2,3), (0,2,1,3)]:
self._dot_product_attention_reshape_q(
compute_axis_order=compute_axis_order,
)
print(f"test passed for compute_axis_order: {compute_axis_order}")

def _dot_product_attention_reshape_q(self, quantize_kvcache, rtol, atol):
def _dot_product_attention_reshape_q(self, compute_axis_order):
"""Test equalvant between q and reshape q in dot_product"""
prefill_length = self.max_prefill_predict_length
decode_total_length = self.max_target_length
lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype)

rtol, atol = 1e-02, 1e-02

pyconfig.initialize(
[sys.argv[0], "configs/base.yml"],
per_device_batch_size=1.0,
run_name="test",
enable_checkpointing=False,
max_target_length=128,
max_prefill_predict_length=16,
attention="dot_product",
)
config = pyconfig.config

prefill_length = config.max_prefill_predict_length
decode_total_length = config.max_target_length
lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(config.dtype)

lnx_prefill = lnx[:, 0:prefill_length, :]
decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length]
decoder_positions_prefill = decoder_positions[:, 0:prefill_length]

attention_wo_reshape_q = Attention(
config=self.cfg,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
max_target_length=self.max_target_length,
max_prefill_predict_length=self.max_prefill_predict_length,
mesh=self.mesh,
attention_kernel="dot_product",
dtype=self.dtype,
config=config,
num_query_heads=config.num_query_heads,
num_kv_heads=config.num_kv_heads,
head_dim=config.head_dim,
max_target_length=config.max_target_length,
max_prefill_predict_length=config.max_prefill_predict_length,
attention_kernel=config.attention,
dtype=config.dtype,
compute_axis_order=compute_axis_order,
reshape_q=False,
quantize_kvcache=quantize_kvcache,
)
attention_wo_reshape_q_variable = attention_wo_reshape_q.init(
{"params": self.rng, "aqt": self.rng},
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)),
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)),
jnp.ones((self.global_batch_size, config.max_target_length)),
)

attention_w_reshape_q = Attention(
config=self.cfg,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
max_target_length=self.max_target_length,
max_prefill_predict_length=self.max_prefill_predict_length,
mesh=self.mesh,
attention_kernel="dot_product",
dtype=self.dtype,
config=config,
num_query_heads=config.num_query_heads,
num_kv_heads=config.num_kv_heads,
head_dim=config.head_dim,
max_target_length=config.max_target_length,
max_prefill_predict_length=config.max_prefill_predict_length,
attention_kernel=config.attention,
dtype=config.dtype,
compute_axis_order=compute_axis_order,
reshape_q=True,
quantize_kvcache=quantize_kvcache,
)

attention_wo_reshape_q_variable = attention_wo_reshape_q.init(
{"params": self.rng, "aqt": self.rng},
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length)),
)

attention_w_reshape_q_variable = attention_w_reshape_q.init(
{"params": self.rng, "aqt": self.rng},
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length)),
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)),
jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)),
jnp.ones((self.global_batch_size, config.max_target_length)),
)

attention_wo_reshape_q_full = attention_wo_reshape_q.apply(
Expand Down

0 comments on commit c5ee451

Please sign in to comment.