Skip to content

Commit

Permalink
Merge pull request #689 from google:vanilla_megablox
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642798191
  • Loading branch information
maxtext authors committed Jun 13, 2024
2 parents d0701b0 + 0fc492d commit e898606
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 13 deletions.
1 change: 1 addition & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ head_dim: 128
num_experts: 1
num_experts_per_tok: 1
moe_matmul: False
megablox: False
mlp_activations: ["silu", "linear"]
dropout_rate: 0
logits_via_embedding: False
Expand Down
119 changes: 107 additions & 12 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,19 @@
from layers import quantizations
import numpy as np
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import shard_map
import max_logging

try:
from jax.experimental.pallas.ops.tpu import megablox as mblx
except ImportError:
max_logging.log("JAX megablox is available for TPU only.")
pass

Array = common_types.Array
Config = common_types.Config
DType = common_types.DType
Mesh = common_types.Mesh
NdInitializer = initializers.NdInitializer

nd_dense_init = initializers.nd_dense_init
Expand Down Expand Up @@ -263,6 +272,7 @@ class MoeBlock(nn.Module):
Attributes:
num_experts: Number of experts.
num_experts_per_tok: Number of experts for each token.
mesh: Mesh, device mesh.
kernel_init: Kernel function, passed to the dense layers.
kernel_axes: Tuple with axes to apply kernel function.
weight_dtype: Type for the weights.
Expand All @@ -272,6 +282,7 @@ class MoeBlock(nn.Module):
config: Config
num_experts: int
num_experts_per_tok: int
mesh: Mesh
kernel_init: NdInitializer
kernel_axes: Tuple[str, ...]
weight_dtype: DType = jnp.float32
Expand Down Expand Up @@ -315,6 +326,73 @@ def generate_kernels(self, num_experts, emb_dim, mlp_dim):
wo_kernel = jnp.asarray(wo_kernel, self.dtype)
return w0_kernel, w1_kernel, wo_kernel

def permute(self, inputs, gate_logits, emb_dim):
"""Permute tokens to group by expert to fit gmm call."""

# reshape inputs (batch, sequence, emb) to (batch * sequence, emb)
inputs_2d = jnp.reshape(inputs, (-1, emb_dim))
weights, selected_experts = jax.lax.top_k(gate_logits, self.num_experts_per_tok)
weights = jax.nn.softmax(weights.astype(self.weight_dtype), axis=-1).astype(self.dtype)
flatten_selected_experts = jnp.ravel(selected_experts)
indices_to_sort_by_expert = jnp.argsort(flatten_selected_experts)
# repeat inputs for number of active experts
repeat_inputs = jnp.repeat(inputs_2d, self.num_experts_per_tok, axis=0)
# sort inputs for number of selected experts
sorted_inputs = jnp.take(repeat_inputs, indices=indices_to_sort_by_expert, axis=0).astype(self.dtype)
group_size = jnp.bincount(flatten_selected_experts, length=self.num_experts)

return sorted_inputs, indices_to_sort_by_expert, weights, group_size

def unpermute(self, intermediate, inputs, indices_to_sort_by_expert, weights):
"""Unpermute tokens to original order and combine weights."""

unsort_output = jnp.take(intermediate, indices=jnp.argsort(indices_to_sort_by_expert), axis=0)
flatten_weights = jnp.ravel(weights)
combined_output = jnp.multiply(unsort_output, flatten_weights[:, None])
groups = jnp.reshape(combined_output, (-1, self.num_experts_per_tok, combined_output.shape[1]))
return jnp.sum(groups, axis=1).reshape(inputs.shape).astype(self.dtype)

def call_gmm(self, inputs, group_sizes, mlp_activation, w0_kernel, w1_kernel, wo_kernel):
# TODO(ranran): currently megablox works well on single host, and
# will add sharding properly to improve performance.
@functools.partial(
shard_map.shard_map,
mesh=self.mesh,
in_specs=(
(nn.logical_to_mesh_axes((None, None))),
(nn.logical_to_mesh_axes((None, None, None))),
(nn.logical_to_mesh_axes((None,))),
),
out_specs=(nn.logical_to_mesh_axes((None, None))),
check_rep=False,
)
def gmm(inputs, kernel, group_sizes):
hs_shape = inputs.shape
# pad lengh is the 1st dimension of tiling size in gmm call
pad_length = 512
if hs_shape[0] % pad_length:
pad_length = pad_length - hs_shape[0] % pad_length
inputs = jax.lax.pad(inputs.astype(jnp.float32), 0.0, [(0, pad_length, 0), (0,0,0)])

inputs = inputs.astype(self.dtype)
kernel = kernel.astype(self.weight_dtype)

output = mblx.gmm(lhs=inputs,
rhs=kernel,
group_sizes=group_sizes,
tiling=(512, 512, 512))

if hs_shape[0] % pad_length:
output = output[:hs_shape[0]]
return output

layer_w0 = gmm(inputs, w0_kernel, group_sizes)
layer_w1 = gmm(inputs, w1_kernel, group_sizes)
layer_act = _convert_to_activation_function(mlp_activation)(layer_w0)
intermediate_layer = jnp.multiply(layer_act, layer_w1)
output = gmm(intermediate_layer, wo_kernel, group_sizes)
return output

@nn.compact
def __call__(self, inputs):
cfg = self.config
Expand All @@ -339,16 +417,33 @@ def __call__(self, inputs):
w0_kernel, w1_kernel, wo_kernel = self.generate_kernels(cfg.num_experts,
cfg.emb_dim,
cfg.mlp_dim)

with jax.named_scope("wi_0"):
layer_w0 = jnp.einsum("BLE,NEH -> BLNH", inputs, w0_kernel)
with jax.named_scope("wi_1"):
layer_w1 = jnp.einsum("BLE,NEH -> BLNH", inputs, w1_kernel)
layer_w0_act = _convert_to_activation_function(cfg.mlp_activations[0])(layer_w0)
layer_multiply = jnp.multiply(layer_w0_act, layer_w1)
with jax.named_scope("wo"):
intermediate_layer = jnp.einsum("BLNH,NHE -> BLNE", layer_multiply, wo_kernel)
with jax.named_scope("w_sum"):
output = jnp.einsum("BLNE,BLN -> BLE", intermediate_layer, weights)


if cfg.megablox:
max_logging.log("Running MoE megablox implementation.")
sorted_hidden_states, indices_to_sort_by_expert, weights, group_sizes = self.permute(inputs,
gate_logits,
cfg.emb_dim)
intermediate_output = self.call_gmm(sorted_hidden_states,
group_sizes,
cfg.mlp_activations[0],
w0_kernel,
w1_kernel,
wo_kernel)
output = self.unpermute(intermediate_output,
inputs,
indices_to_sort_by_expert,
weights)
else:
max_logging.log("Running MoE matmul implementation.")
with jax.named_scope("wi_0"):
layer_w0 = jnp.einsum("BLE,NEH -> BLNH", inputs, w0_kernel)
with jax.named_scope("wi_1"):
layer_w1 = jnp.einsum("BLE,NEH -> BLNH", inputs, w1_kernel)
layer_w0_act = _convert_to_activation_function(cfg.mlp_activations[0])(layer_w0)
layer_multiply = jnp.multiply(layer_w0_act, layer_w1)
with jax.named_scope("wo"):
intermediate_layer = jnp.einsum("BLNH,NHE -> BLNE", layer_multiply, wo_kernel)
with jax.named_scope("w_sum"):
output = jnp.einsum("BLNE,BLN -> BLE", intermediate_layer, weights)

return output
2 changes: 1 addition & 1 deletion MaxText/layers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ def __call__(
if cfg.num_experts > 1:
# TODO(ranran): remove for loop implementation after adding expert parallelism
if cfg.moe_matmul:
max_logging.log("Running MoE matmul implementation.")
mlp_lnx = linears.MoeBlock(
config=cfg,
num_experts=cfg.num_experts,
num_experts_per_tok=cfg.num_experts_per_tok,
mesh=mesh,
kernel_init=initializers.nd_dense_init(1.0, 'fan_in', 'truncated_normal'),
kernel_axes=('embed', 'mlp'),
dtype=cfg.dtype,
Expand Down
3 changes: 3 additions & 0 deletions end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml bas
# Test whether the forward pass logits match the golden logits - matmul implementation
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${MATMUL_SCANNED_CHECKPOINT} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mixtral-8x7b tokenizer_path=gs://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 moe_matmul=True --atol=3 --rtol=1 --token_size=4

# Test whether the forward pass logits match the golden logits - megablox implementation
python3 MaxText/tests/forward_pass_logit_checker.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${MATMUL_SCANNED_CHECKPOINT} run_name=megablox_forward_pass_test per_device_batch_size=1 model_name=mixtral-8x7b tokenizer_path=gs://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral ici_tensor_parallelism=4 ici_fsdp_parallelism=16 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 moe_matmul=True megablox=True --atol=3 --rtol=1 --token_size=4

# Run fine-tuning
python3 MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=fine_tuning per_device_batch_size=1 model_name=mixtral-8x7b ici_tensor_parallelism=4 ici_fsdp_parallelism=16 steps=10 max_target_length=1024 async_checkpointing=false tokenizer_path=gs://maxtext-external/mixtral-8x7B-v0.1-Instruct/tokenizer.mistral checkpoint_period=5

Expand Down

0 comments on commit e898606

Please sign in to comment.