-
Notifications
You must be signed in to change notification settings - Fork 232
/
attentions.py
1115 lines (935 loc) · 44.9 KB
/
attentions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Attentions Layers."""
import functools
import math
from typing import Optional, Sequence
from flax import linen as nn
import jax
from jax import lax
from jax import random
from jax.ad_checkpoint import checkpoint_name
from jax.experimental import shard_map
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
import jax.numpy as jnp
import common_types
from layers import embeddings
from layers import initializers
from layers import linears
from layers import quantizations
Array = common_types.Array
Config = common_types.Config
DType = common_types.DType
Mesh = common_types.Mesh
PRNGKey = common_types.PRNGKey
DenseGeneral = linears.DenseGeneral
RotaryEmbedding = embeddings.RotaryEmbedding
NdInitializer = initializers.NdInitializer
Quant = quantizations.AqtQuantization
AxisNames = common_types.AxisNames
AxisIdxes = common_types.AxisIdxes
BATCH = common_types.BATCH
KV_BATCH = common_types.KV_BATCH
LENGTH = common_types.LENGTH
HEAD = common_types.HEAD
KV_HEAD = common_types.KV_HEAD
D_KV = common_types.D_KV
KV_HEAD_DIM = common_types.KV_HEAD_DIM
CACHE_BATCH = common_types.CACHE_BATCH
CACHE_SEQUENCE = common_types.CACHE_SEQUENCE
CACHE_HEADS = common_types.CACHE_HEADS
CACHE_KV = common_types.CACHE_KV
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
nd_dense_init = initializers.nd_dense_init
shard_map = shard_map.shard_map
dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes
# pytype: disable=attribute-error
def validate_compute_axis_order(s: AxisIdxes) -> None:
valid_compute_axis_order = ((0,1,2,3), (0,2,1,3))
if s not in valid_compute_axis_order: # currently supported compute_axis_order
raise ValueError("Invalid compute_axis_order was passed. Valid options ", valid_compute_axis_order)
def apply_mask_to_logits(logits: Array, mask: Array):
"""Applies a floating-point mask to a set of logits.
The mask is represented as a tensor with some dtype where 0 represents true and values
below a large negative number (here set to
get_large_negative_number(logits.dtype) / 2) represent false. Applying the mask
leaves the logits alone in the true case and replaces them by
get_large_negative_number(logits.dtype) in the false case. Previously, this was
done by adding the logits to the mask; however, this leads to a bad fusion
decision in the compiler that saves the values in memory rather than
just the predicate. This implementation avoids that problem.
from https://github.com/google/praxis/blob/4712a6b9ee13e224b86e235ff55f7c6bab9fbab3/praxis/py_utils.py#L706
Args:
logits: A JTensor of logit values.
mask: A JTensor of mask values with the encoding described in the
function documentation.
Returns:
Masked logits.
"""
return jnp.where((mask >= DEFAULT_MASK_VALUE * 0.5), logits, DEFAULT_MASK_VALUE)
def _maybe_aqt_einsum(quant: Quant):
"""Maybe overwrite dot general with aqt_dot_general."""
return jnp.einsum if quant is None else quant.einsum()
class AttentionOp(nn.Module):
mesh: Mesh
attention_kernel: str
max_target_length: int
num_query_heads: int
num_kv_heads: int
float32_qk_product: bool = False
max_prefill_predict_length: int = -1
float32_logits: bool = False
flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV)
kv_cache_logical_layout: AxisNames = (CACHE_BATCH, CACHE_SEQUENCE, CACHE_HEADS, CACHE_KV)
prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3)
ar_cache_axis_order: AxisIdxes = (1, 2, 0, 3)
compute_axis_order: AxisIdxes = (0, 1, 2, 3)
reshape_q: bool = False
dropout_rate: float = 0.0
dtype: DType = jnp.float32
quant: Optional[Quant] = None
quantize_kvcache: bool = False
def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None:
"""Check attention inputs."""
assert key.ndim == value.ndim, "k, v must have same rank."
assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match."
assert key.shape[-2] == value.shape[-2], "k, v num_kv_heads must match."
assert key.shape[-3] == value.shape[-3], "k, v lengths must match."
assert query.shape[-1] == key.shape[-1], "q, k depths must match."
# Following Pallas MHA Flash Attention Reference.
# https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py
# This mask models (1) separate sequences (decoder_segment_ids) and (2) causality
def generate_attention_mask(self, query, key, decoder_segment_ids: Array | None, model_mode: str) -> Array | None:
mask = None
if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
mask = decoder_segment_ids[:, None, None, None, :] == common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR
elif decoder_segment_ids is not None:
mask = decoder_segment_ids[:, :, None] == decoder_segment_ids[:, None, :]
mask = mask[:, None, None, :, :]
causal_mask = None
# We enforce causality except for AUTOREGRESSION
if model_mode != common_types.MODEL_MODE_AUTOREGRESSIVE:
_, q_seq_len, _, _ = query.shape
_, kv_seq_len, _, _ = key.shape
mask_shape = (q_seq_len, kv_seq_len)
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
causal_mask = (col_ids <= row_ids)[None, None, None, :, :]
if (mask is not None) and (causal_mask is not None):
output_mask = jnp.logical_and(mask, causal_mask)
elif mask is not None:
output_mask = mask
elif causal_mask is not None:
output_mask = causal_mask
else:
output_mask = None
return jnp.where(output_mask, 0.0, DEFAULT_MASK_VALUE) if output_mask is not None else None
def apply_attention(self, query: Array, key: Array, value: Array, decoder_segment_ids: Array | None, model_mode: str):
self.check_attention_inputs(query, key, value)
length = query.shape[-3]
if (
self.attention_kernel == "dot_product"
or (self.attention_kernel == "autoselected" and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE)
or (self.attention_kernel == "autoselected" and length < 128)
):
return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode)
elif self.attention_kernel == "flash" or self.attention_kernel == "autoselected":
if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
raise ValueError(
"""Decode not supported with flash attention.
Use `dot_product` instead."""
)
return self.tpu_flash_attention(query, key, value, decoder_segment_ids), None, None
elif self.attention_kernel == "cudnn_flash_te":
if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
raise ValueError(
"""Decode not supported with flash attention.
Use `dot_product` instead."""
)
return self.cudnn_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, None
else:
raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.")
def tpu_flash_attention(self, query: Array, key: Array, value: Array, decoder_segment_ids: Array | None) -> Array:
"""TPU Flash Attention."""
# Transpose to ('batch', 'heads', 'length', 'kv')
query = jnp.transpose(query, axes=(0, 2, 1, 3))
key = jnp.transpose(key, axes=(0, 2, 1, 3))
value = jnp.transpose(value, axes=(0, 2, 1, 3))
if decoder_segment_ids is not None:
decoder_segment_ids = splash_attention_kernel.SegmentIds(decoder_segment_ids, decoder_segment_ids)
axis_names = nn.logical_to_mesh_axes(self.flash_axis_names)
segment_axis_names = nn.logical_to_mesh_axes((BATCH, "activation_length_no_heads"))
@functools.partial(
shard_map,
mesh=self.mesh,
in_specs=(
axis_names,
axis_names,
axis_names,
segment_axis_names,
),
out_specs=axis_names,
check_rep=False,
)
def wrap_flash_attention(query, key, value, decoder_segment_ids):
if decoder_segment_ids is not None:
assert (
query.shape[2] == decoder_segment_ids.q.shape[1]
), "Sharding along sequence dimension not allowed in tpu kernel attention"
block_sizes = splash_attention_kernel.BlockSizes(
block_q=min(512, query.shape[2]),
block_kv_compute=min(512, key.shape[2]),
block_kv=min(512, key.shape[2]),
block_q_dkv=min(512, query.shape[2]),
block_kv_dkv=min(512, key.shape[2]),
block_kv_dkv_compute=min(512, query.shape[2]),
block_q_dq=min(512, query.shape[2]),
block_kv_dq=min(512, query.shape[2]),
)
masks = [splash_attention_mask.CausalMask(shape=(query.shape[2], query.shape[2])) for i in range(query.shape[1])]
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks)
splash_kernel = splash_attention_kernel.make_splash_mha(
mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes
)
return jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids)
devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"]
assert (query.shape[0] / devices_in_data_fsdp).is_integer(), (
"Batch dimension should be shardable among the devices in data and fsdp" " axis"
)
x = wrap_flash_attention(query, key, value, decoder_segment_ids)
x = jnp.transpose(x, axes=(0, 2, 1, 3))
return x
def cudnn_flash_attention(
self,
query: Array,
key: Array,
value: Array,
decoder_segment_ids: Array | None,
model_mode: str = common_types.MODEL_MODE_TRAIN,
) -> Array:
"""CUDNN Flash Attention with Transformer Engine.
1. Stable API, supports GQA
2. Supports head_dim till 128; head_dim=256 support will be added soon
"""
# These imports are only meant to work in a GPU build.
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
_, _, _, head_dim = query.shape # pylint: disable=unused-variable
# generate attn_mask
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
dpa_layer = DotProductAttention(
head_dim=head_dim,
num_attention_heads=self.num_query_heads,
num_gqa_groups=self.num_kv_heads,
attn_mask_type="causal", # 'causal' or 'padding'
attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
attention_dropout=self.dropout_rate,
dropout_rng_name="aqt",
dtype=self.dtype,
float32_logits=self.float32_logits,
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
scale_factor=1.0 / math.sqrt(head_dim),
transpose_batch_sequence=False,
)
return dpa_layer(query, key, value, mask=attn_mask)
def compute_local_attention(self, attn_weights: Array, value: Array, q_seq_len: int, model_mode: str) -> tuple[Array, Array, Array]:
"""Computes the attention of a local subset of the kv cache.
Local attention results will need to be combined with any other local attentions and normalized
Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py
Args:
attn_weights (Array): Product of query and key
value (Array): Current value
aqt_rng (PRNGKey | None): Optional rng
Returns:
(local_out, local_max,): where
local_out is local unnormalized output
local_max is the local max of exponentials
local_sum is the sum of exponentials for this chunk, divided by exp(local_max).
"""
local_max = jnp.max(attn_weights, axis=-1, keepdims=True)
local_exps = jnp.exp(attn_weights - local_max)
local_sum = jnp.sum(local_exps, axis=-1, keepdims=True)
local_sum = jnp.moveaxis(local_sum, -2, 1)
local_max = jnp.moveaxis(local_max, -2, 1)
local_max = jnp.reshape(local_max, (local_max.shape[0], local_max.shape[1], local_max.shape[2] * local_max.shape[3], 1))
local_sum = jnp.reshape(local_sum, (local_sum.shape[0], local_sum.shape[1], local_sum.shape[2] * local_sum.shape[3], 1))
local_out = self.wv_product(local_exps, value, model_mode)
if self.reshape_q and q_seq_len == 1:
local_max = local_max[:,0:1,:,:]
local_sum = local_sum[:,0:1,:,:]
local_out = local_out[:,0:1,:,:]
return local_out, local_max, local_sum
def apply_attention_dot(
self,
query: Array,
key: Array,
value: Array,
decoder_segment_ids: Array | None,
model_mode: str = common_types.MODEL_MODE_TRAIN,
):
"""Apply Attention."""
validate_compute_axis_order(self.compute_axis_order)
# Casting qk_product and softmaxt computation for float32 for model stability.
if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_qk_product:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
q_seq_len = query.shape[1]
attn_weights = self.qk_product(query, key, q_seq_len, model_mode)
# Casting softmaxt computation for float32 for model stability.
if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_logits:
attn_weights = attn_weights.astype(jnp.float32)
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
if attn_mask is not None:
attn_weights = apply_mask_to_logits(attn_weights, attn_mask)
return self.compute_local_attention(attn_weights, value, q_seq_len, model_mode)
def qk_product(self, query: Array, key: Array, q_seq_len: int, model_mode: str) -> Array:
"""Query-Key product.
Args:
query: Query projection, in shape of [b, t, n, d]
key: Key projection in shape of [b, s, n_kv, d]
Returns:
results in shape [b, n_kv, n // n_kv, t, s].
Annotations:
b: batch size
t: query length
s: key / value length
d: head / kv dimension
n: number of query heads
n_kv: number of kv heads, sometimes annotated as k
n // n_kv: number of group for query, sometimes annotated with g
"""
b, t, n, d = query.shape
n_kv = key.shape[-2]
assert n_kv == self.num_kv_heads
if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0,1,2,3):
query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d))
if self.reshape_q and q_seq_len == 1:
query = jnp.broadcast_to(query, (b, 2, n_kv, n // n_kv, d))
result = jnp.einsum("btkgd,bskd->bkgts", query, key)
elif self.compute_axis_order == (0,2,1,3):
query = jnp.transpose(query, axes=self.compute_axis_order)
key = jnp.transpose(key, axes=self.compute_axis_order)
query = jnp.reshape(query, (b, n_kv, n // n_kv, t, d))
if self.reshape_q and q_seq_len == 1:
query = jnp.broadcast_to(query, (b, n_kv, n // n_kv, 2, d))
result = jnp.einsum("bkgtd,bksd->bkgts", query, key)
return result
def wv_product(self, attn_weights: Array, value: Array, model_mode: str) -> Array:
"""weighted value product.
Args:
attn_weights: Computed results of qk_einsum, in shape [b, n_kv, n // n_kv, t, s]
value: Value projection, in shape of [b, s, n_kv, d]
Returns:
result in shape [b, t, n, d]
Annotations:
b: batch size
t: query length
s: key / value length
d: head / kv dimension
n: number of query heads
n_kv: number of kv heads, sometimes annotated as k
n // n_kv: number of group for query, sometimes annotated with g
"""
if model_mode == common_types.MODEL_MODE_TRAIN or self.compute_axis_order == (0,1,2,3):
out = jnp.einsum("bkgts,bskd->btkgd", attn_weights, value)
b, t, n_kv, g, d = out.shape
result = jnp.reshape(out, (b, t, n_kv * g, d))
elif self.compute_axis_order == (0,2,1,3):
value = jnp.transpose(value, axes=self.compute_axis_order)
out = jnp.einsum("bkgts,bksd->bkgtd", attn_weights, value)
b, n_kv, g, t, d = out.shape
result = jnp.reshape(out, (b, n_kv * g, t, d))
result = jnp.transpose(result, axes=self.compute_axis_order)
return result
def revert_kv_cache(self, kv, cached_axis_order):
"""Revert key/value cache to logical shape.
Args:
kv: reshaped kv as defined in cached_axis_order
Returns:
revert kv to logical shape as [b, s, n_kv, d]
Annotations:
b: batch size
s: key / value length
n_kv: number of kv heads, sometimes annotated as k
d: head / kv dimension
"""
return jax.numpy.moveaxis(kv, (0, 1, 2, 3), cached_axis_order)
def reshape_kv_cache(self, kv, cached_axis_order):
"""Reshape key/value cache as defined in cached_axis_order.
Args:
kv: in logical shape as [b, s, n_kv, d]
Returns:
reshaped kv as defined in cached_axis_order
Annotations:
b: batch size
s: key / value length
n_kv: number of kv heads, sometimes annotated as k
d: head / kv dimension
"""
axis_order_to_index_mapping = {a:i for i, a in enumerate(cached_axis_order)}
axis_destination = tuple([i for a, i in sorted(axis_order_to_index_mapping.items())])
return jax.numpy.moveaxis(kv, (0, 1, 2, 3), axis_destination)
def cached_kv_layout(self, kv_layout, cached_axis_order):
return tuple([kv_layout[i] for i in cached_axis_order])
def cached_kv_shape(self, kv_shape, cached_axis_order):
"""Cached KV shape.
The key and value have dimension [b, s, n_kv, d], but
we cache them as defined in cached_axis_order for optimized read/write performance.
Args:
kv_shape: shape of key or value for caching, as [b, s, n_kv, d].
Returns:
Swapped kv_shape as defined in cached_axis_order for cache.
Annotations:
b: batch size
s: key / value length
n_kv: number of kv heads, sometimes annotated as k
d: head / kv dimension
"""
return tuple([kv_shape[i] for i in cached_axis_order])
def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache):
dtype = jnp.int8 if quantize_kvcache else jnp.bfloat16
cache_logical_shape = (batch, self.max_prefill_predict_length, heads, kv_head_size)
key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_cache_axis_order)
value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_cache_axis_order)
key_shape = self.cached_kv_shape(cache_logical_shape, self.prefill_cache_axis_order)
value_shape = self.cached_kv_shape(cache_logical_shape, self.prefill_cache_axis_order)
cached_key = self.variable(
"cache",
"cached_prefill_key",
nn.with_logical_partitioning(jnp.zeros, key_layout),
key_shape,
dtype,
)
cached_value = self.variable(
"cache",
"cached_prefill_value",
nn.with_logical_partitioning(jnp.zeros, value_layout),
value_shape,
dtype,
)
cached_segment_id = self.variable(
"cache",
"cache_prefill_segment_id",
nn.with_logical_partitioning(jnp.zeros, (CACHE_BATCH, CACHE_SEQUENCE)),
(cache_logical_shape[0], self.max_prefill_predict_length),
jnp.int32,
)
if self.quantize_kvcache:
cache_logical_shape_scale = (batch, self.max_prefill_predict_length, heads, 1)
key_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.prefill_cache_axis_order)
value_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.prefill_cache_axis_order)
cached_key_scale_var = self.variable(
"cache",
"cached_prefill_key_scale",
nn.with_logical_partitioning(jnp.zeros, key_layout),
key_shape_scale,
jnp.bfloat16,
)
cached_value_scale_var = self.variable(
"cache",
"cached_prefill_value_scale",
nn.with_logical_partitioning(jnp.zeros, value_layout),
value_shape_scale,
jnp.bfloat16,
)
else:
cached_key_scale_var = None
cached_value_scale_var = None
key_vars = (cached_key, cached_key_scale_var)
value_vars = (cached_value, cached_value_scale_var)
return key_vars, value_vars, cached_segment_id
def _get_ar_cache(self, batch, heads, kv_head_size, quantize_kvcache):
dtype = jnp.int8 if quantize_kvcache else jnp.bfloat16
cache_length = self.max_target_length - self.max_prefill_predict_length
cache_logical_shape = (batch, cache_length, heads, kv_head_size)
key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_cache_axis_order)
value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_cache_axis_order)
key_shape = self.cached_kv_shape(cache_logical_shape, self.ar_cache_axis_order)
value_shape = self.cached_kv_shape(cache_logical_shape, self.ar_cache_axis_order)
# TODO(b/339703100): investigate the issue why with_logical_partitioning doesn't enforce sharding
cached_key = self.variable(
"cache",
"cached_ar_key",
nn.with_logical_partitioning(jnp.zeros, key_layout),
key_shape,
dtype,
)
cached_key.value = nn.with_logical_constraint(
cached_key.value,
key_layout,
)
cached_value = self.variable(
"cache",
"cached_ar_value",
nn.with_logical_partitioning(jnp.zeros, value_layout),
value_shape,
dtype,
)
cached_value.value = nn.with_logical_constraint(
cached_value.value,
value_layout,
)
cached_segment_id = self.variable(
"cache",
"cache_ar_segment_id",
nn.with_logical_partitioning(jnp.zeros, (CACHE_BATCH, CACHE_SEQUENCE)),
(cache_logical_shape[0], cache_length),
jnp.int32,
)
if self.quantize_kvcache:
cache_logical_shape_scale = (batch, cache_length, heads, 1)
key_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.ar_cache_axis_order)
value_shape_scale = self.cached_kv_shape(cache_logical_shape_scale, self.ar_cache_axis_order)
cached_key_scale_var = self.variable(
"cache",
"cached_ar_key_scale",
nn.with_logical_partitioning(jnp.zeros, key_layout),
key_shape_scale,
jnp.bfloat16,
)
cached_value_scale_var = self.variable(
"cache",
"cached_ar_value_scale",
nn.with_logical_partitioning(jnp.zeros, value_layout),
value_shape_scale,
jnp.bfloat16,
)
else:
cached_key_scale_var = None
cached_value_scale_var = None
cache_index = self.variable("cache", "cache_ar_index", nn.with_logical_partitioning(jnp.zeros, ()), (1,), jnp.int32)
key_vars = (cached_key, cached_key_scale_var)
value_vars = (cached_value, cached_value_scale_var)
return key_vars, value_vars, cached_segment_id, cache_index
def kv_cache_prefill(
self,
key: Array,
value: Array,
decoder_segment_ids: Array,
):
"""In prefill mode, we zero out the existing cache, run the computation and
prepare the cache as necessary.
Args:
key: in shape [b, s, n, d].
value: in shape [b, s, n, d].
decoder_segment_ids: [b, s] -- marking segment ids for tokens
Returns:
key, value, decoder_segment_id.
"""
batch, sequence, heads, kv_head_size = key.shape
assert key.dtype == value.dtype, "Key and Value Dtypes should match."
cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache(
batch, heads, kv_head_size, self.quantize_kvcache
)
cached_ar_key_var, cached_ar_value_var, _, _ = self._get_ar_cache(batch, heads, kv_head_size, self.quantize_kvcache) # initialize it now
assert cached_prefill_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_cache_axis_order)
assert cached_prefill_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_cache_axis_order)
assert cached_ar_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_cache_axis_order)
assert cached_ar_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_cache_axis_order)
prefill_key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_cache_axis_order)
prefill_value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.prefill_cache_axis_order)
key_shaped_for_cache = self.reshape_kv_cache(key, self.prefill_cache_axis_order)
value_shaped_for_cache = self.reshape_kv_cache(value, self.prefill_cache_axis_order)
if self.quantize_kvcache:
key_shaped_for_cache, key_scale = quantizations.quantize_kv(key_shaped_for_cache, prefill_key_layout.index(CACHE_KV))
value_shaped_for_cache, value_scale = quantizations.quantize_kv(value_shaped_for_cache, prefill_value_layout.index(CACHE_KV))
cached_prefill_key_var[1].value = key_scale
cached_prefill_value_var[1].value = value_scale
cached_prefill_key_var[0].value = key_shaped_for_cache
cached_prefill_value_var[0].value = value_shaped_for_cache
if decoder_segment_ids is not None:
cached_prefill_segment_id.value = decoder_segment_ids
return key, value, decoder_segment_ids
def update_ar_key_value(
self,
one_token_key: Array,
one_token_value: Array,
cached_key_vars: tuple[nn.Variable, nn.Variable | None],
cached_value_vars: tuple[nn.Variable, nn.Variable | None],
one_hot_indices: Array,
) -> tuple[Array, Array]:
"""Adds a single token's results to the ar kv cache
Args:
one_token_key (Array): Key of one token to add to the cache
one_token_value (Array): Value of one token to add to the cache
cached_ar_key (tuple[nn.Variable, nn.Variable|None],): Cached keys to add new token key to, possibly with scale
cached_ar_value (tuple[nn.Variable, nn.Variable|None],: Cached values to add new token value to, possible with scale
one_hot_indices (Array): Location of the new token within the cache
Returns:
tuple[Array, Array]: Updated caches for key and value with new token info added
"""
cached_key_var, cached_key_scale_var = cached_key_vars
cached_value_var, cached_value_scale_var = cached_value_vars
# In order to update the key, value caches with the current key and
# value, we reshape the one_token_key and one_token_value
one_token_key_shaped_for_cache = self.reshape_kv_cache(one_token_key, self.ar_cache_axis_order)
one_token_value_shaped_for_cache = self.reshape_kv_cache(one_token_value, self.ar_cache_axis_order)
ar_key_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_cache_axis_order)
ar_value_layout = self.cached_kv_layout(self.kv_cache_logical_layout, self.ar_cache_axis_order)
if self.quantize_kvcache:
one_token_key_shaped_for_cache, one_token_key_scale = quantizations.quantize_kv(one_token_key_shaped_for_cache, ar_key_layout.index(CACHE_KV))
one_token_value_shaped_for_cache, one_token_value_scale = quantizations.quantize_kv(one_token_value_shaped_for_cache, ar_value_layout.index(CACHE_KV))
one_hot_indices = one_hot_indices.astype(int)
ar_key = cached_key_var.value
ar_key = jax.lax.dynamic_update_index_in_dim(ar_key, one_token_key_shaped_for_cache, jnp.squeeze(one_hot_indices), ar_key_layout.index(CACHE_SEQUENCE))
ar_key = nn.with_logical_constraint(
ar_key,
ar_key_layout
)
cached_key_var.value = ar_key
ar_value = cached_value_var.value
ar_value = jax.lax.dynamic_update_index_in_dim(ar_value, one_token_value_shaped_for_cache, jnp.squeeze(one_hot_indices), ar_key_layout.index(CACHE_SEQUENCE))
ar_value = nn.with_logical_constraint(
ar_value,
ar_value_layout,
)
cached_value_var.value = ar_value
if self.quantize_kvcache:
ar_key_scale = jax.lax.dynamic_update_index_in_dim(
cached_key_scale_var.value, one_token_key_scale, jnp.squeeze(one_hot_indices), ar_key_layout.index(CACHE_SEQUENCE)
)
ar_key_scale = nn.with_logical_constraint(
ar_key_scale,
ar_key_layout
)
ar_value_scale = jax.lax.dynamic_update_index_in_dim(
cached_value_scale_var.value, one_token_value_scale, jnp.squeeze(one_hot_indices), ar_key_layout.index(CACHE_SEQUENCE)
)
ar_value_scale = nn.with_logical_constraint(
ar_value_scale,
ar_value_layout
)
cached_key_scale_var.value = ar_key_scale
cached_value_scale_var.value = ar_value_scale
ar_key = quantizations.unquantize_kv(cached_key_var.value, cached_key_scale_var.value, one_token_key.dtype)
ar_value = quantizations.unquantize_kv(cached_value_var.value, cached_value_scale_var.value, one_token_value.dtype)
# Revert the keys and values back to original logical shapes.
return self.revert_kv_cache(ar_key, self.ar_cache_axis_order), self.revert_kv_cache(ar_value, self.ar_cache_axis_order)
def prefill_cache_var_model_var(self, cache_var, target_dtype, cache_axis_order):
if not self.quantize_kvcache:
return self.revert_kv_cache(cache_var[0].value, cache_axis_order)
else:
raw_cache, quant_scale = cache_var
raw_cache_unquantized = quantizations.unquantize_kv(raw_cache.value, quant_scale.value, target_dtype)
return self.revert_kv_cache(raw_cache_unquantized, cache_axis_order)
def kv_cache_autoregressive(
self,
key: Array,
value: Array,
):
"""In autoregressive mode, we update the cache for this entry and
then return the full cache.
Args:
key: in shape [b, 1, n, d].
value: in shape [b, 1, n, d].
decoder_segment_ids: [b, 1] -- marking segment ids for tokens
Returns:
tuple of (key, value, segment_id) for both prefill and ar cache,
Raises:
ValueError: when key/value shape is not [batch, 1, num_heads, heads_dim].
"""
batch, sequence, heads, kv_head_size = key.shape
if sequence != 1:
raise ValueError(f"Sequence length should be 1 during autoregression, got {sequence=}")
is_initialized = self.has_variable("cache", "cache_ar_index")
if not is_initialized:
raise ValueError("Error, we can't do autoregression if we haven't seeded the KV Cache.")
cached_ar_key_var, cached_ar_value_var, cached_ar_segment_id, cache_ar_index = self._get_ar_cache(
batch, heads, kv_head_size, self.quantize_kvcache
)
assert cached_ar_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_cache_axis_order)
assert cached_ar_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_target_length - self.max_prefill_predict_length, heads, kv_head_size), self.ar_cache_axis_order)
ar_key, ar_value = self.update_ar_key_value(key, value, cached_ar_key_var, cached_ar_value_var, cache_ar_index.value)
active_indicator = jnp.zeros((batch, 1), dtype=jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR
cached_ar_segment_id.value = jax.lax.dynamic_update_index_in_dim(
cached_ar_segment_id.value, active_indicator, jnp.squeeze(cache_ar_index.value), 1
)
cache_ar_index.value = jnp.mod(cache_ar_index.value + 1, self.max_target_length - self.max_prefill_predict_length)
# The below retrieves the existing prefill cache variables, not creating new ones
cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache(
batch, heads, kv_head_size, self.quantize_kvcache
)
assert cached_prefill_key_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_cache_axis_order)
assert cached_prefill_value_var[0].value.shape == self.cached_kv_shape((batch, self.max_prefill_predict_length, heads, kv_head_size), self.prefill_cache_axis_order)
cached_prefill = (
self.prefill_cache_var_model_var(cached_prefill_key_var, key.dtype, self.prefill_cache_axis_order),
self.prefill_cache_var_model_var(cached_prefill_value_var, value.dtype, self.prefill_cache_axis_order),
cached_prefill_segment_id.value,
)
return cached_prefill, (ar_key, ar_value, cached_ar_segment_id.value)
def kv_cache(self, key: Array, value: Array, decoder_segment_ids: Array, model_mode: str) -> tuple:
"""KV cache takes the current state and updates the state accordingly.
The key and value have dimension [b, s, n_kv, d],
but we cache them with a reshape as defined in *_axis_order config as a TPU
fusion optimization. This also enables the "scatter via one-hot
broadcast" trick, which means we do a one-hot broadcast instead of a
scatter/gather operations, resulting in a 3-4x speedup in practice.
Args:
key: in shape [b, s, n_kv, d].
value: in shape [b, s, n_kv, d].
model_mode: model mode controlling model
Returns:
two tuples of (k, v, decoder_segments) -- either can be Nones
"""
if key.shape != value.shape:
raise ValueError(f"Can't KV cache with mismatched shapes {key.shape=}, {value.shape=}")
if model_mode == common_types.MODEL_MODE_TRAIN:
return (key, value, decoder_segment_ids), None
elif model_mode == common_types.MODEL_MODE_PREFILL:
return self.kv_cache_prefill(key, value, decoder_segment_ids), None
elif model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
return self.kv_cache_autoregressive(key, value)
else:
raise ValueError(f"Model Mode isn't supported! {model_mode=}")
def normalize_attention(self, local_outs, local_maxes, local_sums):
"""Normalize across multiple localized attentions
Args:
local_outs (list): List of unnormalized outputs entries for each local attention
local_maxes (list): List of max exponentials entries for each local attention
local_sums (list): List of exponential sum entries for each local attention
Returns:
Array: Combined attention that has been normalized
"""
# Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py
global_max = functools.reduce(jnp.maximum, local_maxes)
global_sum = sum(
[jnp.exp(local_max - global_max) * local_sum for (local_sum, local_max) in zip(local_sums, local_maxes)]
)
attn_out = 0
for local_max, local_out in zip(local_maxes, local_outs):
local_normalizer = jnp.exp(local_max - global_max) / global_sum
attn_out += local_normalizer * local_out
return attn_out
@nn.compact
def __call__(self, query, key, value, decoder_segment_ids, model_mode):
prefill_kv_cache, ar_kv_cache = self.kv_cache(key, value, decoder_segment_ids, model_mode)
prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention(
query=query,
key=prefill_kv_cache[0],
value=prefill_kv_cache[1],
decoder_segment_ids=prefill_kv_cache[2],
model_mode=model_mode,
)
# Return the "prefill" cache if it actually the combined prefill+ar kv cache
if ar_kv_cache is None:
if prefill_exponentials_sum is not None:
return prefill_unnormalized_output / prefill_exponentials_sum
return prefill_unnormalized_output
ar_unnormalized_output, ar_exponentials_max, ar_exponentials_sum = self.apply_attention(
query=query,
key=ar_kv_cache[0],
value=ar_kv_cache[1],
decoder_segment_ids=ar_kv_cache[2],
model_mode=model_mode,
)
unnormalized_outputs = [prefill_unnormalized_output, ar_unnormalized_output]
exponentials_maxes = [prefill_exponentials_max, ar_exponentials_max]
exponentials_sums = [prefill_exponentials_sum, ar_exponentials_sum]
return self.normalize_attention(unnormalized_outputs, exponentials_maxes, exponentials_sums)
class Attention(nn.Module):
"""Generic Attention.
Attributes:
num_query_heads: number of query attention heads. Features (i.e. inputs_q.shape[-1])
should be divisible by the number of heads.
num_kv_heads: number of kv attention heads.
head_dim: dimension of each head.
mesh: Mesh, device mesh
attention_kernel: str, guidance on if we should use an attention kernel
dtype: the dtype of the computation.
weight_dtype: the dtype of the weights.
max_target_length: maximum target length
max_prefill_predict_length: size of the maximum prefill
dropout_rate: dropout rate
kernel_init: initializer for the kernel of the Dense layers.
float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid
numerical issues with bfloat16.
float32_logits: bool, if True then cast logits to float32 before softmax to avoid
numerical issues with bfloat16.
quant: Quant, stores quantization parameters, defaults to None implying no quantization.
quantize_kvcache: bool, quantize the kv cache.
"""
config: Config
num_query_heads: int
num_kv_heads: int
head_dim: int
max_target_length: int
mesh: Mesh
attention_kernel: str
dtype: DType = jnp.float32
weight_dtype: DType = jnp.float32
max_prefill_predict_length: int = -1
dropout_rate: float = 0.0
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal")
float32_qk_product: bool = False # computes logits in float32 for stability.
float32_logits: bool = False # cast logits in float32 for stability.
quant: Optional[Quant] = None
quantize_kvcache: bool = False
# Shard the query activation as the same as the key and value.
# TODO: Find a better sharding axis name.
query_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM)
key_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM)
value_axis_names: AxisNames = (KV_BATCH, LENGTH, KV_HEAD, KV_HEAD_DIM)
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV)
prefill_cache_axis_order: AxisIdxes = (1, 2, 0, 3)
ar_cache_axis_order: AxisIdxes = (1, 2, 0, 3)
compute_axis_order: AxisIdxes = (0, 1, 2, 3)
reshape_q: bool = False
def query_projection(self, inputs_q: Array) -> Array:
"""Query projection."""
# NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor.
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
def query_init(*args):
# pylint: disable=no-value-for-parameter
return self.kernel_init(*args) / depth_scaling
query_proj = DenseGeneral(
features=(self.num_query_heads, self.head_dim),
axis=-1,
kernel_init=query_init,
kernel_axes=("embed", "heads", "kv"),
dtype=self.dtype,
weight_dtype=self.weight_dtype,
name="query",
quant=self.quant,
)(inputs_q)
return query_proj
def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array:
"""Projection for Key and Value.
Args:
inputs_kv: inputs_kv: key/values of shape `[batch, kv_length,
num_kv_heads, kv_dim]`.
proj_name: name of projection, `key` or `value`.
Returns:
Projection of key or value, in shape of `[batch, kv_length, head_dim]`.
"""
if self.num_kv_heads == -1:
raise ValueError("num_kv_heads is not defined.")
if self.num_query_heads % self.num_kv_heads != 0:
raise ValueError("Invalid num_kv_heads for GQA.")
kernel_axes = ("embed", "kv_heads", "kv_head_dim")
kv_proj = DenseGeneral(
features=(self.num_kv_heads, self.head_dim),
axis=-1,
kernel_init=self.kernel_init,
kernel_axes=kernel_axes,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
name=proj_name,
quant=self.quant,
)(inputs_kv)
return kv_proj
def qkv_projection(self, inputs: Array, proj_name: str):
"""Fused QKV projection"""