Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Different Compute Layout for Attention #709

Merged
merged 1 commit into from
Jun 18, 2024

Conversation

morgandu
Copy link
Collaborator

@morgandu morgandu commented Jun 17, 2024

Checklist

This PR introduced compute layout control to allowed a different compute layout for attention

  • Attention Unit Tests for different compute layout
  • Microbenchmark - Performance
  • E2E Serving - Accuracy and Performance

Setup

Results and Analysis

The goal of introducing the new compute layout is to potentially avoid cache layout tuning, though we still can tune the cache layout to seek and verify for the best performance.

Annotation

  • b: batch
  • t: query_length
  • h: query_heads
  • d: kv_dimension
  • s: kv_length
  • k: kv_heads

Layout

  • 0123 | bthd | bskd
  • 0213 | bhtd | bksd
  • 1203 | thbd | skbd

Summary

Existing attention compute layout is 0123, and we introduced a different compute layout 0213, which is of a layout that's TPU friendly.

We introduced 0213 compute layout to verify:
- if and how much 0213 has direct impact on performance on the default cache layout, i.e. same layout as compute layout
- if and how much different compute layout and different cache layouts have a composite impact on performance

Performance

Existing compute layout 0123 and its history

Cache layout 1203-1203

With the existing cache layout was 1203-1203, with throughput 2591.642232 tokens/s, this was improved from the default cache layout 0123-0123 about 3x.

Cache layout 2013-2013

After layout tuning, we got optimal prefill-ar cache layout as 2013-2013, with throughput 3347.180221 tokens/s, which was 29% improvement.

New compute layout 0213

Cache layout 0213-0213

With the two cache in the same layout as compute, i.e. 0213-0213 (xprof: https://xprof.corp.google.com/overview_page/morgandu-12159058496322304249), we got 3273.96 tokens/s, this is of the top performance after we verified with layout tuning.

Cache layout 0213-0132

The tuned cache layout that give us the best throughput 3329.45 tokens/s is 0213-0132 (xprof: https://xprof.corp.google.com/overview_page/morgandu-5743582688063478644)

Accuracy

No regression on Rouge scores between 0123 and 0213

{'rouge1': 42.1738, 'rouge2': 19.6973, 'rougeL': 26.9088, 'rougeLsum': 39.6794, 'gen_len': 1144204, 'gen_num': 995}

@morgandu morgandu force-pushed the mor--compute-axis-order branch 2 times, most recently from 2e5e8ac to c5ee451 Compare June 17, 2024 20:29
@morgandu morgandu changed the title Allow Different Compute Layout Allow Different Compute Layout for Attention Jun 17, 2024
Copy link
Collaborator

@vipannalla vipannalla left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

Comment on lines +36 to +68
"name": "Debug MaxText Inference Microbenchmark",
"type": "python",
"request": "launch",
"console": "integratedTerminal",
"justMyCode": false,
"python": "python3",
"program": "${workspaceFolder}/MaxText/inference_microbenchmark.py",
"args": [
"MaxText/configs/base.yml",
"model_name=llama2-7b",
"tokenizer_path=assets/tokenizer.llama2",
"weight_dtype=bfloat16",
"scan_layers=false",
"attention=dot_product",
"max_prefill_predict_length=1024",
"max_target_length=2048",
"ici_fsdp_parallelism=1",
"ici_tensor_parallelism=-1",
"ici_autoregressive_parallelism=1",
"inference_microbenchmark_prefill_lengths=32,64,128,256,512,1024",
"inference_microbenchmark_stages=generate",
"inference_microbenchmark_loop_iters=1",
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
"base_output_directory=gs://test-maxtext-output",
"prefill_cache_axis_order=0,2,1,3",
"ar_cache_axis_order=0,2,1,3",
"compute_axis_order=0,2,1,3",
"reshape_q=true",
"per_device_batch_size=24",
"quantization=int8",
"quantize_kvcache=True",
]
},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question -- Is this to auto run for every local commit/amend on vscode? or just for convinience?

Copy link
Collaborator Author

@morgandu morgandu Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, I added this when I was debugging on my local vscode, which I think will be helpful for other engineers too. You may need to change the flags depends on your run though.

@@ -52,6 +52,11 @@ def string_to_bool(s: str) -> bool:
_yaml_types_to_parser = {str: str, int: int, float: float, bool: string_to_bool}


def validate_compute_axis_order(s: str) -> None:
valid_compute_axis_order = ("0,1,2,3", "0,2,1,3")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not allow other layouts? Does the code break with others or just run slow? I'd allow other layouts similar to prefill_cache_axis_order/ar_cache_axis_order.
Maybe remove the exception and just print a warning that others are untested?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm, I noticed in attention.py you specifically look for those two.

prefill_cache_axis_order=(1,2,0,3),
ar_cache_axis_order=(1,2,0,3)
)
def test_dot_product_cache_axis_order(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding unitests!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of course!

@@ -68,6 +68,12 @@
# pytype: disable=attribute-error


def validate_compute_axis_order(s: str) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate, remove. I see this method in pyconfig.py as well, which is right place to validate at the beginning of the run.

Copy link
Collaborator Author

@morgandu morgandu Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is a different validate. Since the base.yml only allows string, so the pyconfig validates the flag's string value. In attention module, the axis_orders are really used as AxisIdxes (tuple), and it can be hard code / overwrite in different model class initiation, i.e. https://github.com/google/maxtext/blob/main/MaxText/layers/gpt3.py#L229

I don't want to risk with possible hard coded axis_orders passing from somewhere other than yaml flags.

Comment on lines +369 to +383
if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0,1,2,3):
query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d))
if self.reshape_q and q_seq_len == 1:
query = jnp.broadcast_to(query, (b, 2, n_kv, n // n_kv, d))
result = jnp.einsum("btkgd,bskd->bkgts", query, key)
elif self.compute_axis_order == (0,2,1,3):
query = jnp.transpose(query, axes=self.compute_axis_order)
key = jnp.transpose(key, axes=self.compute_axis_order)
query = jnp.reshape(query, (b, n_kv, n // n_kv, t, d))
if self.reshape_q and q_seq_len == 1:
query = jnp.broadcast_to(query, (b, n_kv, n // n_kv, 2, d))
result = jnp.einsum("bkgtd,bksd->bkgts", query, key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This lgtm for now. Since we already know 0 = b, 1 = t, 2= n etc.., can this code be made generic to support all layouts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of other places needs to be changed to support more layouts, I have experienced some loud and silent bug when I was trying to add one more layout till I gave up.

Also, I don't think it's necessary to support all layouts. Since it was recommended that bh** is one of the friendly layout. I'd say if we end up really need to include new layouts then let's revisit it!

MaxText/configs/base.yml Show resolved Hide resolved
@copybara-service copybara-service bot merged commit fe4bfdc into main Jun 18, 2024
13 checks passed
@copybara-service copybara-service bot deleted the mor--compute-axis-order branch June 18, 2024 14:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants