Skip to content

Commit

Permalink
Change norm sharding for llama2-7b to fsdp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 640498890
  • Loading branch information
golechwierowicz authored and maxtext authors committed Jun 5, 2024
1 parent 2a6154f commit cb2c69b
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions MaxText/configs/llama2_7b_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,26 @@ logits_dot_in_fp32: False

per_device_batch_size: 4
max_target_length: 4096

logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose',]],
['activation_heads', ['tensor','sequence']],
['activation_length', 'sequence'],
['activation_embed', 'tensor'],
['activation_mlp', 'tensor'],
['activation_kv', 'tensor'],
['activation_vocab', ['tensor', 'sequence']],
['activation_vocab', 'tensor'],
['activation_vocab', 'sequence'],
['mlp', ['fsdp_transpose', 'tensor', 'autoregressive']],
['vocab', ['tensor', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence']],
['embed', ['fsdp', 'sequence']],
['norm', 'fsdp'],
['heads', ['tensor', 'autoregressive']],
['kv', []],
['cache_batch', []],
['cache_heads', ['autoregressive', 'tensor']],
['cache_kv', []],
['cache_sequence', []],
]

0 comments on commit cb2c69b

Please sign in to comment.