Skip to content

Commit

Permalink
Merge pull request #685 from google:fix_matmul_scale
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 640594189
  • Loading branch information
maxtext authors committed Jun 5, 2024
2 parents d185209 + cdb4853 commit 16aeac4
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ class MoeBlock(nn.Module):
weight_dtype: DType = jnp.float32
dtype: DType = jnp.float32

def generate_kernels(self, num_experts, base_emb_dim, mlp_dim):
def generate_kernels(self, num_experts, emb_dim, mlp_dim):

kernel_in_axis = np.arange(1)
kernel_out_axis = np.arange(1, 2)
Expand All @@ -289,7 +289,7 @@ def generate_kernels(self, num_experts, base_emb_dim, mlp_dim):
w0_kernel = self.param(
'wi_0',
nn.with_logical_partitioning(kernel_init, kernel_axes),
(num_experts, base_emb_dim, mlp_dim),
(num_experts, emb_dim, mlp_dim),
self.weight_dtype,
kernel_in_axis,
kernel_out_axis,
Expand All @@ -298,7 +298,7 @@ def generate_kernels(self, num_experts, base_emb_dim, mlp_dim):
w1_kernel = self.param(
'wi_1',
nn.with_logical_partitioning(kernel_init, kernel_axes),
(num_experts, base_emb_dim, mlp_dim),
(num_experts, emb_dim, mlp_dim),
self.weight_dtype,
kernel_in_axis,
kernel_out_axis,
Expand All @@ -307,7 +307,7 @@ def generate_kernels(self, num_experts, base_emb_dim, mlp_dim):
wo_kernel = self.param(
'wo',
nn.with_logical_partitioning(kernel_init, wo_kernel_axes),
(num_experts, mlp_dim, base_emb_dim),
(num_experts, mlp_dim, emb_dim),
self.weight_dtype,
kernel_in_axis,
kernel_out_axis,
Expand Down Expand Up @@ -337,7 +337,7 @@ def __call__(self, inputs):
weights = weights.at[index_update].set(softmax_probs)

w0_kernel, w1_kernel, wo_kernel = self.generate_kernels(cfg.num_experts,
cfg.base_emb_dim,
cfg.emb_dim,
cfg.mlp_dim)

with jax.named_scope("wi_0"):
Expand Down

0 comments on commit 16aeac4

Please sign in to comment.