Skip to content

Commit

Permalink
Merge pull request #690 from google:mor--reshape-q
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642676990
  • Loading branch information
maxtext authors committed Jun 12, 2024
2 parents 7cdca96 + 02b681e commit d0701b0
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 11 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
*__pycache__*
tmp/
logs/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
1 change: 1 addition & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ prefill_key_axis_order: "1,2,0,3"
prefill_value_axis_order: "1,2,0,3"
ar_key_axis_order: "1,2,0,3"
ar_value_axis_order: "1,2,0,3"
reshape_q: False

# Maxengine Metrics
prometheus_port: 0
Expand Down
25 changes: 19 additions & 6 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class AttentionOp(nn.Module):
prefill_value_axis_order: AxisIdxes = (1, 2, 0, 3)
ar_key_axis_order: AxisIdxes = (1, 2, 0, 3)
ar_value_axis_order: AxisIdxes = (1, 2, 0, 3)
reshape_q: bool = False
dropout_rate: float = 0.0
dtype: DType = jnp.float32
quant: Optional[Quant] = None
Expand Down Expand Up @@ -277,7 +278,7 @@ def cudnn_flash_attention(
)
return dpa_layer(query, key, value, mask=attn_mask)

def compute_local_attention(self, attn_weights: Array, value: Array) -> tuple[Array, Array, Array]:
def compute_local_attention(self, attn_weights: Array, value: Array, q_seq_len: int) -> tuple[Array, Array, Array]:
"""Computes the attention of a local subset of the kv cache.
Local attention results will need to be combined with any other local attentions and normalized
Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py
Expand All @@ -303,7 +304,12 @@ def compute_local_attention(self, attn_weights: Array, value: Array) -> tuple[Ar
local_max = jnp.reshape(local_max, (local_max.shape[0], local_max.shape[1], local_max.shape[2] * local_max.shape[3], 1))
local_sum = jnp.reshape(local_sum, (local_sum.shape[0], local_sum.shape[1], local_sum.shape[2] * local_sum.shape[3], 1))

local_out = self.wv_product(local_exps, value)
local_out = self.wv_product(local_exps, value, q_seq_len)

if self.reshape_q and q_seq_len == 1:
local_max = local_max[:,0:1,:,:]
local_sum = local_sum[:,0:1,:,:]

return local_out, local_max, local_sum

def apply_attention_dot(
Expand All @@ -320,17 +326,18 @@ def apply_attention_dot(
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)

attn_weights = self.qk_product(query, key)
q_seq_len = query.shape[1]
attn_weights = self.qk_product(query, key, q_seq_len)

# Casting softmaxt computation for float32 for model stability.
if self.float32_logits:
attn_weights = attn_weights.astype(jnp.float32)
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
if attn_mask is not None:
attn_weights = apply_mask_to_logits(attn_weights, attn_mask)
return self.compute_local_attention(attn_weights, value)
return self.compute_local_attention(attn_weights, value, q_seq_len)

def qk_product(self, query: Array, key: Array) -> Array:
def qk_product(self, query: Array, key: Array, q_seq_len: int) -> Array:
"""Query-Key product.
Args:
Expand All @@ -353,10 +360,12 @@ def qk_product(self, query: Array, key: Array) -> Array:
n_kv = key.shape[-2]
assert n_kv == self.num_kv_heads
query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d))
if self.reshape_q and q_seq_len == 1:
query = jnp.broadcast_to(query, (b, 2, n_kv, n // n_kv, d))
result = jnp.einsum("btkgd,bskd->bkgts", query, key)
return result

def wv_product(self, attn_weights: Array, value: Array) -> Array:
def wv_product(self, attn_weights: Array, value: Array, q_seq_len: int) -> Array:
"""weighted value product.
Args:
Expand All @@ -378,6 +387,8 @@ def wv_product(self, attn_weights: Array, value: Array) -> Array:
out = jnp.einsum("bkgts,bskd->btkgd", attn_weights, value)
b, t, n_kv, g, d = out.shape
result = jnp.reshape(out, (b, t, n_kv * g, d))
if self.reshape_q and q_seq_len == 1:
result = result[:, 0:1, :, :]
return result

def revert_kv_cache(self, kv, cached_axis_order):
Expand Down Expand Up @@ -907,6 +918,7 @@ class Attention(nn.Module):
prefill_value_axis_order: AxisIdxes = (1, 2, 0, 3)
ar_key_axis_order: AxisIdxes = (1, 2, 0, 3)
ar_value_axis_order: AxisIdxes = (1, 2, 0, 3)
reshape_q: bool = False

def query_projection(self, inputs_q: Array) -> Array:
"""Query projection."""
Expand Down Expand Up @@ -1066,6 +1078,7 @@ def __call__(
prefill_value_axis_order = self.prefill_value_axis_order,
ar_key_axis_order = self.ar_key_axis_order,
ar_value_axis_order = self.ar_value_axis_order,
reshape_q = self.reshape_q,
)

out = attention_op(query, key, value, decoder_segment_ids, model_mode)
Expand Down
1 change: 1 addition & 0 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __call__(
prefill_value_axis_order=tuple([int(i) for i in cfg.prefill_value_axis_order.split(",")]),
ar_key_axis_order=tuple([int(i) for i in cfg.ar_key_axis_order.split(",")]),
ar_value_axis_order=tuple([int(i) for i in cfg.ar_value_axis_order.split(",")]),
reshape_q=cfg.reshape_q,
)

attention_lnx = attention_layer(
Expand Down
165 changes: 160 additions & 5 deletions MaxText/tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,11 @@ def test_dot_product_2130_2130(self):
)

def dot_product_attention_helper(self, prefill_cache_axis_order, ar_cache_axis_order):
self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=False)
self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=True)
self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=False, rtol=1e-02, atol=1e-01)
self._dot_product_attention(prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache=True, rtol=1e-01, atol=1e-01)

def _dot_product_attention(self, prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache):
"""Test equalvant between dot_product and TPU accelerated"""
def _dot_product_attention(self, prefill_cache_axis_order, ar_cache_axis_order, quantize_kvcache, rtol, atol):
"""Test equalvant between different layout control in dot_product"""
prefill_length = self.max_prefill_predict_length
decode_total_length = self.max_target_length
lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype)
Expand Down Expand Up @@ -367,7 +367,162 @@ def _dot_product_attention(self, prefill_cache_axis_order, ar_cache_axis_order,

attention_w_layout_full_this_idx = attention_w_layout_full[:, idx : idx + 1, :]
self.assertTrue(attention_w_layout_full_this_idx.shape == attention_w_layout_idx.shape)
self.assertTrue(jax.numpy.allclose(attention_w_layout_full_this_idx, attention_w_layout_idx, rtol=1e-02, atol=1e-01, equal_nan=False))
self.assertTrue(jax.numpy.allclose(attention_w_layout_full_this_idx, attention_w_layout_idx, rtol=rtol, atol=atol, equal_nan=False))

@pytest.mark.tpu
def test_dot_product_reshape_q(self):
self._dot_product_attention_reshape_q(quantize_kvcache=True, rtol=1e-01, atol=1e-01)
self._dot_product_attention_reshape_q(quantize_kvcache=False, rtol=1e-02, atol=1e-02)

def _dot_product_attention_reshape_q(self, quantize_kvcache, rtol, atol):
"""Test equalvant between q and reshape q in dot_product"""
prefill_length = self.max_prefill_predict_length
decode_total_length = self.max_target_length
lnx, decoder_segment_ids, decoder_positions = self.get_structured_data(self.dtype)

lnx_prefill = lnx[:, 0:prefill_length, :]
decoder_segment_ids_prefill = decoder_segment_ids[:, 0:prefill_length]
decoder_positions_prefill = decoder_positions[:, 0:prefill_length]

attention_wo_reshape_q = Attention(
config=self.cfg,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
max_target_length=self.max_target_length,
max_prefill_predict_length=self.max_prefill_predict_length,
mesh=self.mesh,
attention_kernel="dot_product",
dtype=self.dtype,
reshape_q=False,
quantize_kvcache=quantize_kvcache,
)

attention_w_reshape_q = Attention(
config=self.cfg,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
max_target_length=self.max_target_length,
max_prefill_predict_length=self.max_prefill_predict_length,
mesh=self.mesh,
attention_kernel="dot_product",
dtype=self.dtype,
reshape_q=True,
quantize_kvcache=quantize_kvcache,
)

attention_wo_reshape_q_variable = attention_wo_reshape_q.init(
{"params": self.rng, "aqt": self.rng},
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length)),
)

attention_w_reshape_q_variable = attention_w_reshape_q.init(
{"params": self.rng, "aqt": self.rng},
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length, self.embed_dim)),
jnp.ones((self.global_batch_size, self.max_target_length)),
)

attention_wo_reshape_q_full = attention_wo_reshape_q.apply(
attention_wo_reshape_q_variable,
lnx,
lnx,
decoder_segment_ids=decoder_segment_ids,
inputs_positions=decoder_positions,
deterministic=True,
model_mode=common_types.MODEL_MODE_TRAIN,
rngs={"aqt": self.rng},
)

attention_w_reshape_q_full = attention_w_reshape_q.apply(
attention_w_reshape_q_variable,
lnx,
lnx,
decoder_segment_ids=decoder_segment_ids,
inputs_positions=decoder_positions,
deterministic=True,
model_mode=common_types.MODEL_MODE_TRAIN,
rngs={"aqt": self.rng},
)

attention_wo_reshape_q_prefill, attention_wo_reshape_q_output_cache = attention_wo_reshape_q.apply(
attention_wo_reshape_q_variable,
lnx_prefill,
lnx_prefill,
decoder_segment_ids=decoder_segment_ids_prefill,
inputs_positions=decoder_positions_prefill,
deterministic=True,
model_mode=common_types.MODEL_MODE_PREFILL,
rngs={"aqt": self.rng},
mutable=["cache"],
)
self.assertTrue(
jax.numpy.allclose(attention_wo_reshape_q_full[:, :prefill_length, :], attention_wo_reshape_q_prefill, equal_nan=False)
)

attention_w_reshape_q_prefill, attention_w_reshape_q_output_cache = attention_w_reshape_q.apply(
attention_w_reshape_q_variable,
lnx_prefill,
lnx_prefill,
decoder_segment_ids=decoder_segment_ids_prefill,
inputs_positions=decoder_positions_prefill,
deterministic=True,
model_mode=common_types.MODEL_MODE_PREFILL,
rngs={"aqt": self.rng},
mutable=["cache"],
)
self.assertTrue(
jax.numpy.allclose(attention_w_reshape_q_full[:, :prefill_length, :], attention_w_reshape_q_prefill, equal_nan=False)
)

self.assertTrue(
jax.numpy.allclose(attention_wo_reshape_q_prefill, attention_w_reshape_q_prefill, equal_nan=False)
)
self.assertTrue(
jax.numpy.allclose(attention_wo_reshape_q_full[:, :prefill_length, :], attention_w_reshape_q_full[:, :prefill_length, :], equal_nan=False)
)

for idx in range(prefill_length, decode_total_length):

lnx_idx = lnx[:, idx : idx + 1, :]
decoder_positions_idx = decoder_positions[:, idx : idx + 1]

attention_wo_reshape_q_variable.update(attention_wo_reshape_q_output_cache)
attention_wo_reshape_q_idx, attention_wo_reshape_q_output_cache = attention_wo_reshape_q.apply(
attention_wo_reshape_q_variable,
lnx_idx,
lnx_idx,
inputs_positions=decoder_positions_idx,
deterministic=True,
model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE,
rngs={"aqt": self.rng},
mutable=["cache"],
)

attention_wo_reshape_q_full_this_idx = attention_wo_reshape_q_full[:, idx : idx + 1, :]
self.assertTrue(attention_wo_reshape_q_full_this_idx.shape == attention_wo_reshape_q_idx.shape)
self.assertTrue(jax.numpy.allclose(attention_wo_reshape_q_full_this_idx, attention_wo_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False))

attention_w_reshape_q_variable.update(attention_w_reshape_q_output_cache)
attention_w_reshape_q_idx, attention_w_reshape_q_output_cache = attention_w_reshape_q.apply(
attention_w_reshape_q_variable,
lnx_idx,
lnx_idx,
inputs_positions=decoder_positions_idx,
deterministic=True,
model_mode=common_types.MODEL_MODE_AUTOREGRESSIVE,
rngs={"aqt": self.rng},
mutable=["cache"],
)

attention_w_reshape_q_full_this_idx = attention_w_reshape_q_full[:, idx : idx + 1, :]
self.assertTrue(attention_w_reshape_q_full_this_idx.shape == attention_w_reshape_q_idx.shape)
self.assertTrue(jax.numpy.allclose(attention_w_reshape_q_full_this_idx, attention_w_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False))

self.assertTrue(jax.numpy.allclose(attention_w_reshape_q_idx, attention_wo_reshape_q_idx, rtol=rtol, atol=atol, equal_nan=False))


if __name__ == "__main__":
Expand Down

0 comments on commit d0701b0

Please sign in to comment.