From 02b681ef035a36b8f0c3490e234a099aaa9dabb2 Mon Sep 17 00:00:00 2001 From: Morgan Du Date: Wed, 5 Jun 2024 23:52:14 +0000 Subject: [PATCH] reshape q --- .gitignore | 1 + MaxText/configs/base.yml | 1 + MaxText/layers/attentions.py | 25 +++-- MaxText/layers/llama2.py | 1 + MaxText/tests/attention_test.py | 165 +++++++++++++++++++++++++++++++- 5 files changed, 182 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 69382978b..938243bb7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ *__pycache__* tmp/ +logs/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 27efc0ce2..ed4f5edd2 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -307,6 +307,7 @@ prefill_key_axis_order: "1,2,0,3" prefill_value_axis_order: "1,2,0,3" ar_key_axis_order: "1,2,0,3" ar_value_axis_order: "1,2,0,3" +reshape_q: False # Maxengine Metrics prometheus_port: 0 diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 8261a14c8..4e3cd8ffb 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -113,6 +113,7 @@ class AttentionOp(nn.Module): prefill_value_axis_order: AxisIdxes = (1, 2, 0, 3) ar_key_axis_order: AxisIdxes = (1, 2, 0, 3) ar_value_axis_order: AxisIdxes = (1, 2, 0, 3) + reshape_q: bool = False dropout_rate: float = 0.0 dtype: DType = jnp.float32 quant: Optional[Quant] = None @@ -277,7 +278,7 @@ def cudnn_flash_attention( ) return dpa_layer(query, key, value, mask=attn_mask) - def compute_local_attention(self, attn_weights: Array, value: Array) -> tuple[Array, Array, Array]: + def compute_local_attention(self, attn_weights: Array, value: Array, q_seq_len: int) -> tuple[Array, Array, Array]: """Computes the attention of a local subset of the kv cache. Local attention results will need to be combined with any other local attentions and normalized Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py @@ -303,7 +304,12 @@ def compute_local_attention(self, attn_weights: Array, value: Array) -> tuple[Ar local_max = jnp.reshape(local_max, (local_max.shape[0], local_max.shape[1], local_max.shape[2] * local_max.shape[3], 1)) local_sum = jnp.reshape(local_sum, (local_sum.shape[0], local_sum.shape[1], local_sum.shape[2] * local_sum.shape[3], 1)) - local_out = self.wv_product(local_exps, value) + local_out = self.wv_product(local_exps, value, q_seq_len) + + if self.reshape_q and q_seq_len == 1: + local_max = local_max[:,0:1,:,:] + local_sum = local_sum[:,0:1,:,:] + return local_out, local_max, local_sum def apply_attention_dot( @@ -320,7 +326,8 @@ def apply_attention_dot( query = query.astype(jnp.float32) key = key.astype(jnp.float32) - attn_weights = self.qk_product(query, key) + q_seq_len = query.shape[1] + attn_weights = self.qk_product(query, key, q_seq_len) # Casting softmaxt computation for float32 for model stability. if self.float32_logits: @@ -328,9 +335,9 @@ def apply_attention_dot( attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) if attn_mask is not None: attn_weights = apply_mask_to_logits(attn_weights, attn_mask) - return self.compute_local_attention(attn_weights, value) + return self.compute_local_attention(attn_weights, value, q_seq_len) - def qk_product(self, query: Array, key: Array) -> Array: + def qk_product(self, query: Array, key: Array, q_seq_len: int) -> Array: """Query-Key product. Args: @@ -353,10 +360,12 @@ def qk_product(self, query: Array, key: Array) -> Array: n_kv = key.shape[-2] assert n_kv == self.num_kv_heads query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d)) + if self.reshape_q and q_seq_len == 1: + query = jnp.broadcast_to(query, (b, 2, n_kv, n // n_kv, d)) result = jnp.einsum("btkgd,bskd->bkgts", query, key) return result - def wv_product(self, attn_weights: Array, value: Array) -> Array: + def wv_product(self, attn_weights: Array, value: Array, q_seq_len: int) -> Array: """weighted value product. Args: @@ -378,6 +387,8 @@ def wv_product(self, attn_weights: Array, value: Array) -> Array: out = jnp.einsum("bkgts,bskd->btkgd", attn_weights, value) b, t, n_kv, g, d = out.shape result = jnp.reshape(out, (b, t, n_kv * g, d)) + if self.reshape_q and q_seq_len == 1: + result = result[:, 0:1, :, :] return result def revert_kv_cache(self, kv, cached_axis_order): @@ -907,6 +918,7 @@ class Attention(nn.Module): prefill_value_axis_order: AxisIdxes = (1, 2, 0, 3) ar_key_axis_order: AxisIdxes = (1, 2, 0, 3) ar_value_axis_order: AxisIdxes = (1, 2, 0, 3) + reshape_q: bool = False def query_projection(self, inputs_q: Array) -> Array: """Query projection.""" @@ -1066,6 +1078,7 @@ def __call__( prefill_value_axis_order = self.prefill_value_axis_order, ar_key_axis_order = self.ar_key_axis_order, ar_value_axis_order = self.ar_value_axis_order, + reshape_q = self.reshape_q, ) out = attention_op(query, key, value, decoder_segment_ids, model_mode) diff --git a/MaxText/layers/llama2.py b/MaxText/layers/llama2.py index 157bdf78f..b61971631 100644 --- a/MaxText/layers/llama2.py +++ b/MaxText/layers/llama2.py @@ -100,6 +100,7 @@ def __call__( prefill_value_axis_order=tuple([int(i) for i in cfg.prefill_value_axis_order.split(",")]), ar_key_axis_order=tuple([int(i) for i in cfg.ar_key_axis_order.split(",")]), ar_value_axis_order=tuple([int(i) for i in cfg.ar_value_axis_order.split(",")]), + reshape_q=cfg.reshape_q, ) attention_lnx = attention_layer( diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index 2692cdad9..57648ac28 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -285,11 +285,11 @@ def test_dot_product_2130_2130(self): ) def dot_product_attention_helper(self, prefill_cache_axis_order, ar_cache_axis_order): - self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=False) - self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=True) + self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=False, rtol=1e-02, atol=1e-01) + self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=True, rtol=1e-01, atol=1e-01) - def _dot_product_attention(self, prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache): - """Test equalvant between dot_product and TPU accelerated""" + def _dot_product_attention(self, prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache, rtol, atol): + """Test equalvant between different layout control in dot_product""" prefill_length = self.max_prefill_predict_length decode_total_length = self.max_target_length lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) @@ -367,7 +367,162 @@ def _dot_product_attention(self, prefill_cache_axis_order, ar_cache_axis_order, attention_w_layout_full_this_idx = attention_w_layout_full[:, idx : idx + 1, :] self.assertTrue(attention_w_layout_full_this_idx.shape == attention_w_layout_idx.shape) - self.assertTrue(jax.numpy.allclose(attention_w_layout_full_this_idx, attention_w_layout_idx, rtol=1e-02, atol=1e-01, equal_nan=False)) + self.assertTrue(jax.numpy.allclose(attention_w_layout_full_this_idx, attention_w_layout_idx, rtol=rtol, atol=atol, equal_nan=False)) + + @pytest.mark.tpu + def test_dot_product_reshape_q(self): + self._dot_product_attention_reshape_q(quantize_kvcache=True, rtol=1e-01, atol=1e-01) + self._dot_product_attention_reshape_q(quantize_kvcache=False, rtol=1e-02, atol=1e-02) + + def _dot_product_attention_reshape_q(self, quantize_kvcache, rtol, atol): + """Test equalvant between q and reshape q in dot_product""" + prefill_length = self.max_prefill_predict_length + decode_total_length = self.max_target_length + lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype) + + lnx_prefill = lnx[:, 0:prefill_length, :] + decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length] + decoder_positions_prefill = decoder_positions[:, 0:prefill_length] + + attention_wo_reshape_q = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=self.dtype, + reshape_q=False, + quantize_kvcache=quantize_kvcache, + ) + + attention_w_reshape_q = Attention( + config=self.cfg, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + mesh=self.mesh, + attention_kernel="dot_product", + dtype=self.dtype, + reshape_q=True, + quantize_kvcache=quantize_kvcache, + ) + + attention_wo_reshape_q_variable = attention_wo_reshape_q.init( + {"params": self.rng, "aqt": self.rng}, + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length)), + ) + + attention_w_reshape_q_variable = attention_w_reshape_q.init( + {"params": self.rng, "aqt": self.rng}, + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)), + jnp.ones((self.global_batch_size, self.max_target_length)), + ) + + attention_wo_reshape_q_full = attention_wo_reshape_q.apply( + attention_wo_reshape_q_variable, + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=common_types.MODEL_MODE_TRAIN, + rngs={"aqt": self.rng}, + ) + + attention_w_reshape_q_full = attention_w_reshape_q.apply( + attention_w_reshape_q_variable, + lnx, + lnx, + decoder_segment_ids=decoder_segment_ids, + inputs_positions=decoder_positions, + deterministic=True, + model_mode=common_types.MODEL_MODE_TRAIN, + rngs={"aqt": self.rng}, + ) + + attention_wo_reshape_q_prefill, attention_wo_reshape_q_output_cache = attention_wo_reshape_q.apply( + attention_wo_reshape_q_variable, + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"aqt": self.rng}, + mutable=["cache"], + ) + self.assertTrue( + jax.numpy.allclose(attention_wo_reshape_q_full[:, :prefill_length, :], attention_wo_reshape_q_prefill, equal_nan=False) + ) + + attention_w_reshape_q_prefill, attention_w_reshape_q_output_cache = attention_w_reshape_q.apply( + attention_w_reshape_q_variable, + lnx_prefill, + lnx_prefill, + decoder_segment_ids=decoder_segment_ids_prefill, + inputs_positions=decoder_positions_prefill, + deterministic=True, + model_mode=common_types.MODEL_MODE_PREFILL, + rngs={"aqt": self.rng}, + mutable=["cache"], + ) + self.assertTrue( + jax.numpy.allclose(attention_w_reshape_q_full[:, :prefill_length, :], attention_w_reshape_q_prefill, equal_nan=False) + ) + + self.assertTrue( + jax.numpy.allclose(attention_wo_reshape_q_prefill, attention_w_reshape_q_prefill, equal_nan=False) + ) + self.assertTrue( + jax.numpy.allclose(attention_wo_reshape_q_full[:, :prefill_length, :], attention_w_reshape_q_full[:, :prefill_length, :], equal_nan=False) + ) + + for idx in range(prefill_length, decode_total_length): + + lnx_idx = lnx[:, idx : idx + 1, :] + decoder_positions_idx = decoder_positions[:, idx : idx + 1] + + attention_wo_reshape_q_variable.update(attention_wo_reshape_q_output_cache) + attention_wo_reshape_q_idx, attention_wo_reshape_q_output_cache = attention_wo_reshape_q.apply( + attention_wo_reshape_q_variable, + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, + rngs={"aqt": self.rng}, + mutable=["cache"], + ) + + attention_wo_reshape_q_full_this_idx = attention_wo_reshape_q_full[:, idx : idx + 1, :] + self.assertTrue(attention_wo_reshape_q_full_this_idx.shape == attention_wo_reshape_q_idx.shape) + self.assertTrue(jax.numpy.allclose(attention_wo_reshape_q_full_this_idx, attention_wo_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False)) + + attention_w_reshape_q_variable.update(attention_w_reshape_q_output_cache) + attention_w_reshape_q_idx, attention_w_reshape_q_output_cache = attention_w_reshape_q.apply( + attention_w_reshape_q_variable, + lnx_idx, + lnx_idx, + inputs_positions=decoder_positions_idx, + deterministic=True, + model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE, + rngs={"aqt": self.rng}, + mutable=["cache"], + ) + + attention_w_reshape_q_full_this_idx = attention_w_reshape_q_full[:, idx : idx + 1, :] + self.assertTrue(attention_w_reshape_q_full_this_idx.shape == attention_w_reshape_q_idx.shape) + self.assertTrue(jax.numpy.allclose(attention_w_reshape_q_full_this_idx, attention_w_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False)) + + self.assertTrue(jax.numpy.allclose(attention_w_reshape_q_idx, attention_wo_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False)) if __name__ == "__main__":