From f5ede5719e1388da638f5ec670d362c21b85d84b Mon Sep 17 00:00:00 2001 From: Raymond Zou Date: Tue, 11 Jun 2024 22:49:38 +0000 Subject: [PATCH] Add llama2 70b training config for v5e --- MaxText/configs/v5e/llama2_70b.sh | 46 +++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 MaxText/configs/v5e/llama2_70b.sh diff --git a/MaxText/configs/v5e/llama2_70b.sh b/MaxText/configs/v5e/llama2_70b.sh new file mode 100644 index 000000000..750db1e0f --- /dev/null +++ b/MaxText/configs/v5e/llama2_70b.sh @@ -0,0 +1,46 @@ +# Llama2 70B model. +# This config will work out of the box for any number of v5e-256 slices. +# +# Command Flags: +# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml) +# DATASET_PATH (Required, unless dataset_path is already set in base.yml) +# RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) +# PLATFORM (Optional, can be "gke" or "gce", default is "gce") +# +# Example to invoke this script: +# bash MaxText/configs/v5e/llama2_70b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" PLATFORM="gke" +# +# Example to AOT compile: +# bash MaxText/configs/v5e/llama2_70b.sh EXECUTABLE=train_compile.py M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 + + +# Stop execution if any command exits with error +set -e + +export PLATFORM="gce" # Can be "gke" or "gce" +export EXECUTABLE="train.py" # or train_compile.py + +# Set environment variables +for ARGUMENT in "$@"; do + IFS='=' read -r KEY VALUE <<< "$ARGUMENT" + export "$KEY"="$VALUE" +done + +# The setup accommodates two cases: +# 1) Passing the 'RUN_NAME' variable at runtime +# 2) Propagating the 'M_RUN_NAME' variable within an Airflow sweeping workflow +if [ -n "$RUN_NAME" ]; +then + export M_RUN_NAME=$RUN_NAME +fi + +# Set up network +bash preflight.sh PLATFORM=$PLATFORM + +# Train +export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" + +python MaxText/$EXECUTABLE MaxText/configs/base.yml model_name=llama2-70b\ + base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\ + tokenizer_path=assets/tokenizer.llama2 per_device_batch_size=2 remat_policy=qkv_proj_offloaded\ + steps=30 enable_checkpointing=false