diff --git a/.vscode/launch.json b/.vscode/launch.json index da30010fb..ddd8eb0f6 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -31,6 +31,40 @@ "dataset_path=gs://test-maxtext-dataset", "steps=2", "enable_checkpointing=false"] - } + }, + { + "name": "Debug MaxText Inference Microbenchmark", + "type": "python", + "request": "launch", + "console": "integratedTerminal", + "justMyCode": false, + "python": "python3", + "program": "${workspaceFolder}/MaxText/inference_microbenchmark.py", + "args": [ + "MaxText/configs/base.yml", + "model_name=llama2-7b", + "tokenizer_path=assets/tokenizer.llama2", + "weight_dtype=bfloat16", + "scan_layers=false", + "attention=dot_product", + "max_prefill_predict_length=1024", + "max_target_length=2048", + "ici_fsdp_parallelism=1", + "ici_tensor_parallelism=-1", + "ici_autoregressive_parallelism=1", + "inference_microbenchmark_prefill_lengths=32,64,128,256,512,1024", + "inference_microbenchmark_stages=generate", + "inference_microbenchmark_loop_iters=1", + "run_name=runner_$(date +%Y-%m-%d-%H-%M)", + "base_output_directory=gs://test-maxtext-output", + "prefill_cache_axis_order=0,2,1,3", + "ar_cache_axis_order=0,2,1,3", + "compute_axis_order=0,2,1,3", + "reshape_q=true", + "per_device_batch_size=24", + "quantization=int8", + "quantize_kvcache=True", + ] + }, ] } \ No newline at end of file diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 586064c4e..47f8b35b6 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -312,10 +312,13 @@ inference_microbenchmark_log_file_path: "" # KV Cache layout control # Logical layout: 0,1,2,3 ; CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV # Default layout: 1,2,0,3 ; CACHE_SEQUENCE, CACHE_HEADS, CACHE_BATCH, CACHE_KV -prefill_key_axis_order: "1,2,0,3" -prefill_value_axis_order: "1,2,0,3" -ar_key_axis_order: "1,2,0,3" -ar_value_axis_order: "1,2,0,3" +prefill_cache_axis_order: "1,2,0,3" +ar_cache_axis_order: "1,2,0,3" + +# Compute layout control +# Default layout: 0,1,2,3 ; BATCH, LENGTH, HEAD, D_KV +compute_axis_order: "0,1,2,3" + reshape_q: False # Maxengine Metrics diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 4e3cd8ffb..117fda9ea 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -68,6 +68,12 @@ # pytype: disable=attribute-error +def validate_compute_axis_order(s: str) -> None: + valid_compute_axis_order = ((0,1,2,3), (0,2,1,3)) + if s not in valid_compute_axis_order: # currently supported compute_axis_order + raise ValueError("Invalid compute_axis_order was passed. Valid options ", valid_compute_axis_order) + + def apply_mask_to_logits(logits: Array, mask: Array): """Applies a floating-point mask to a set of logits. @@ -109,10 +115,9 @@ class AttentionOp(nn.Module): float32_logits: bool = False flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) kv_cache_logical_layout: AxisNames = (CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV) - prefill_key_axis_order: AxisIdxes = (1, 2, 0, 3) - prefill_value_axis_order: AxisIdxes = (1, 2, 0, 3) - ar_key_axis_order: AxisIdxes = (1, 2, 0, 3) - ar_value_axis_order: AxisIdxes = (1, 2, 0, 3) + prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3) + ar_cache_axis_order: AxisIdxes = (1, 2, 0, 3) + compute_axis_order: AxisIdxes = (0, 1, 2, 3) reshape_q: bool = False dropout_rate: float = 0.0 dtype: DType = jnp.float32 @@ -278,7 +283,7 @@ def cudnn_flash_attention( ) return dpa_layer(query, key, value, mask=attn_mask) - def compute_local_attention(self, attn_weights: Array, value: Array, q_seq_len: int) -> tuple[Array, Array, Array]: + def compute_local_attention(self, attn_weights: Array, value: Array, q_seq_len: int, model_mode: str) -> tuple[Array, Array, Array]: """Computes the attention of a local subset of the kv cache. Local attention results will need to be combined with any other local attentions and normalized Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py @@ -304,11 +309,12 @@ def compute_local_attention(self, attn_weights: Array, value: Array, q_seq_len: local_max = jnp.reshape(local_max, (local_max.shape[0], local_max.shape[1], local_max.shape[2] * local_max.shape[3], 1)) local_sum = jnp.reshape(local_sum, (local_sum.shape[0], local_sum.shape[1], local_sum.shape[2] * local_sum.shape[3], 1)) - local_out = self.wv_product(local_exps, value, q_seq_len) + local_out = self.wv_product(local_exps, value, model_mode) if self.reshape_q and q_seq_len == 1: local_max = local_max[:,0:1,:,:] local_sum = local_sum[:,0:1,:,:] + local_out = local_out[:,0:1,:,:] return local_out, local_max, local_sum @@ -321,23 +327,24 @@ def apply_attention_dot( model_mode: str = common_types.MODEL_MODE_TRAIN, ): """Apply Attention.""" + validate_compute_axis_order(self.compute_axis_order) # Casting qk_product and softmaxt computation for float32 for model stability. - if self.float32_qk_product: + if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_qk_product: query = query.astype(jnp.float32) key = key.astype(jnp.float32) q_seq_len = query.shape[1] - attn_weights = self.qk_product(query, key, q_seq_len) + attn_weights = self.qk_product(query, key, q_seq_len, model_mode) # Casting softmaxt computation for float32 for model stability. - if self.float32_logits: + if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_logits: attn_weights = attn_weights.astype(jnp.float32) attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) if attn_mask is not None: attn_weights = apply_mask_to_logits(attn_weights, attn_mask) - return self.compute_local_attention(attn_weights, value, q_seq_len) + return self.compute_local_attention(attn_weights, value, q_seq_len, model_mode) - def qk_product(self, query: Array, key: Array, q_seq_len: int) -> Array: + def qk_product(self, query: Array, key: Array, q_seq_len: int, model_mode: str) -> Array: """Query-Key product. Args: @@ -359,13 +366,21 @@ def qk_product(self, query: Array, key: Array, q_seq_len: int) -> Array: b, t, n, d = query.shape n_kv = key.shape[-2] assert n_kv == self.num_kv_heads - query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d)) - if self.reshape_q and q_seq_len == 1: - query = jnp.broadcast_to(query, (b, 2, n_kv, n // n_kv, d)) - result = jnp.einsum("btkgd,bskd->bkgts", query, key) + if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0,1,2,3): + query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d)) + if self.reshape_q and q_seq_len == 1: + query = jnp.broadcast_to(query, (b, 2, n_kv, n // n_kv, d)) + result = jnp.einsum("btkgd,bskd->bkgts", query, key) + elif self.compute_axis_order == (0,2,1,3): + query = jnp.transpose(query, axes=self.compute_axis_order) + key = jnp.transpose(key, axes=self.compute_axis_order) + query = jnp.reshape(query, (b, n_kv, n // n_kv, t, d)) + if self.reshape_q and q_seq_len == 1: + query = jnp.broadcast_to(query, (b, n_kv, n // n_kv, 2, d)) + result = jnp.einsum("bkgtd,bksd->bkgts", query, key) return result - def wv_product(self, attn_weights: Array, value: Array, q_seq_len: int) -> Array: + def wv_product(self, attn_weights: Array, value: Array, model_mode: str) -> Array: """weighted value product. Args: @@ -384,11 +399,16 @@ def wv_product(self, attn_weights: Array, value: Array, q_seq_len: int) -> Array n_kv: number of kv heads, sometimes annotated as k n // n_kv: number of group for query, sometimes annotated with g """ - out = jnp.einsum("bkgts,bskd->btkgd", attn_weights, value) - b, t, n_kv, g, d = out.shape - result = jnp.reshape(out, (b, t, n_kv * g, d)) - if self.reshape_q and q_seq_len == 1: - result = result[:, 0:1, :, :] + if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0,1,2,3): + out = jnp.einsum("bkgts,bskd->btkgd", attn_weights, value) + b, t, n_kv, g, d = out.shape + result = jnp.reshape(out, (b, t, n_kv * g, d)) + elif self.compute_axis_order == (0,2,1,3): + value = jnp.transpose(value, axes=self.compute_axis_order) + out = jnp.einsum("bkgts,bksd->bkgtd", attn_weights, value) + b, n_kv, g, t, d = out.shape + result = jnp.reshape(out, (b, n_kv * g, t, d)) + result = jnp.transpose(result, axes=self.compute_axis_order) return result def revert_kv_cache(self, kv, cached_axis_order): @@ -458,11 +478,11 @@ def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache): cache_logical_shape = (batch, self.max_prefill_predict_length, heads, kv_head_size) - key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_key_axis_order) - value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_value_axis_order) + key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_cache_axis_order) + value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_cache_axis_order) - key_shape = self.cached_kv_shape(cache_logical_shape, self.prefill_key_axis_order) - value_shape = self.cached_kv_shape(cache_logical_shape, self.prefill_value_axis_order) + key_shape = self.cached_kv_shape(cache_logical_shape, self.prefill_cache_axis_order) + value_shape = self.cached_kv_shape(cache_logical_shape, self.prefill_cache_axis_order) cached_key = self.variable( "cache", @@ -490,8 +510,8 @@ def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache): cache_logical_shape_scale = (batch, self.max_prefill_predict_length, heads, 1) - key_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.prefill_key_axis_order) - value_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.prefill_value_axis_order) + key_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.prefill_cache_axis_order) + value_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.prefill_cache_axis_order) cached_key_scale_var = self.variable( "cache", @@ -521,11 +541,11 @@ def _get_ar_cache(self, batch, heads, kv_head_size, quantize_kvcache): cache_logical_shape = (batch, cache_length, heads, kv_head_size) - key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_key_axis_order) - value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_value_axis_order) + key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_cache_axis_order) + value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_cache_axis_order) - key_shape = self.cached_kv_shape(cache_logical_shape, self.ar_key_axis_order) - value_shape = self.cached_kv_shape(cache_logical_shape, self.ar_value_axis_order) + key_shape = self.cached_kv_shape(cache_logical_shape, self.ar_cache_axis_order) + value_shape = self.cached_kv_shape(cache_logical_shape, self.ar_cache_axis_order) # TODO(b/339703100): investigate the issue why with_logical_partitioning doesn't enforce sharding cached_key = self.variable( @@ -564,8 +584,8 @@ def _get_ar_cache(self, batch, heads, kv_head_size, quantize_kvcache): cache_logical_shape_scale = (batch, cache_length, heads, 1) - key_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.ar_key_axis_order) - value_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.ar_value_axis_order) + key_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.ar_cache_axis_order) + value_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.ar_cache_axis_order) cached_key_scale_var = self.variable( "cache", @@ -616,16 +636,16 @@ def kv_cache_prefill( ) cached_ar_key_var, cached_ar_value_var, _, _ = self._get_ar_cache(batch, heads, kv_head_size, self.quantize_kvcache) # initialize it now - assert cached_prefill_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_key_axis_order) - assert cached_prefill_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_value_axis_order) - assert cached_ar_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_key_axis_order) - assert cached_ar_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_value_axis_order) + assert cached_prefill_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_cache_axis_order) + assert cached_prefill_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_cache_axis_order) + assert cached_ar_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_cache_axis_order) + assert cached_ar_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_cache_axis_order) - prefill_key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_key_axis_order) - prefill_value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_value_axis_order) + prefill_key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_cache_axis_order) + prefill_value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_cache_axis_order) - key_shaped_for_cache = self.reshape_kv_cache(key, self.prefill_key_axis_order) - value_shaped_for_cache = self.reshape_kv_cache(value, self.prefill_value_axis_order) + key_shaped_for_cache = self.reshape_kv_cache(key, self.prefill_cache_axis_order) + value_shaped_for_cache = self.reshape_kv_cache(value, self.prefill_cache_axis_order) if self.quantize_kvcache: key_shaped_for_cache, key_scale = quantizations.quantize_kv(key_shaped_for_cache, prefill_key_layout.index(CACHE_KV)) @@ -667,11 +687,11 @@ def update_ar_key_value( # In order to update the key, value caches with the current key and # value, we reshape the one_token_key and one_token_value - one_token_key_shaped_for_cache = self.reshape_kv_cache(one_token_key, self.ar_key_axis_order) - one_token_value_shaped_for_cache = self.reshape_kv_cache(one_token_value, self.ar_value_axis_order) + one_token_key_shaped_for_cache = self.reshape_kv_cache(one_token_key, self.ar_cache_axis_order) + one_token_value_shaped_for_cache = self.reshape_kv_cache(one_token_value, self.ar_cache_axis_order) - ar_key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_key_axis_order) - ar_value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_value_axis_order) + ar_key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_cache_axis_order) + ar_value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_cache_axis_order) if self.quantize_kvcache: one_token_key_shaped_for_cache, one_token_key_scale = quantizations.quantize_kv(one_token_key_shaped_for_cache, ar_key_layout.index(CACHE_KV)) @@ -717,7 +737,7 @@ def update_ar_key_value( ar_value = quantizations.unquantize_kv(cached_value_var.value, cached_value_scale_var.value, one_token_value.dtype) # Revert the keys and values back to original logical shapes. - return self.revert_kv_cache(ar_key, self.ar_key_axis_order), self.revert_kv_cache(ar_value, self.ar_value_axis_order) + return self.revert_kv_cache(ar_key, self.ar_cache_axis_order), self.revert_kv_cache(ar_value, self.ar_cache_axis_order) def prefill_cache_var_model_var(self, cache_var, target_dtype, cache_axis_order): if not self.quantize_kvcache: @@ -756,8 +776,8 @@ def kv_cache_autoregressive( batch, heads, kv_head_size, self.quantize_kvcache ) - assert cached_ar_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_key_axis_order) - assert cached_ar_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_value_axis_order) + assert cached_ar_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_cache_axis_order) + assert cached_ar_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_cache_axis_order) key = nn.with_logical_constraint(key, (BATCH, LENGTH, HEAD, D_KV)) value = nn.with_logical_constraint(value, (BATCH, LENGTH, HEAD, D_KV)) @@ -773,12 +793,12 @@ def kv_cache_autoregressive( cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache( batch, heads, kv_head_size, self.quantize_kvcache ) - assert cached_prefill_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_key_axis_order) - assert cached_prefill_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_value_axis_order) + assert cached_prefill_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_cache_axis_order) + assert cached_prefill_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_cache_axis_order) cached_prefill = ( - self.prefill_cache_var_model_var(cached_prefill_key_var, key.dtype, self.prefill_key_axis_order), - self.prefill_cache_var_model_var(cached_prefill_value_var, value.dtype, self.prefill_value_axis_order), + self.prefill_cache_var_model_var(cached_prefill_key_var, key.dtype, self.prefill_cache_axis_order), + self.prefill_cache_var_model_var(cached_prefill_value_var, value.dtype, self.prefill_cache_axis_order), cached_prefill_segment_id.value, ) return cached_prefill, (ar_key, ar_value, cached_ar_segment_id.value) @@ -914,10 +934,9 @@ class Attention(nn.Module): value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) - prefill_key_axis_order: AxisIdxes = (1, 2, 0, 3) - prefill_value_axis_order: AxisIdxes = (1, 2, 0, 3) - ar_key_axis_order: AxisIdxes = (1, 2, 0, 3) - ar_value_axis_order: AxisIdxes = (1, 2, 0, 3) + prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3) + ar_cache_axis_order: AxisIdxes = (1, 2, 0, 3) + compute_axis_order: AxisIdxes = (0, 1, 2, 3) reshape_q: bool = False def query_projection(self, inputs_q: Array) -> Array: @@ -1074,10 +1093,9 @@ def __call__( num_kv_heads=self.num_kv_heads, dropout_rate=self.dropout_rate, dtype=self.dtype, - prefill_key_axis_order = self.prefill_key_axis_order, - prefill_value_axis_order = self.prefill_value_axis_order, - ar_key_axis_order = self.ar_key_axis_order, - ar_value_axis_order = self.ar_value_axis_order, + prefill_cache_axis_order=self.prefill_cache_axis_order, + ar_cache_axis_order=self.ar_cache_axis_order, + compute_axis_order=self.compute_axis_order, reshape_q = self.reshape_q, ) diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index b61971631..d5baad836 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -96,10 +96,9 @@ def __call__( name="self_attention", quant=self.quant, quantize_kvcache=cfg.quantize_kvcache, - prefill_key_axis_order=tuple([int(i) for i in cfg.prefill_key_axis_order.split(",")]), - prefill_value_axis_order=tuple([int(i) for i in cfg.prefill_value_axis_order.split(",")]), - ar_key_axis_order=tuple([int(i) for i in cfg.ar_key_axis_order.split(",")]), - ar_value_axis_order=tuple([int(i) for i in cfg.ar_value_axis_order.split(",")]), + prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]), + ar_cache_axis_order=tuple([int(i) for i in cfg.ar_cache_axis_order.split(",")]), + compute_axis_order=tuple([int(i) for i in cfg.compute_axis_order.split(",")]), reshape_q=cfg.reshape_q, ) diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 7b824601a..319afee7e 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -93,10 +93,10 @@ def __call__( name="self_attention", quant=self.quant, quantize_kvcache=cfg.quantize_kvcache, - prefill_key_axis_order=tuple([int(i) for i in cfg.prefill_key_axis_order.split(",")]), - prefill_value_axis_order=tuple([int(i) for i in cfg.prefill_value_axis_order.split(",")]), - ar_key_axis_order=tuple([int(i) for i in cfg.ar_key_axis_order.split(",")]), - ar_value_axis_order=tuple([int(i) for i in cfg.ar_value_axis_order.split(",")]), + prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]), + ar_cache_axis_order=tuple([int(i) for i in cfg.ar_cache_axis_order.split(",")]), + compute_axis_order=tuple([int(i) for i in cfg.compute_axis_order.split(",")]), + reshape_q=cfg.reshape_q, ) attention_lnx = attention_layer( diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 714d9692f..96125e60a 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -52,6 +52,11 @@ def string_to_bool(s: str) -> bool: _yaml_types_to_parser = {str: str, int: int, float: float, bool: string_to_bool} +def validate_compute_axis_order(s: str) -> None: + valid_compute_axis_order = ("0,1,2,3", "0,2,1,3") + if s not in valid_compute_axis_order: # currently supported compute_axis_order + raise ValueError("Invalid compute_axis_order was passed. Valid options ", valid_compute_axis_order) + def validate_attention_type(s: str) -> None: valid_attention_types = ("autoselected", "dot_product", "flash", "cudnn_flash_te") if s not in valid_attention_types: # currently supported attention @@ -66,6 +71,7 @@ def validate_profiler_type(s: str) -> None: def validate_keys(keys): validate_attention_type(keys["attention"]) validate_profiler_type(keys["profiler"]) + validate_compute_axis_order(keys["compute_axis_order"]) assert (keys["load_parameters_path"] == "" and keys["load_full_state_path"] == "") or keys[ "enable_checkpointing" diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index 57648ac28..e1d782cd1 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -14,6 +14,8 @@ """Tests for Attentions.""" +import itertools +import random import sys import unittest @@ -29,7 +31,6 @@ import pyconfig from layers import attentions -from layers import embeddings Mesh = jax.sharding.Mesh Attention = attentions.Attention @@ -257,71 +258,73 @@ def tpu_kernel_attention_helper(self, num_kv_heads): ) @pytest.mark.tpu - def test_dot_product_1203_1203(self): - self.dot_product_attention_helper( - prefill_cache_axis_order=(1,2,0,3), - ar_cache_axis_order=(1,2,0,3) - ) + def test_dot_product_cache_axis_order(self): + all_axis_orders = [axis_order for axis_order in itertools.permutations(range(4))] + for axis_order in random.choices(all_axis_orders, k=4): + self.dot_product_attention_helper( + prefill_cache_axis_order=axis_order, + ar_cache_axis_order=axis_order + ) + print(f"passed test for {axis_order=}") - @pytest.mark.tpu - def test_dot_product_1203_2130(self): - self.dot_product_attention_helper( - prefill_cache_axis_order=(1,2,0,3), - ar_cache_axis_order=(2,1,3,0) - ) + def dot_product_attention_helper(self, prefill_cache_axis_order, ar_cache_axis_order): + for compute_axis_order in [(0,1,2,3), (0,2,1,3)]: + self._dot_product_attention( + prefill_cache_axis_order, + ar_cache_axis_order, + compute_axis_order=compute_axis_order, + ) + print(f"passed subtest for {compute_axis_order=}") + + def _dot_product_attention( + self, + prefill_cache_axis_order, + ar_cache_axis_order, + compute_axis_order, + ): + """Test equalvant between different layout control in dot_product""" - @pytest.mark.tpu - def test_dot_product_2130_1203(self): - self.dot_product_attention_helper( - prefill_cache_axis_order=(2,1,3,0), - ar_cache_axis_order=(1,2,0,3) - ) + rtol, atol = 1e-02, 1e-02 - @pytest.mark.tpu - def test_dot_product_2130_2130(self): - self.dot_product_attention_helper( - prefill_cache_axis_order=(2,1,3,0), - ar_cache_axis_order=(2,1,3,0), + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1.0, + run_name="test", + enable_checkpointing=False, + max_target_length=128, + max_prefill_predict_length=16, + attention="dot_product", ) + config = pyconfig.config - def dot_product_attention_helper(self, prefill_cache_axis_order, ar_cache_axis_order): - self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=False, rtol=1e-02, atol=1e-01) - self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=True, rtol=1e-01, atol=1e-01) - - def _dot_product_attention(self, prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache, rtol, atol): - """Test equalvant between different layout control in dot_product""" - prefill_length = self.max_prefill_predict_length - decode_total_length = self.max_target_length - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) + prefill_length = config.max_prefill_predict_length + decode_total_length = config.max_target_length + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(config.dtype) lnx_prefill = lnx[:, 0:prefill_length, :] decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] decoder_positions_prefill = decoder_positions[:, 0:prefill_length] attention_w_layout = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, mesh=self.mesh, - attention_kernel="dot_product", - dtype=self.dtype, - prefill_key_axis_order=prefill_cache_axis_order, - prefill_value_axis_order=prefill_cache_axis_order, - ar_key_axis_order=ar_cache_axis_order, - ar_value_axis_order=ar_cache_axis_order, - quantize_kvcache=quantize_kvcache, + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + dtype=config.dtype, + prefill_cache_axis_order=prefill_cache_axis_order, + ar_cache_axis_order=ar_cache_axis_order, + compute_axis_order=compute_axis_order, ) - attention_w_layout_variable = attention_w_layout.init( {"params": self.rng, "aqt": self.rng}, - jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones((self.global_batch_size, self.max_target_length)), + jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)), + jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)), + jnp.ones((self.global_batch_size, config.max_target_length)), ) - attention_w_layout_full = attention_w_layout.apply( attention_w_layout_variable, lnx, @@ -371,59 +374,74 @@ def _dot_product_attention(self, prefill_cache_axis_order, ar_cache_axis_order, @pytest.mark.tpu def test_dot_product_reshape_q(self): - self._dot_product_attention_reshape_q(quantize_kvcache=True, rtol=1e-01, atol=1e-01) - self._dot_product_attention_reshape_q(quantize_kvcache=False, rtol=1e-02, atol=1e-02) + for compute_axis_order in [(0,1,2,3), (0,2,1,3)]: + self._dot_product_attention_reshape_q( + compute_axis_order=compute_axis_order, + ) + print(f"test passed for compute_axis_order: {compute_axis_order}") - def _dot_product_attention_reshape_q(self, quantize_kvcache, rtol, atol): + def _dot_product_attention_reshape_q(self, compute_axis_order): """Test equalvant between q and reshape q in dot_product""" - prefill_length = self.max_prefill_predict_length - decode_total_length = self.max_target_length - lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) + + rtol, atol = 1e-02, 1e-02 + + pyconfig.initialize( + [sys.argv[0], "configs/base.yml"], + per_device_batch_size=1.0, + run_name="test", + enable_checkpointing=False, + max_target_length=128, + max_prefill_predict_length=16, + attention="dot_product", + ) + config = pyconfig.config + + prefill_length = config.max_prefill_predict_length + decode_total_length = config.max_target_length + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(config.dtype) lnx_prefill = lnx[:, 0:prefill_length, :] decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] decoder_positions_prefill = decoder_positions[:, 0:prefill_length] attention_wo_reshape_q = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, mesh=self.mesh, - attention_kernel="dot_product", - dtype=self.dtype, + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + dtype=config.dtype, + compute_axis_order=compute_axis_order, reshape_q=False, - quantize_kvcache=quantize_kvcache, + ) + attention_wo_reshape_q_variable = attention_wo_reshape_q.init( + {"params": self.rng, "aqt": self.rng}, + jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)), + jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)), + jnp.ones((self.global_batch_size, config.max_target_length)), ) attention_w_reshape_q = Attention( - config=self.cfg, - num_query_heads=self.num_query_heads, - num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, - max_target_length=self.max_target_length, - max_prefill_predict_length=self.max_prefill_predict_length, mesh=self.mesh, - attention_kernel="dot_product", - dtype=self.dtype, + config=config, + num_query_heads=config.num_query_heads, + num_kv_heads=config.num_kv_heads, + head_dim=config.head_dim, + max_target_length=config.max_target_length, + max_prefill_predict_length=config.max_prefill_predict_length, + attention_kernel=config.attention, + dtype=config.dtype, + compute_axis_order=compute_axis_order, reshape_q=True, - quantize_kvcache=quantize_kvcache, - ) - - attention_wo_reshape_q_variable = attention_wo_reshape_q.init( - {"params": self.rng, "aqt": self.rng}, - jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones((self.global_batch_size, self.max_target_length)), ) - attention_w_reshape_q_variable = attention_w_reshape_q.init( {"params": self.rng, "aqt": self.rng}, - jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), - jnp.ones((self.global_batch_size, self.max_target_length)), + jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)), + jnp.ones((self.global_batch_size, config.max_target_length, config.base_emb_dim)), + jnp.ones((self.global_batch_size, config.max_target_length)), ) attention_wo_reshape_q_full = attention_wo_reshape_q.apply(