From 0fc492d3e2bb334f2a1304568ce0bf379c8ff5be Mon Sep 17 00:00:00 2001 From: RissyRan Date: Wed, 5 Jun 2024 20:03:13 +0000 Subject: [PATCH] Add vanilla megablox to MoE --- MaxText/configs/base.yml | 1 + MaxText/layers/linears.py | 119 ++++++++++++++++-- MaxText/layers/mistral.py | 2 +- end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh | 3 + 4 files changed, 112 insertions(+), 13 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index ed4f5edd2..92f338513 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -81,6 +81,7 @@ head_dim: 128 num_experts: 1 num_experts_per_tok: 1 moe_matmul: False +megablox: False mlp_activations: ["silu", "linear"] dropout_rate: 0 logits_via_embedding: False diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index a6a105974..5cc418d5a 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -28,10 +28,19 @@ from layers import quantizations import numpy as np from jax.ad_checkpoint import checkpoint_name +from jax.experimental import shard_map +import max_logging + +try: + from jax.experimental.pallas.ops.tpu import megablox as mblx +except ImportError: + max_logging.log("JAX megablox is available for TPU only.") + pass Array = common_types.Array Config = common_types.Config DType = common_types.DType +Mesh = common_types.Mesh NdInitializer = initializers.NdInitializer nd_dense_init = initializers.nd_dense_init @@ -263,6 +272,7 @@ class MoeBlock(nn.Module): Attributes: num_experts: Number of experts. num_experts_per_tok: Number of experts for each token. + mesh: Mesh, device mesh. kernel_init: Kernel function, passed to the dense layers. kernel_axes: Tuple with axes to apply kernel function. weight_dtype: Type for the weights. @@ -272,6 +282,7 @@ class MoeBlock(nn.Module): config: Config num_experts: int num_experts_per_tok: int + mesh: Mesh kernel_init: NdInitializer kernel_axes: Tuple[str, ...] weight_dtype: DType = jnp.float32 @@ -315,6 +326,73 @@ def generate_kernels(self, num_experts, emb_dim, mlp_dim): wo_kernel = jnp.asarray(wo_kernel, self.dtype) return w0_kernel, w1_kernel, wo_kernel + def permute(self, inputs, gate_logits, emb_dim): + """Permute tokens to group by expert to fit gmm call.""" + + # reshape inputs (batch, sequence, emb) to (batch * sequence, emb) + inputs_2d = jnp.reshape(inputs, (-1, emb_dim)) + weights, selected_experts = jax.lax.top_k(gate_logits, self.num_experts_per_tok) + weights = jax.nn.softmax(weights.astype(self.weight_dtype), axis=-1).astype(self.dtype) + flatten_selected_experts = jnp.ravel(selected_experts) + indices_to_sort_by_expert = jnp.argsort(flatten_selected_experts) + # repeat inputs for number of active experts + repeat_inputs = jnp.repeat(inputs_2d, self.num_experts_per_tok, axis=0) + # sort inputs for number of selected experts + sorted_inputs = jnp.take(repeat_inputs, indices=indices_to_sort_by_expert, axis=0).astype(self.dtype) + group_size = jnp.bincount(flatten_selected_experts, length=self.num_experts) + + return sorted_inputs, indices_to_sort_by_expert, weights, group_size + + def unpermute(self, intermediate, inputs, indices_to_sort_by_expert, weights): + """Unpermute tokens to original order and combine weights.""" + + unsort_output = jnp.take(intermediate, indices=jnp.argsort(indices_to_sort_by_expert), axis=0) + flatten_weights = jnp.ravel(weights) + combined_output = jnp.multiply(unsort_output, flatten_weights[:, None]) + groups = jnp.reshape(combined_output, (-1, self.num_experts_per_tok, combined_output.shape[1])) + return jnp.sum(groups, axis=1).reshape(inputs.shape).astype(self.dtype) + + def call_gmm(self, inputs, group_sizes, mlp_activation, w0_kernel, w1_kernel, wo_kernel): + # TODO(ranran): currently megablox works well on single host, and + # will add sharding properly to improve performance. + @functools.partial( + shard_map.shard_map, + mesh=self.mesh, + in_specs=( + (nn.logical_to_mesh_axes((None, None))), + (nn.logical_to_mesh_axes((None, None, None))), + (nn.logical_to_mesh_axes((None,))), + ), + out_specs=(nn.logical_to_mesh_axes((None, None))), + check_rep=False, + ) + def gmm(inputs, kernel, group_sizes): + hs_shape = inputs.shape + # pad lengh is the 1st dimension of tiling size in gmm call + pad_length = 512 + if hs_shape[0] % pad_length: + pad_length = pad_length - hs_shape[0] % pad_length + inputs = jax.lax.pad(inputs.astype(jnp.float32), 0.0, [(0, pad_length, 0), (0,0,0)]) + + inputs = inputs.astype(self.dtype) + kernel = kernel.astype(self.weight_dtype) + + output = mblx.gmm(lhs=inputs, + rhs=kernel, + group_sizes=group_sizes, + tiling=(512, 512, 512)) + + if hs_shape[0] % pad_length: + output = output[:hs_shape[0]] + return output + + layer_w0 = gmm(inputs, w0_kernel, group_sizes) + layer_w1 = gmm(inputs, w1_kernel, group_sizes) + layer_act = _convert_to_activation_function(mlp_activation)(layer_w0) + intermediate_layer = jnp.multiply(layer_act, layer_w1) + output = gmm(intermediate_layer, wo_kernel, group_sizes) + return output + @nn.compact def __call__(self, inputs): cfg = self.config @@ -339,16 +417,33 @@ def __call__(self, inputs): w0_kernel, w1_kernel, wo_kernel = self.generate_kernels(cfg.num_experts, cfg.emb_dim, cfg.mlp_dim) - - with jax.named_scope("wi_0"): - layer_w0 = jnp.einsum("BLE,NEH -> BLNH", inputs, w0_kernel) - with jax.named_scope("wi_1"): - layer_w1 = jnp.einsum("BLE,NEH -> BLNH", inputs, w1_kernel) - layer_w0_act = _convert_to_activation_function(cfg.mlp_activations[0])(layer_w0) - layer_multiply = jnp.multiply(layer_w0_act, layer_w1) - with jax.named_scope("wo"): - intermediate_layer = jnp.einsum("BLNH,NHE -> BLNE", layer_multiply, wo_kernel) - with jax.named_scope("w_sum"): - output = jnp.einsum("BLNE,BLN -> BLE", intermediate_layer, weights) - + + if cfg.megablox: + max_logging.log("Running MoE megablox implementation.") + sorted_hidden_states, indices_to_sort_by_expert, weights, group_sizes = self.permute(inputs, + gate_logits, + cfg.emb_dim) + intermediate_output = self.call_gmm(sorted_hidden_states, + group_sizes, + cfg.mlp_activations[0], + w0_kernel, + w1_kernel, + wo_kernel) + output = self.unpermute(intermediate_output, + inputs, + indices_to_sort_by_expert, + weights) + else: + max_logging.log("Running MoE matmul implementation.") + with jax.named_scope("wi_0"): + layer_w0 = jnp.einsum("BLE,NEH -> BLNH", inputs, w0_kernel) + with jax.named_scope("wi_1"): + layer_w1 = jnp.einsum("BLE,NEH -> BLNH", inputs, w1_kernel) + layer_w0_act = _convert_to_activation_function(cfg.mlp_activations[0])(layer_w0) + layer_multiply = jnp.multiply(layer_w0_act, layer_w1) + with jax.named_scope("wo"): + intermediate_layer = jnp.einsum("BLNH,NHE -> BLNE", layer_multiply, wo_kernel) + with jax.named_scope("w_sum"): + output = jnp.einsum("BLNE,BLN -> BLE", intermediate_layer, weights) + return output diff --git a/MaxText/layers/mistral.py b/MaxText/layers/mistral.py index 5c7e4ba88..449797518 100644 --- a/MaxText/layers/mistral.py +++ b/MaxText/layers/mistral.py @@ -125,11 +125,11 @@ def __call__( if cfg.num_experts > 1: # TODO(ranran): remove for loop implementation after adding expert parallelism if cfg.moe_matmul: - max_logging.log("Running MoE matmul implementation.") mlp_lnx = linears.MoeBlock( config=cfg, num_experts=cfg.num_experts, num_experts_per_tok=cfg.num_experts_per_tok, + mesh=mesh, kernel_init=initializers.nd_dense_init(1.0, 'fan_in', 'truncated_normal'), kernel_axes=('embed', 'mlp'), dtype=cfg.dtype, diff --git a/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh b/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh index 36fa1129a..7faffc333 100644 --- a/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh +++ b/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh @@ -34,6 +34,9 @@ python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml bas # Test whether the forward pass logits match the golden logits - matmul implementation python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${MATMUL_SCANNED_CHECKPOINT} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mixtral-8x7b tokenizer_path=gs://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 moe_matmul=True --atol=3 --rtol=1 --token_size=4 +# Test whether the forward pass logits match the golden logits - megablox implementation +python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${MATMUL_SCANNED_CHECKPOINT} run_name=megablox_forward_pass_test per_device_batch_size=1 model_name=mixtral-8x7b tokenizer_path=gs://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 moe_matmul=True megablox=True --atol=3 --rtol=1 --token_size=4 + # Run fine-tuning python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=fine_tuning per_device_batch_size=1 model_name=mixtral-8x7b ici_tensor_parallelism=4 ici_fsdp_parallelism=16 steps=10 max_target_length=1024 async_checkpointing=false tokenizer_path=gs://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral checkpoint_period=5