Semantyka operacji

Poniżej opisujemy semantykę operacji zdefiniowanych w interfejsie XlaBuilder. Zwykle te operacje mapują jeden do jednego na operacje zdefiniowane w interfejsie RPC w xla_data.proto.

Uwaga na temat nazewnictwa: uogólniony typ danych XLA dotyczy n-wymiarowej tablicy przechowującej elementy jakiegoś typu jednolitego (na przykład 32-bitową liczbę zmiennoprzecinkową). W dokumentacji argument tablica jest używany do oznaczania tablicy o dowolnym wymiarze. Dla wygody specjalne przypadki mają bardziej szczegółowe i znane nazwy. Na przykład wektor to jednowymiarowa tablica, a macierz to tablica dwuwymiarowa.

AfterAll

Zobacz też XlaBuilder::AfterAll.

AfterAll pobiera zmienną liczbę tokenów i generuje jeden token. Tokeny są typami podstawowymi, które można łączyć w wątki między operacjami ubocznymi, aby wymusić kolejność. AfterAll może służyć do łączenia tokenów w celu zamawiania operacji po operacjach zestawu.

AfterAll(operands)

Argumenty Typ Semantyka
operands XlaOp zmienna liczba tokenów

AllGather

Zobacz też XlaBuilder::AllGather.

Przeprowadza konkatenację w replikach.

AllGather(operand, all_gather_dim, shard_count, replica_group_ids, channel_id)

Argumenty Typ Semantyka
operand XlaOp Tablica konkatenacji między replikami
all_gather_dim int64 Wymiar łączenia
replica_groups wektor wektorów int64 Grupy, między którymi odbywa się łączenie,
channel_id opcjonalnie: int64 Opcjonalny identyfikator kanału na potrzeby komunikacji między modułami
  • replica_groups to lista grup replik, między którymi odbywa się konkatenacja (identyfikator repliki bieżącej repliki można pobrać za pomocą ReplicaId). Kolejność replik w każdej grupie określa kolejność, w której dane wejściowe znajdują się w wyniku. replica_groups musi być pusta (w takim przypadku wszystkie repliki należą do jednej grupy, uporządkowanej od 0 do N - 1) lub zawierać taką samą liczbę elementów jak liczba replik. Na przykład replica_groups = {0, 2}, {1, 3} wykonuje łączenie replik 0 i 2 oraz 1 i 3.
  • shard_count to rozmiar każdej grupy replik. Jest on potrzebny w sytuacjach, gdy pole replica_groups jest puste.
  • channel_id służy do komunikacji między modułami: tylko operacje all-gather z tym samym atrybutem channel_id mogą się ze sobą komunikować.

Kształt wyjściowy to kształt wejściowy, w którym element all_gather_dim został powiększony shard_count razy. Jeśli na przykład są 2 repliki, a operand w 2 replikach ma wartość [1.0, 2.5] i [3.0, 5.25], wartość wyjściowa tej operacji, gdzie all_gather_dim to 0, będzie wynosić [1.0, 2.5, 3.0, 5.25] w obu replikach.

AllReduce

Zobacz też XlaBuilder::AllReduce.

Przeprowadza niestandardowe obliczenia w replikach.

AllReduce(operand, computation, replica_group_ids, channel_id)

Argumenty Typ Semantyka
operand XlaOp Tablica lub niepusta krotka tablic, która ogranicza liczbę w replikach,
computation XlaComputation Obliczanie redukcji
replica_groups wektor wektorów int64 Grupy, w których są stosowane redukcje
channel_id opcjonalnie: int64 Opcjonalny identyfikator kanału na potrzeby komunikacji między modułami
  • Gdy operand jest kropką tablicy, w każdym jej elemencie wykonywane jest wszystkie obliczenia.
  • replica_groups to lista grup replik, w których następuje zmniejszenie (identyfikator repliki bieżącej repliki można pobrać za pomocą narzędzia ReplicaId). replica_groups musi być pusty (w takim przypadku wszystkie repliki należą do jednej grupy) lub zawierać taką samą liczbę elementów jak liczba replik. Na przykład replica_groups = {0, 2}, {1, 3} zmniejsza liczbę replik 0 i 2 oraz 1 i 3.
  • channel_id służy do komunikacji między modułami: tylko operacje all-reduce z tym samym atrybutem channel_id mogą się ze sobą komunikować.

Kształt wyjściowy jest taki sam jak kształt wejściowy. Jeśli na przykład są 2 repliki, a operand ma w każdej z nich wartości [1.0, 2.5] i [3.0, 5.25], w obu replikach wartość wyjściowa tej operacji i obliczenia sumy będą wynosić [4.0, 7.75]. Jeśli dane wejściowe są kropką, dane wyjściowe też są kropką.

Obliczenie wyniku AllReduce wymaga 1 danych wejściowych z każdej repliki, więc jeśli jedna replika uruchomi węzeł AllReduce więcej razy niż drugi, dotychczasowa replika będzie czekać w nieskończoność. Ponieważ wszystkie repliki działają w ramach tego samego programu, nie ma wielu sposobów, aby tak się stało. Jest jednak możliwe, gdy stan pętli podczas wykonywania zależy od danych z InFeed, a dodane dane powodują, że pętla podczas wykonywania powtarza się więcej razy w przypadku jednej repliki niż w drugiej.

AllToAll

Zobacz też XlaBuilder::AllToAll.

AllToAll to operacja zbiorcza, która wysyła dane ze wszystkich rdzeni do wszystkich rdzeni. Składa się z 2 faz:

  1. Faza punktowa. W każdym rdzeniu operand jest dzielony na split_count bloków wzdłuż linii split_dimensions, a bloki są rozłożone na wszystkie rdzenie, np. i-ty blok jest wysyłany do i-tego rdzenia.
  2. Faza gromadzenia. Każdy rdzeń łączy odebrane bloki wzdłuż concat_dimension.

Podstawowe rdzenie można skonfigurować:

  • replica_groups: każda grupa ReplicaGroup zawiera listę identyfikatorów replik uczestniczących w obliczeniach (identyfikator repliki bieżącej repliki można pobrać za pomocą ReplicaId). Opcja AllToAll zostanie zastosowana w podgrupach w określonej kolejności. Na przykład replica_groups = { {1,2,3}, {4,5,0} } oznacza, że element AllToAll zostanie zastosowany w replikach {1, 2, 3} na etapie zbierania danych, a odebrane bloki będą łączone w tej samej kolejności 1, 2, 3. Następnie w replikach 4, 5, 0 zostanie zastosowany kolejny parametr AllToAll, a kolejność konkatenacji również wynosi 4, 5, 0. Jeśli replica_groups jest pusty, wszystkie repliki należą do jednej grupy w kolejności konkatenacji ich wyglądu.

Wymagania wstępne:

  • Rozmiar operandu w elemencie split_dimension jest podzielny przez split_count.
  • Kształt operandu nie jest kropką.

AllToAll(operand, split_dimension, concat_dimension, split_count, replica_groups)

Argumenty Typ Semantyka
operand XlaOp n-wymiarowa tablica wejściowa
split_dimension int64 Wartość w przedziale [0, n) określająca wymiar, którego operand jest podzielony
concat_dimension int64 Wartość w przedziale [0, n) określająca wymiar, do którego są łączone podzielone bloki
split_count int64 Liczba rdzeni używanych w tej operacji. Jeśli właściwość replica_groups jest pusta, powinna to być liczba replik. W przeciwnym razie powinna być równa liczbie replik w każdej grupie.
replica_groups Wektor ReplicaGroup Każda grupa zawiera listę identyfikatorów replik.

Poniżej znajdziesz przykład Alltoall.

XlaBuilder b("alltoall");
auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4);

W tym przykładzie wszystkie 4 rdzenie są częścią zasobu Alltoall. W każdym rdzeniu operand jest podzielony na 4 części wzdłuż wymiaru 0, co oznacza, że każda część ma kształt f32[4,4]. Te 4 części są rozłożone na wszystkie rdzenie. Następnie każdy rdzeń łączy otrzymane części w wymiarze 1 w kolejności rdzeni 0–4. Dane wyjściowe na każdym rdzeniu mają więc kształt f32[16,4].

BatchNormGrad

Szczegółowy opis algorytmu znajdziesz też w XlaBuilder::BatchNormGrad i pierwotnej dokumentacji do normalizacji wsadowej.

Oblicza gradienty normy wsadowej.

BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)

Argumenty Typ Semantyka
operand XlaOp n tablica wymiarowa do znormalizowania (x)
scale XlaOp 1 tablica wymiarowa (\(\gamma\))
mean XlaOp 1 tablica wymiarowa (\(\mu\))
variance XlaOp 1 tablica wymiarowa (\(\sigma^2\))
grad_output XlaOp Gradienty przekazane do: BatchNormTraining (\(\nabla y\))
epsilon float Wartość epsilonu (\(\epsilon\))
feature_index int64 Indeks do wymiaru cechy w: operand

W przypadku każdej cechy w wymiarze cech (feature_index to indeks wymiaru obiektu w argumencie operand) operacja oblicza gradienty z uwzględnieniem operand, offset i scale wszystkich pozostałych wymiarów. Wartość feature_index musi być prawidłowym indeksem dla wymiaru cechy w polu operand.

Trzy gradienty są zdefiniowane za pomocą następujących wzorów (przy założeniu, że czterowymiarowa tablica ma postać operand, indeks wymiarów cech l, rozmiar wsadu m oraz rozmiary przestrzenne w i h):

\[ \begin{split} c_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right) \\\\ d_l&= \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \\\\ \nabla x_{ijkl} &= \frac{\gamma_{l} }{\sqrt{\sigma^2_{l}+\epsilon} } \left( \nabla y_{ijkl} - d_l - c_l (x_{ijkl} - \mu_{l}) \right) \\\\ \nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon} } \right) \\\\\ \nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \end{split} \]

Dane wejściowe mean i variance reprezentują wartości momentów w wymiarach wsadowych i przestrzennych.

Typ wyjściowy to krotka z 3 nickami:

Wyjście Typ Semantyka
grad_operand XlaOp gradientu względem danych wejściowych operand ($\nabla x$)
grad_scale XlaOp gradientu względem danych wejściowych scale ($\nabla\gamma$)
grad_offset XlaOp gradient w odniesieniu do danych wejściowych offset($\nabla \beta$)

BatchNormInference

Szczegółowy opis algorytmu znajdziesz też w XlaBuilder::BatchNormInference i pierwotnej dokumentacji do normalizacji wsadowej.

Normalizuje tablicę w wymiarach wsadowych i przestrzennych.

BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)

Argumenty Typ Semantyka
operand XlaOp n tablica wymiarowa do znormalizowania
scale XlaOp 1 tablica wymiarowa
offset XlaOp 1 tablica wymiarowa
mean XlaOp 1 tablica wymiarowa
variance XlaOp 1 tablica wymiarowa
epsilon float Wartość epsilonu
feature_index int64 Indeks do wymiaru cechy w: operand

Dla każdej cechy w wymiarze cech (feature_index to indeks wymiaru cech w elemencie operand) operacja oblicza średnią i wariancję wszystkich pozostałych wymiarów, a następnie wykorzystuje średnią i wariancję do normalizacji każdego elementu w obiekcie operand. Wartość feature_index musi być prawidłowym indeksem dla wymiaru cechy w tabeli operand.

Funkcja BatchNormInference jest odpowiednikiem wywoływania metody BatchNormTraining bez obliczania wartości mean i variance dla każdej grupy. Używa danych wejściowych mean i variance jako wartości szacunkowych. Celem tej operacji jest skrócenie czasu oczekiwania na wnioskowanie, stąd nazwa BatchNormInference.

Wynikiem jest n-wymiarowa, znormalizowana tablica o tym samym kształcie co dane wejściowe operand.

BatchNormTraining

Szczegółowy opis algorytmu znajdziesz też w sekcjach XlaBuilder::BatchNormTraining i the original batch normalization paper.

Normalizuje tablicę w wymiarach wsadowych i przestrzennych.

BatchNormTraining(operand, scale, offset, epsilon, feature_index)

Argumenty Typ Semantyka
operand XlaOp n tablica wymiarowa do znormalizowania (x)
scale XlaOp 1 tablica wymiarowa (\(\gamma\))
offset XlaOp 1 tablica wymiarowa (\(\beta\))
epsilon float Wartość epsilonu (\(\epsilon\))
feature_index int64 Indeks do wymiaru cechy w: operand

Dla każdej cechy w wymiarze cech (feature_index to indeks wymiaru cech w elemencie operand) operacja oblicza średnią i wariancję wszystkich pozostałych wymiarów, a następnie wykorzystuje średnią i wariancję do normalizacji każdego elementu w obiekcie operand. Wartość feature_index musi być prawidłowym indeksem dla wymiaru cechy w tabeli operand.

Algorytm w przypadku każdej wsadu operand \(x\) zawiera m elementy, w których rozmiarem wymiarów przestrzennych są w i h (przy założeniu, że operand jest tablicą 4-wymiarową):

  • Oblicza średnią wsadową \(\mu_l\) dla każdej cechy l w wymiarze cech:\(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\)

  • Oblicza wariancję wsadu \(\sigma^2_l\): $\sigma^2l=\frac{1}{mwh}\sum{i=1}^m\sum{j=1}^w\sum{k=1}^h (x_{ijkl} - \mu_l)^2$

  • Normalizuje, skaluje i przesuwa: \(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon} }+\beta_l\)

Dodaje się wartość epsilonu, zwykle małą liczbę, aby uniknąć błędów dzielenia przez 0.

Typem wyjściowym jest krotka z 3 elementami XlaOp:

Wyjście Typ Semantyka
output XlaOp n tablica wymiarowa o tym samym kształcie co dane wejściowe operand (y)
batch_mean XlaOp 1 tablica wymiarowa (\(\mu\))
batch_var XlaOp 1 tablica wymiarowa (\(\sigma^2\))

batch_mean i batch_var to momenty obliczone dla wsadów i wymiarów przestrzennych na podstawie powyższych wzorów.

BitcastConvertType

Zobacz też XlaBuilder::BitcastConvertType.

Podobnie jak tf.bitcast w TensorFlow, wykonuje operację bitcastu z wykorzystaniem elementu – od kształtu danych do docelowego kształtu. Rozmiar danych wejściowych i wyjściowych musi być taki sam, np. elementy s32 stają się elementami f32 przez procedurę bitcastu, a 1 element s32 zmienia się w 4 elementy s8. Bitcast jest implementowany jako niskopoziomowy typ przesyłania, dlatego maszyny z różnymi reprezentacjami zmiennoprzecinkowymi dają różne wyniki.

BitcastConvertType(operand, new_element_type)

Argumenty Typ Semantyka
operand XlaOp tablica typu T z przyciemnieniem D
new_element_type PrimitiveType typ U

Wymiary operandu i docelowego kształtu muszą być zgodne z wyjątkiem ostatniego wymiaru, który zmieni się ze względu na współczynnik rozmiaru podstawowego przed konwersją i po niej.

Typy elementów źródłowych i docelowych nie mogą być krotkami.

Konwertowanie bitcastu na typ podstawowy o różnej szerokości

BitcastConvert Instrukcja HLO obsługuje przypadek, w którym rozmiar elementu wyjściowego typu T' nie jest równy rozmiarowi elementu wejściowego T. Cała operacja jest koncepcyjnie oparta na bitcastach i nie zmienia bazowych bajtów, więc kształt elementu wyjściowego musi się zmienić. W przypadku usługi B = sizeof(T), B' = sizeof(T') możliwe są 2 przypadki.

Po pierwsze, gdy parametr B > B' ma kształt wyjściowy, otrzymuje nowy, mniejszy wymiar rozmiaru B/B'. Na przykład:

  f16[10,2]{1,0} %output = f16[10,2]{1,0} bitcast-convert(f32[10]{0} %input)

W przypadku skutecznych skalarów reguła pozostaje taka sama:

  f16[2]{0} %output = f16[2]{0} bitcast-convert(f32[] %input)

Z kolei w przypadku funkcji B' > B instrukcja wymaga, aby ostatni wymiar logiczny wejściowego kształtu był równy B'/B, a ten wymiar jest pomijany podczas konwersji:

  f32[10]{0} %output = f32[10]{0} bitcast-convert(f16[10,2]{1,0} %input)

Pamiętaj, że konwersje między różnymi szerokościami bitowymi nie zależą od elementu.

Komunikaty

Zobacz też XlaBuilder::Broadcast.

Dodaje wymiary do tablicy, duplikując zawarte w niej dane.

Broadcast(operand, broadcast_sizes)

Argumenty Typ Semantyka
operand XlaOp Tablica do duplikowania
broadcast_sizes ArraySlice<int64> Rozmiary nowych wymiarów

Nowe wymiary są wstawiane po lewej stronie, tj. jeśli broadcast_sizes ma wartości {a0, ..., aN}, a kształt argumentu ma wymiary {b0, ..., bM}, kształt danych wyjściowych ma wymiary {a0, ..., aN, b0, ..., bM}.

Nowe wymiary są indeksowane do kopii argumentu, tj.

output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]

Jeśli na przykład operand jest skalarnym f32 o wartości 2.0f, a broadcast_sizes ma wartość {2, 3}, wynik będzie tablicą z kształtem f32[2, 3], a wszystkie wartości w wyniku będą wynosić 2.0f.

BroadcastInDim

Zobacz też XlaBuilder::BroadcastInDim.

Rozszerza rozmiar i pozycję tablicy, duplikując w niej dane.

BroadcastInDim(operand, out_dim_size, broadcast_dimensions)

Argumenty Typ Semantyka
operand XlaOp Tablica do duplikowania
out_dim_size ArraySlice<int64> Rozmiary wymiarów docelowego kształtu
broadcast_dimensions ArraySlice<int64> Któremu wymiarowi w kształcie docelowym odpowiada każdy wymiar w kształcie operandu?

Działa podobnie jak transmisja, ale umożliwia dodawanie wymiarów w dowolnym miejscu i rozszerzanie istniejących wymiarów o rozmiar 1.

Moduł operand jest transmitowany do kształtu opisanego przez out_dim_size. Funkcja broadcast_dimensions mapuje wymiary elementu operand na wymiary docelowego kształtu, tzn. i-ty wymiar operandu jest mapowany na wymiar kształtu wyjściowego. Wymiary obiektu operand muszą mieć rozmiar 1 lub być identyczne z wymiarem w kształcie wyjściowym, do którego są zmapowane. Pozostałe wymiary zostaną wypełnione wymiarami rozmiaru 1. Wiadomość o zdegenerowanych wymiarach będzie przesyłana wzdłuż tych zdegenerowanych wymiarów, aby osiągnąć kształt wyjściowy. Semantyka jest szczegółowo opisana na stronie transmisji.

Połączenie

Zobacz też XlaBuilder::Call.

Wywołuje obliczenie przy użyciu podanych argumentów.

Call(computation, args...)

Argumenty Typ Semantyka
computation XlaComputation obliczenia typu T_0, T_1, ..., T_{N-1} -> S przy użyciu N parametrów dowolnego typu
args sekwencja N XlaOp s N argumentów dowolnego typu

Argument i typy właściwości args muszą być zgodne z parametrami tagu computation. Nie może zawierać żadnych elementów typu args.

Kraj choleski

Zobacz też XlaBuilder::Cholesky.

Oblicza rozkład Cholesky’ego w grupie macierzy o określonych symetrycznych (hermitach) dodatnich.

Cholesky(a, lower)

Argumenty Typ Semantyka
a XlaOp tablicy rangowej > 2 w przypadku typu złożonego lub zmiennoprzecinkowego.
lower bool czy użyć górnego czy dolnego trójkąta a.

Jeśli lower to true, oblicza macierze dolne trójkątów l w taki sposób, że $a = l . l^T$. Jeśli lower ma wartość false, oblicza macierze górnego trójkąta u w taki sposób, że\(a = u^T . u\).

Dane wejściowe są odczytywane tylko z trójkąta dolnego/górnego a w zależności od wartości lower. Wartości z drugiego trójkąta są ignorowane. Dane wyjściowe są zwracane w tym samym trójkącie. Wartości w drugim trójkącie są zdefiniowane przez implementację i mogą być dowolne.

Jeśli pozycja a jest większa niż 2, a jest traktowana jako zbiór macierzy, gdzie wszystkie oprócz 2 dodatkowych wymiarów są wymiarami wsadu.

Jeśli definicja a nie jest symetryczna (hermita) wartością dodatnią, wynik jest zdefiniowany przez implementację.

Z klipsem

Zobacz też XlaBuilder::Clamp.

Łączy operand między wartością minimalną a maksymalną.

Clamp(min, operand, max)

Argumenty Typ Semantyka
min XlaOp tablica typu T
operand XlaOp tablica typu T
max XlaOp tablica typu T

Biorąc pod uwagę operand oraz wartości minimalną i maksymalną, zwraca operand, jeśli mieści się w zakresie między wartością minimalną a maksymalną, w przeciwnym razie zwraca wartość minimalną, jeśli operand jest poniżej tego zakresu, lub wartość maksymalną, jeśli operand przekracza ten zakres. Czyli clamp(a, x, b) = min(max(a, x), b).

Wszystkie trzy tablice muszą mieć ten sam kształt. W przypadku ograniczonej formy transmisji min lub max mogą być skalarem typu T.

Przykład ze skalarnymi wartościami min i max:

let operand: s32[3] = {-1, 5, 9};
let min: s32 = 0;
let max: s32 = 6;
==>
Clamp(min, operand, max) = s32[3]{0, 5, 6};

Zwiń

Zobacz też XlaBuilder::Collapse i operację tf.reshape.

Zwija wymiary tablicy do jednego wymiaru.

Collapse(operand, dimensions)

Argumenty Typ Semantyka
operand XlaOp tablica typu T
dimensions Wektor int64 w kolejności, w kolejności podzbioru wymiarów T.

Zwijanie zastępuje dany podzbiór wymiarów operandu jednym wymiarem. Argumenty wejściowe to dowolna tablica typu T i stały wektor indeksów wymiarów w czasie kompilacji. Indeksy wymiarów muszą być ustalone w odpowiedniej kolejności (liczby od najmniejszej liczby wymiarów) i stanowić się w kolejnym podzbiorze wymiarów T. Zatem {0, 1, 2}, {0, 1} i {1, 2} są prawidłowymi zestawami wymiarów, ale {1, 0} ani {0, 2} już nie. Są one zastępowane jednym nowym wymiarem, które znajdują się w tej samej pozycji w sekwencji wymiarów, które zastępują, oraz z nowym rozmiarem równym iloczynowi rozmiarów pierwotnych wymiarów. Najniższa liczba wymiarów w elemencie dimensions to najwolniejszy wymiar zmieniający się (największy) w zagnieżdżeniu pętli, który zwija te wymiary. Najwyższa liczba wymiarów zmienia się najszybciej (w większości mniejszych). Jeśli potrzebna jest bardziej ogólna kolejność zwijania, zapoznaj się z operatorem tf.reshape.

Na przykład niech v będzie tablicą 24 elementów:

let v = f32[4x2x3] { { {10, 11, 12},  {15, 16, 17} },
{ {20, 21, 22},  {25, 26, 27} },
{ {30, 31, 32},  {35, 36, 37} },
{ {40, 41, 42},  {45, 46, 47} } };

// Collapse to a single dimension, leaving one dimension.
let v012 = Collapse(v, {0,1,2});
then v012 == f32[24] {10, 11, 12, 15, 16, 17,
20, 21, 22, 25, 26, 27,
30, 31, 32, 35, 36, 37,
40, 41, 42, 45, 46, 47};

// Collapse the two lower dimensions, leaving two dimensions.
let v01 = Collapse(v, {0,1});
then v01 == f32[4x6] { {10, 11, 12, 15, 16, 17},
{20, 21, 22, 25, 26, 27},
{30, 31, 32, 35, 36, 37},
{40, 41, 42, 45, 46, 47} };

// Collapse the two higher dimensions, leaving two dimensions.
let v12 = Collapse(v, {1,2});
then v12 == f32[8x3] { {10, 11, 12},
{15, 16, 17},
{20, 21, 22},
{25, 26, 27},
{30, 31, 32},
{35, 36, 37},
{40, 41, 42},
{45, 46, 47} };

CollectivePermute

Zobacz też XlaBuilder::CollectivePermute.

CollectivePermute to operacja zbiorcza, która wysyła i odbiera krzyżowe repliki danych.

CollectivePermute(operand, source_target_pairs)

Argumenty Typ Semantyka
operand XlaOp n-wymiarowa tablica wejściowa
source_target_pairs Wektor <int64, int64> Lista par (source_replica_id, target_replica_id). Dla każdej pary operand jest wysyłany z repliki źródłowej do repliki docelowej.

Pamiętaj, że source_target_pair podlega tym ograniczeniom:

  • Dowolne 2 pary nie powinny mieć tego samego identyfikatora repliki docelowej i nie powinny mieć tego samego identyfikatora repliki źródłowej.
  • Jeśli identyfikator repliki nie jest elementem docelowym w żadnej parze, dane wyjściowe tej repliki to tensor składający się z 0 o takim samym kształcie jak dane wejściowe.

Połącz

Zobacz też XlaBuilder::ConcatInDim.

Conkatenacja tworzy tablicę z wielu operandów tablicy. Tablica ma tę samą pozycję co każdy operand tablicy wejściowej (musi mieć tę samą pozycję co pozostałe) i zawiera argumenty w takiej kolejności, w jakiej zostały określone.

Concatenate(operands..., dimension)

Argumenty Typ Semantyka
operands sekwencja N XlaOp N tablic typu T o wymiarach [L0, L1, ...]. Wymagana wartość N >= 1.
dimension int64 Wartość z przedziału [0, N) określająca wymiar, który ma zostać połączony z komponentem operands.

Wszystkie wymiary z wyjątkiem wartości dimension muszą być takie same. Dzieje się tak, ponieważ XLA nie obsługuje „obciążonych” tablic. Pamiętaj też, że wartości rankingu 0 nie mogą być łączone (ponieważ nie można nazwać wymiaru, w którym następuje konkatenacja).

Przykład jednowymiarowego:

Concat({ {2, 3}, {4, 5}, {6, 7} }, 0)
>>> {2, 3, 4, 5, 6, 7}

Przykład obrazu dwuwymiarowego:

let a = {
{1, 2},
{3, 4},
{5, 6},
};
let b = {
{7, 8},
};
Concat({a, b}, 0)
>>> {
{1, 2},
{3, 4},
{5, 6},
{7, 8},
}

Schemat:

Warunkowy

Zobacz też XlaBuilder::Conditional.

Conditional(pred, true_operand, true_computation, false_operand, false_computation)

Argumenty Typ Semantyka
pred XlaOp Skala typu PRED
true_operand XlaOp Argument typu \(T_0\)
true_computation XlaComputation XlaComputation typu \(T_0 \to S\)
false_operand XlaOp Argument typu \(T_1\)
false_computation XlaComputation XlaComputation typu \(T_1 \to S\)

Wykonuje polecenie true_computation, jeśli pred ma wartość true, lub false_computation, jeśli pred ma wartość false, i zwraca wynik.

true_computation musi przyjmować pojedynczy argument typu \(T_0\) i zostanie wywołana z funkcją true_operand, która musi być tego samego typu. Element false_computation musi przyjmować pojedynczy argument typu \(T_1\) i jest wywoływany za pomocą funkcji false_operand, która musi być tego samego typu. Typ zwróconej wartości true_computation i false_computation musi być taki sam.

Pamiętaj, że w zależności od wartości pred zostanie wykonane tylko jedno z działań true_computation i false_computation.

Conditional(branch_index, branch_computations, branch_operands)

Argumenty Typ Semantyka
branch_index XlaOp Skala typu S32
branch_computations sekwencja N XlaComputation XlaComputations typu \(T_0 \to S , T_1 \to S , ..., T_{N-1} \to S\)
branch_operands sekwencja N XlaOp Argumenty typu \(T_0 , T_1 , ..., T_{N-1}\)

Wykonuje polecenie branch_computations[branch_index] i zwraca wynik. Jeśli branch_index to S32, który ma wartość < 0 lub >= N, jako gałąź domyślną wykonywany jest branch_computations[N-1].

Każdy element branch_computations[b] musi przyjmować 1 argument typu \(T_b\) i jest wywoływany z funkcją branch_operands[b], która musi być tego samego typu. Typ zwracanej wartości każdej branch_computations[b] musi być taki sam.

Pamiętaj, że w zależności od wartości branch_index zostanie wykonany tylko jeden z elementów branch_computations.

Konw (splot)

Zobacz też XlaBuilder::Conv.

W zasadzie ConvWithGeneralPadding, ale dopełnienie jest określane w skrócie jako SAME lub VALID. SAME dopełnienie wypełnia dane wejściowe (lhs) zerami, aby dane wyjściowe miały taki sam kształt jak dane wejściowe, gdy nie są brane pod uwagę. PRAWIDŁOWE dopełnienie oznacza po prostu brak dopełnienia.

ConvWithGeneralPadding (splot)

Zobacz też XlaBuilder::ConvWithGeneralPadding.

Oblicza splot takiego rodzaju jak w sieciach neuronowych. Splot można tu traktować jako okno n-wymiarowe poruszające się po n-wymiarowym obszarze podstawy, a obliczenia wykonywane są dla każdej możliwej pozycji okna.

Argumenty Typ Semantyka
lhs XlaOp tablica danych wejściowych rankingu n+2
rhs XlaOp ranking n+2 wag jądra systemu
window_strides ArraySlice<int64> Tablica n-d kroków jądra
padding ArraySlice< pair<int64,int64>> Tablica n-d z dopełnieniem (niskie, wysokie)
lhs_dilation ArraySlice<int64> n-d lhs tablica współczynników rozszerzania
rhs_dilation ArraySlice<int64> macierz współczynnika dylatacji n-d rhs
feature_group_count int64 liczbę grup cech
batch_group_count int64 liczbę grup wsadowych,

Niech n będzie liczbą wymiarów przestrzennych. Argument lhs to tablica ran+ n+2 opisująca pole podstawy. Nazywamy to danymi wejściowymi, chociaż prawo też jest wartością wejściową. W sieci neuronowej są to aktywacje wejściowe. Wymiary n+2 to w tej kolejności:

  • batch: każda współrzędna w tym wymiarze stanowi niezależne dane wejściowe, dla których przeprowadzany jest splot.
  • z/depth/features: z każdą pozycją (y,x) w obszarze bazowym jest powiązany wektor, który jest przesyłany do tego wymiaru.
  • spatial_dims: opisuje wymiary przestrzenne (n) określające obszar bazowy, przez który przechodzi okno.

Argument rhs to tablica rangi n+2 opisująca filtr splotowy/jądro/okno. Wymiary są podane w tej kolejności:

  • output-z: wymiar danych wyjściowych (z).
  • input-z: rozmiar tego wymiaru pomnożony przez feature_group_count powinien być równy rozmiarowi z w lh.
  • spatial_dims: opisuje wymiary przestrzenne (n) określające okno n-d, które porusza się po obszarze bazowym.

Argument window_strides określa krok okna splotowego w wymiarach przestrzennych. Jeśli np. krok w pierwszym wymiarze przestrzennym wynosi 3, okno można umieścić tylko we współrzędnych, w których pierwszy indeks przestrzenny jest podzielny przez 3.

Argument padding określa zero dopełnienia, które ma być zastosowane do obszaru podstawowego. Wielkość dopełnienia może być ujemna – wartość bezwzględna ujemnego dopełnienia wskazuje liczbę elementów, które należy usunąć z określonego wymiaru przed wykonaniem splotu. padding[0] określa dopełnienie wymiaru y, a padding[1] – dopełnienie wymiaru x. W przypadku każdej pary pierwszy element ma niskie dopełnienie, a drugi – duże. Niskie dopełnienie jest stosowane w kierunku dolnych indeksów, a wysokie – w kierunku wyższych indeksów. Jeśli np. parametr padding[1] ma wartość (2,3), w drugim wymiarze przestrzennym pojawi się dopełnienie o 2 0 z lewej strony i 3 zera po prawej stronie. Użycie dopełnienia jest równoważne wstawieniu do danych wejściowych (lhs) tych samych wartości zerowych przed wykonaniem splotu.

Argumenty lhs_dilation i rhs_dilation określają współczynnik dylatacji, który zostanie zastosowany do wartości lhs i rhs w każdym wymiarze przestrzennym. Jeśli współczynnik rozszerzania w wymiarze przestrzennym wynosi d, otwory d-1 są domyślnie umieszczone między poszczególnymi elementami w tym wymiarze, co zwiększa rozmiar tablicy. Otwory są wypełniane wartością „brak operacji”, co w przypadku splotu oznacza zero.

Rozszerzanie się rhs jest również nazywane splotem zapięcia. Więcej informacji: tf.nn.atrous_conv2d. Rozszerzanie komórek jest też nazywane splotem transponowanym. Więcej informacji: tf.nn.conv2d_transpose.

Argument feature_group_count (wartość domyślna 1) może być używany do zgrupowanych rozwinięć. Funkcja feature_group_count musi być dzielnikiem zarówno danych wejściowych, jak i wyjściowych. Jeśli feature_group_count ma wartość większą niż 1, oznacza to, że koncepcyjnie wymiar cech wejściowych i wyjściowych oraz wymiar cechy wyjściowej rhs są dzielone po równo na wiele grup feature_group_count, z których każda składa się z kolejnej części cech. Wymiar cechy wejściowej rhs musi być równy wymiarowi cechy wejściowej lhs podzielonej przez feature_group_count (więc ma już rozmiar grupy cech wejściowych). Grupy i-te są używane razem do obliczania feature_group_count dla wielu oddzielnych splotów. Wyniki tych splotów są łączone w wymiarze cech wyjściowych.

W przypadku splotu głębokiego argument feature_group_count zostałby ustawiony na wymiar cechy wejściowej, a kształt filtra zmieniłby się z [filter_height, filter_width, in_channels, channel_multiplier] na [filter_height, filter_width, 1, in_channels * channel_multiplier]. Więcej informacji: tf.nn.depthwise_conv2d.

Argument batch_group_count (wartość domyślna 1) może być używany do zgrupowanych filtrów podczas propagacji wstecznej. Wartość batch_group_count musi być dzielnikiem rozmiaru w wymiarze wsadowym lhs (wejściowy). Jeśli batch_group_count ma wartość większą niż 1, oznacza to, że rozmiar wyjściowego wsadu powinien mieć rozmiar input batch / batch_group_count. Element batch_group_count musi być dzielnikiem rozmiaru cechy wyjściowej.

Kształt wyjściowy ma te wymiary w tej kolejności:

  • batch: rozmiar tego wymiaru pomnożony przez batch_group_count powinien być równy rozmiarowi wymiaru batch w l.
  • z: ten sam rozmiar co output-z w jądrze (rhs).
  • spatial_dims: po 1 wartości na każde prawidłowe miejsce docelowe okna konwolucyjnego.

Na ilustracji powyżej pokazujemy, jak działa pole batch_group_count. W efekcie dzielimy każdą wsadę na grupy batch_group_count i to samo robimy z funkcjami wyjściowymi. Następnie w przypadku każdej z tych grup tworzymy pary splotów i łączymy dane wyjściowe razem z wymiarem cech wyjściowych. Semantyka operacyjna pozostałych wymiarów (funkcji i przestrzennego) jest taka sama.

Prawidłowe rozmieszczenie okna splotowego zależy od liczby kroków i rozmiaru obszaru podstawowego po dopełnieniu.

Aby opisać działanie splotu, weź pod uwagę splot dwuwymiarowy i wybierz stałe współrzędne batch, z, y i x w wynikach. Następnie (y,x) określa położenie narożnika okna w obszarze podstawy (np. lewy górny róg, zależnie od interpretacji wymiarów przestrzennych). Mamy teraz okno 2D, wyjęte z obszaru podstawowego, gdzie każdy punkt 2D jest powiązany z wektorem 1d. W efekcie otrzymujemy prostokąt 3D. W przypadku jądra splotowego, ponieważ poprawiliśmy współrzędną wyjściową z, mamy również pole 3D. Te 2 pudełka mają te same wymiary, więc możemy obliczyć sumę iloczynów obu pudełek (podobnie do iloczynu skalarnego). To jest wartość wyjściowa.

Pamiętaj, że jeśli output-z to np. 5, to każda pozycja okna zwraca 5 wartości w danych wyjściowych w wymiarze z danych wyjściowych. Wartości te różnią się pod względem wykorzystania części jądra splotowego – każda współrzędna output-z ma osobne pole 3D z wartościami. Wyobraź sobie 5 osobnych splotów z innym filtrem.

Oto pseudokod dwuwymiarowego splotu z dopełnieniem i krokiem:

for (b, oz, oy, ox) {  // output coordinates
  value = 0;
  for (iz, ky, kx) {  // kernel coordinates and input z
    iy = oy*stride_y + ky - pad_low_y;
    ix = ox*stride_x + kx - pad_low_x;
    if ((iy, ix) inside the base area considered without padding) {
      value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
    }
  }
  output(b, oz, oy, ox) = value;
}

ConvertElementType

Zobacz też XlaBuilder::ConvertElementType.

Podobnie jak static_cast z elementami w języku C++, wykonuje operację konwersji na podstawie elementu, z kształtu danych na kształt docelowy. Wymiary muszą się zgadzać, a konwersja jest ustalana na podstawie elementów, np. w ramach procedury konwersji s32-f32 elementy s32 zmieniają się w elementy f32.

ConvertElementType(operand, new_element_type)

Argumenty Typ Semantyka
operand XlaOp tablica typu T z przyciemnieniem D
new_element_type PrimitiveType typ U

Wymiary operandu i docelowego kształtu muszą być takie same. Typy elementów źródłowych i docelowych nie mogą być krotkami.

Konwersja, np. z T=s32 na U=f32, wykona rutynę normalizowania liczby zmiennoprzecinkowej, np. zaokrąglanie do najmniejszej równomierności.

let a: s32[3] = {0, 1, 2};
let b: f32[3] = convert(a, f32);
then b == f32[3]{0.0, 1.0, 2.0}

CrossReplicaSum

Wykonuje obliczenia związane z funkcją AllReduce podczas obliczania sumy.

CustomCall

Zobacz też XlaBuilder::CustomCall.

Wywoływanie w ramach obliczeń funkcji udostępnionej przez użytkownika.

CustomCall(target_name, args..., shape)

Argumenty Typ Semantyka
target_name string Nazwa funkcji. Zostanie wysłana instrukcja wywołania kierowana na tę nazwę symbolu.
args sekwencja N XlaOp s N argumentów dowolnego typu, które zostaną przekazane do funkcji.
shape Shape Kształt wyjściowy funkcji

Podpis funkcji jest taki sam, niezależnie od argumentu czy typu argumentów:

extern "C" void target_name(void* out, void** in);

Jeśli na przykład funkcja CustomCall jest używana w następujący sposób:

let x = f32[2] {1,2};
let y = f32[2x3] { {10, 20, 30}, {40, 50, 60} };

CustomCall("myfunc", {x, y}, f32[3x3])

Oto przykład implementacji myfunc:

extern "C" void myfunc(void* out, void** in) {
  float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
  float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
  EXPECT_EQ(1, x[0]);
  EXPECT_EQ(2, x[1]);
  EXPECT_EQ(10, y[0][0]);
  EXPECT_EQ(20, y[0][1]);
  EXPECT_EQ(30, y[0][2]);
  EXPECT_EQ(40, y[1][0]);
  EXPECT_EQ(50, y[1][1]);
  EXPECT_EQ(60, y[1][2]);
  float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
  z[0][0] = x[1] + y[1][0];
  // ...
}

Funkcja podana przez użytkownika nie może mieć skutków ubocznych, a jej wykonanie musi być identyczne.

Kropka

Zobacz też XlaBuilder::Dot.

Dot(lhs, rhs)

Argumenty Typ Semantyka
lhs XlaOp tablica typu T
rhs XlaOp tablica typu T

Dokładna semantyka tej operacji zależy od rang operandów:

Dane wejściowe Wyniki Semantyka
wektor [n] wektor dot [n] wartość skalarna iloczyn skalarny wektorowy
macierz [m x k] wektor dot [k] wektor [m] mnożenie wektorów macierzy
macierz [m x k] dot macierz [k x n] macierz [m x n] mnożenie macierzy

Operacja wykonuje sumę produktów w przypadku drugiego wymiaru, czyli lhs (lub pierwszego, jeśli ma pozycję 1) i pierwszego wymiaru o wartości rhs. Są to wymiary „skrócone”. Zakontraktowane wymiary lhs i rhs muszą mieć ten sam rozmiar. W praktyce można jej używać do wykonywania iloczynów skalarnych między wektorami, mnożenia wektorów i macierzy oraz do mnożenia macierzy/macierzy.

DotGeneral

Zobacz też XlaBuilder::DotGeneral.

DotGeneral(lhs, rhs, dimension_numbers)

Argumenty Typ Semantyka
lhs XlaOp tablica typu T
rhs XlaOp tablica typu T
dimension_numbers DotDimensionNumbers numery umów i wymiarów wsadowych

Działa podobnie jak kropka, ale umożliwia określanie liczb wymiarów kontraktowych i wsadowych zarówno w przypadku lhs, jak i rhs.

Pola wymiarów kropek Typ Semantyka
lhs_contracting_dimensions powtórzony int64 lhs numeru wymiaru umownego
rhs_contracting_dimensions powtórzony int64 rhs numeru wymiaru umownego
lhs_batch_dimensions powtórzony int64 lhs numerów wymiarów wsadowych
rhs_batch_dimensions powtórzony int64 rhs numerów wymiarów wsadowych

Usługa DotGeneral podaje sumę produktów w porównaniu do wymiarów umowy określonych w dimension_numbers.

Numery powiązanych wymiarów umownych z lhs i rhs nie muszą być takie same, ale muszą mieć te same rozmiary.

Przykład z numerami wymiarów umownych:

lhs = { {1.0, 2.0, 3.0},
{4.0, 5.0, 6.0} }

rhs = { {1.0, 1.0, 1.0},
{2.0, 2.0, 2.0} }

DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(1);
dnums.add_rhs_contracting_dimensions(1);

DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
{15.0, 30.0} }

Powiązane numery wymiarów wsadowych z lhs i rhs muszą mieć te same rozmiary.

Przykład z numerami wymiarów wsadu (rozmiar wsadu 2, macierze 2 x 2):

lhs = { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }

rhs = { { {1.0, 0.0},
{0.0, 1.0} },
{ {1.0, 0.0},
{0.0, 1.0} } }

DotDimensionNumbers dnums;
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);

DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
{3.0, 4.0} },
{ {5.0, 6.0},
{7.0, 8.0} } }
Dane wejściowe Wyniki Semantyka
[b0, m, k] dot [b0, k, n] [b0, m, n] Batch matmul
[b0, b1, m, k] dot [b0, b1, k, n] [b0, b1, m, n] Batch matmul

Wynikowy numer wymiaru zaczyna się od wymiaru wsadowego, następnie wymiaru lhs nieskróconego/niezbiorczego, a na koniec wymiaru rhs nieskróconego/niezbiorczego.

DynamicSlice

Zobacz też XlaBuilder::DynamicSlice.

DynamicSlice wyodrębnia tablicę podrzędną z tablicy wejściowej w dynamicznym start_indices. Rozmiar wycinka w każdym wymiarze jest przekazywany w parametrze size_indices, który określa punkt końcowy wyłącznych przedziałów wycinków w każdym wymiarze: [początek, początek + rozmiar). Kształt elementu start_indices musi mieć pozycję == 1, a rozmiar wymiaru równa się pozycji operand.

DynamicSlice(operand, start_indices, size_indices)

Argumenty Typ Semantyka
operand XlaOp N tablica wymiarowa typu T
start_indices sekwencja N XlaOp Lista N skalarnych liczb całkowitych zawierających początkowe indeksy wycinka dla każdego wymiaru. Wartość nie może być mniejsza niż 0.
size_indices ArraySlice<int64> Lista N liczb całkowitych zawierających rozmiar wycinka dla każdego wymiaru. Każda wartość musi być większa niż 0, a wartość początkowa + rozmiar musi być mniejsza od rozmiaru wymiaru lub jej równa, aby uniknąć zawijania rozmiaru wymiaru modułu.

Efektywne indeksy wycinków są obliczane przez zastosowanie tej przekształcenia do każdego indeksu i w tabeli [1, N) przed wykonaniem wycinka:

start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])

Dzięki temu wyodrębniony wycinek będzie zawsze w granicach względem tablicy argumentów. Jeśli przed zastosowaniem przekształcenia wycinek mieści się w granicach, przekształcenie nie przyniesie żadnego efektu.

Przykład jednowymiarowego:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let s = {2}

DynamicSlice(a, s, {2}) produces:
{2.0, 3.0}

Przykład obrazu dwuwymiarowego:

let b =
{ {0.0,  1.0,  2.0},
{3.0,  4.0,  5.0},
{6.0,  7.0,  8.0},
{9.0, 10.0, 11.0} }
let s = {2, 1}

DynamicSlice(b, s, {2, 2}) produces:
{ { 7.0,  8.0},
{10.0, 11.0} }

DynamicUpdateSlice

Zobacz też XlaBuilder::DynamicUpdateSlice.

DynamicUpdateSlice generuje wynik będący wartością tablicy wejściowej operand, z wycinkiem update zastąpionym w start_indices. Kształt tablicy update określa kształt tablicy podrzędnej uaktualnionego wyniku. Kształt elementu start_indices musi mieć pozycję == 1, a rozmiar wymiaru równy rangiowi elementu operand.

DynamicUpdateSlice(operand, update, start_indices)

Argumenty Typ Semantyka
operand XlaOp N tablica wymiarowa typu T
update XlaOp N tablica wymiarowa typu T zawierająca aktualizację wycinka. Każdy wymiar kształtu aktualizacji musi być ściśle większy niż 0, a rozmiar pola rozpoczęcia i aktualizacji musi być mniejszy lub równy rozmiarowi operandu każdego wymiaru, aby uniknąć generowania poza granicami indeksów aktualizacji.
start_indices sekwencja N XlaOp Lista N skalarnych liczb całkowitych zawierających początkowe indeksy wycinka dla każdego wymiaru. Wartość nie może być mniejsza niż 0.

Efektywne indeksy wycinków są obliczane przez zastosowanie tej przekształcenia do każdego indeksu i w tabeli [1, N) przed wykonaniem wycinka:

start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i])

Dzięki temu zaktualizowany wycinek będzie zawsze w granicach względem tablicy argumentów. Jeśli przed zastosowaniem przekształcenia wycinek mieści się w granicach, przekształcenie nie przyniesie żadnego efektu.

Przykład jednowymiarowego:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
let u = {5.0, 6.0}
let s = {2}

DynamicUpdateSlice(a, u, s) produces:
{0.0, 1.0, 5.0, 6.0, 4.0}

Przykład obrazu dwuwymiarowego:

let b =
{ {0.0,  1.0,  2.0},
{3.0,  4.0,  5.0},
{6.0,  7.0,  8.0},
{9.0, 10.0, 11.0} }
let u =
{ {12.0,  13.0},
{14.0,  15.0},
{16.0,  17.0} }

let s = {1, 1}

DynamicUpdateSlice(b, u, s) produces:
{ {0.0,  1.0,  2.0},
{3.0, 12.0, 13.0},
{6.0, 14.0, 15.0},
{9.0, 16.0, 17.0} }

Binarne operacje arytmetyczne dotyczące elementów

Zobacz też XlaBuilder::Add.

Obsługiwany jest zestaw binarnych operacji arytmetycznych z uwzględnieniem elementów.

Op(lhs, rhs)

Gdzie Op to jedna z wartości: Add (dodawanie), Sub (odejmowanie), Mul (mnożenie), Div (dzielenie), Rem (reszta), Max (maksymalna), Min (minimalna), LogicalAnd (logiczna ORAZ) lub LogicalOr (logiczna LUB).

Argumenty Typ Semantyka
lhs XlaOp operand po lewej stronie: tablica typu T
rhs XlaOp operand po prawej stronie: tablica typu T

Kształty argumentów muszą być podobne lub zgodne. Przeczytaj dokumentację transmitowania, by dowiedzieć się, co oznacza zgodność kształtów. Wynik operacji ma kształt będący wynikiem publikacji 2 tablic wejściowych. W tym wariancie operacje między tablicami o różnych poziomach nie są obsługiwane, chyba że jeden z operandów jest skalarem.

Gdy Op ma wartość Rem, znak wyniku jest odczytywany z dzielnicy, a wartość bezwzględna wyniku jest zawsze mniejsza niż wartość bezwzględna dzielnika.

Zbyt duża część dzielenia liczby całkowitej (dzielenie/reszta ze znakiem/bez podpisu lub dzielenie/reszta ze znaku INT_SMIN z wartością -1) daje zdefiniowaną wartość implementacji.

W przypadku tych operacji dostępny jest alternatywny wariant z obsługą transmisji o innej pozycji:

Op(lhs, rhs, broadcast_dimensions)

Gdzie Op jest taki sam jak powyżej. Ten wariant operacji powinien być używany w przypadku operacji arytmetycznych na tablicach o różnych poziomach równania (np. podczas dodawania macierzy do wektora).

Dodatkowy operand broadcast_dimensions to wycinek liczb całkowitych używany do zwiększania pozycji operandu o niższej pozycji do pozycji operandu o wyższej pozycji. broadcast_dimensions przyporządkowuje wymiary kształtów o niższej pozycji do wymiarów tych, które są wyżej w rankingu. Niezmapowane wymiary rozwiniętego kształtu są wypełniane wymiarami rozmiaru 1. Przesyłanie o zdegenerowanych wymiarach przekazuje kształty wzdłuż tych zdegenerowanych wymiarów, aby wyrównać kształty obu argumentów. Semantyka jest szczegółowo opisana na stronie transmisji.

Operacje porównania dotyczące elementów

Zobacz też XlaBuilder::Eq.

Obsługiwany jest zestaw standardowych operacji porównania binarnych elementów. Pamiętaj, że przy porównywaniu typów zmiennoprzecinkowych obowiązuje standardowa semantyka porównawcza zmiennoprzecinkowych IEEE 754.

Op(lhs, rhs)

Gdzie Op to jedna z tych wartości: Eq (równe), Ne (nie równa się), Ge (większe lub równe), Gt (większe niż), Le (mniejsze lub równe), Lt (mniejsze niż). Inny zestaw operatorów: EqTotalOrder, NeTotalOrder, GeTotalOrder, GtTotalOrder, LeTotalOrder i LtTotalOrder z innymi operatorami udostępnia te same funkcje z wyjątkiem tego, że dodatkowo obsługują porządek całkowity w przypadku liczb zmiennoprzecinkowych przez egzekwowanie -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf.

Argumenty Typ Semantyka
lhs XlaOp operand po lewej stronie: tablica typu T
rhs XlaOp operand po prawej stronie: tablica typu T

Kształty argumentów muszą być podobne lub zgodne. Przeczytaj dokumentację transmitowania, by dowiedzieć się, co oznacza zgodność kształtów. Wynik operacji ma kształt będący wynikiem publikacji 2 tablic wejściowych z elementem typu PRED. W tym wariancie operacje między tablicami o różnych poziomach nie są obsługiwane, chyba że jeden z operandów jest skalarem.

W przypadku tych operacji dostępny jest alternatywny wariant z obsługą transmisji o innej pozycji:

Op(lhs, rhs, broadcast_dimensions)

Gdzie Op jest taki sam jak powyżej. Tego wariantu należy używać do porównywania operacji między tablicami o różnych poziomach (np. dodawania macierzy do wektora).

Dodatkowy operand broadcast_dimensions to wycinek liczb całkowitych określających wymiary używane do rozgłaszania operandów. Semantyka jest szczegółowo opisana na stronie transmisji.

Funkcje jednoargumentowe dotyczące elementów

XlaBuilder obsługuje następujące jednoargumentowe funkcje związane z elementami:

Abs(operand) x -> |x| z uwzględnieniem elementów.

Ceil(operand) Ceila żywiołów: x -> ⌈x⌉.

Cos(operand) Cosinus x -> cos(x) elementu.

Exp(operand) W przypadku elementów naturalny wykładniczy x -> e^x.

Floor(operand) piętro z uwzględnieniem elementów: x -> ⌊x⌋.

Imag(operand) Fragment urojony z uwzględnieniem elementów złożonego (lub rzeczywistego) kształtu. x -> imag(x). Jeśli operand jest liczbą zmiennoprzecinkową, zwraca 0.

IsFinite(operand) Sprawdza, czy każdy element właściwości operand jest skończony, tj. nie jest dodatni lub ujemny i nie ma wartości NaN. Zwraca tablicę wartości PRED o tym samym kształcie co wartość wejściowa, gdzie każdy element ma wartość true, jeśli i tylko jeśli odpowiedni element wejściowy jest skończony.

Log(operand) Logarytm naturalny elementów x -> ln(x).

LogicalNot(operand) Funkcja logiczna dotycząca elementów, a nie x -> !(x).

Logistic(operand) Obliczanie funkcji logistycznej z uwzględnieniem elementów x -> logistic(x).

PopulationCount(operand) oblicza liczbę bitów ustawionych w każdym elemencie właściwości operand.

Neg(operand) negacja z uwzględnieniem elementów x -> -x.

Real(operand) Rzeczywista część złożonego (lub rzeczywistego) elementu. x -> real(x). Jeśli operand jest typu zmiennoprzecinkowego, zwraca tę samą wartość.

Rsqrt(operand) Odwrotność działania pierwiastka kwadratowego z uwzględnieniem elementówx -> 1.0 / sqrt(x).

Sign(operand) Operacja podpisywania na podstawie elementu x -> sgn(x), gdzie

\[\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}\]

przy użyciu operatora porównania typu elementu operand.

Sqrt(operand) Operacja na pierwiastku kwadratowym według elementu x -> sqrt(x).

Cbrt(operand) Operacje na pierwiastku sześciennym x -> cbrt(x) na podstawie elementu.

Tanh(operand) tangens hiperboliczny elementów x -> tanh(x).

Round(operand) Zaokrąglenie według elementu, wyrówna się od zera.

RoundNearestEven(operand) Zaokrąglanie według elementów, łączy się z najbliższymi parzystymi.

Argumenty Typ Semantyka
operand XlaOp Argument dla funkcji

Funkcja jest stosowana do każdego elementu w tablicy operand, co daje tablicę o tym samym kształcie. Wartość operand może być wartością skalarną (pozycja 0).

Fft

Operacja FFT XLA implementuje transformaty Fouriera do przodu i odwrotnego do rzeczywistych i złożonych danych wejściowych i wyjściowych. Obsługuje wielowymiarowe FFT na maksymalnie 3 osiach.

Zobacz też XlaBuilder::Fft.

Argumenty Typ Semantyka
operand XlaOp Macierz, którą przekształcamy Fouriera.
fft_type FftType Patrz tabela poniżej.
fft_length ArraySlice<int64> Długości w domenie czasu przekształcanych osi. Jest to szczególnie potrzebne w przypadku narzędzia IRFFT, aby dopasować rozmiar najbardziej wewnętrznej osi, ponieważ funkcja RFFT(fft_length=[16]) ma taki sam kształt wyjściowy co RFFT(fft_length=[17]).
FftType Semantyka
FFT Przekieruj ze złożonym na złożonym przekształceniem. Kształt jest niezmieniony.
IFFT Odwrotność funkcji zespolonej do złożonego. Kształt jest niezmieniony.
RFFT Przekieruj ruch rzeczywisty do złożony. Kształt najbardziej wewnętrznej osi zostanie zmniejszony do fft_length[-1] // 2 + 1, jeśli fft_length[-1] ma wartość inną niż 0, a odwrotna część przekształconego sygnału zostanie pominięta poza częstotliwość Nyquista.
IRFFT Odwrotność funkcji FFT z rzeczywistością do złożoności (tj. funkcja ma złożoność, zwraca wartość rzeczywistą). Kształt najbardziej wewnętrznej osi jest rozszerzony do fft_length[-1], jeśli fft_length[-1] ma wartość inną niż 0, co oznacza, że część przekształconego sygnału wykracza poza częstotliwość Nyquista z odwrotnej sprzężenia z wartościami 1 i fft_length[-1] // 2 + 1.

Wielowymiarowe

Jeśli podana jest więcej niż 1 fft_length, jest to równoznaczne z zastosowaniem kaskady operacji FFT do każdej z najbardziej wewnętrznych osi. Zauważ, że w przypadkach rzeczywistych > złożonych i złożonych> najpierw wykonywana jest najbardziej wewnętrzna oś (RFFT; ostatni w przypadku IRFFT), dlatego najbardziej wewnętrzna oś to ta, która zmienia rozmiar. Inne przekształcenia osi będą wtedy złożone->złożone.

Szczegóły implementacji

FFT procesora jest obsługiwany przez TensorFFT firmy Eigen. Funkcja FFT GPU korzysta z funkcji cuFFT.

Zbieraj

Operacja zbierania XLA łączy ze sobą kilka wycinków (z każdego wycinka o potencjalnie różnych opóźnieniach czasu działania) tablicy wejściowej.

Semantyka ogólna

Zobacz też XlaBuilder::Gather. Bardziej intuicyjny opis znajdziesz w sekcji „Informalny opis” poniżej.

gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)

Argumenty Typ Semantyka
operand XlaOp Tablica, z której zbieramy dane.
start_indices XlaOp Tablica zawierająca indeksy początkowe zebranych przez nas wycinków.
index_vector_dim int64 Wymiar w elemencie start_indices, który „zawiera” indeksy początkowe. Szczegółowy opis znajdziesz poniżej.
offset_dims ArraySlice<int64> Zbiór wymiarów w kształcie wyjściowym, które są odsunięte do tablicy wyciętej z argumentu argumentu.
slice_sizes ArraySlice<int64> slice_sizes[i] to granice wycinka o wymiarze i.
collapsed_slice_dims ArraySlice<int64> Zbiór wymiarów w każdym wycinku, który jest zwinięty. Te wymiary muszą mieć rozmiar 1.
start_index_map ArraySlice<int64> Mapa pokazująca, jak mapować indeksy w polu start_indices na indeksy prawne w operand.
indices_are_sorted bool Określa, czy indeksy na pewno zostaną sortowane według elementu wywołującego.

Dla wygody wymiary w tablicy wyjściowej oznaczamy etykietą batch_dims, a nie offset_dims.

Wynikiem jest tablica pozycji batch_dims.size + offset_dims.size.

Wartość operand.rank musi być równa sumie wartości offset_dims.size i collapsed_slice_dims.size. Dodatkowo slice_sizes.size musi być równy operand.rank.

Jeśli index_vector_dim ma wartość start_indices.rank, domyślnie uznajemy, że start_indices ma na końcu wymiar 1 (np. jeśli start_indices ma kształt [6,7], a index_vector_dim to 2, domyślnie zakładamy, że kształt start_indices ma postać [6,7,1]).

Granice tablicy wyjściowej wzdłuż wymiaru i są obliczane w ten sposób:

  1. Jeśli parametr i występuje w elemencie batch_dims (tj. równa się batch_dims[k] w przypadku niektórych k), wybieramy odpowiednie granice wymiaru poza start_indices.shape, pomijając index_vector_dim (np. jeśli k < index_vector_dim wynosi batch_dims[k], a w przeciwnym razie wybieramy start_indices.shape.dims[k+1]).kstart_indices.shape.dims

  2. Jeśli i występuje w tabeli offset_dims (tj. równa się offset_dims[k] w przypadku niektórych k), wybieramy odpowiednią wartość z slice_sizes po uwzględnieniu parametru collapsed_slice_dims (np. wybieramy adjusted_slice_sizes[k], gdzie adjusted_slice_sizes to slice_sizes z usuniętymi granicami w indeksach collapsed_slice_dims).

Formalnie indeks operacji In odpowiadający danemu indeksowi wyjściowemu Out jest obliczany w ten sposób:

  1. Niech G = { Out[k] dla k w batch_dims }. Użyj funkcji G, aby wyciąć wektor S w taki sposób, że S[i] = start_indices[Połącz(G, i)], gdzie Połączenie(A; b) wstawia b w pozycji index_vector_dim do A. Pamiętaj, że ten parametr jest dobrze zdefiniowany nawet wtedy, gdy pole G jest puste: jeśli pole G jest puste, to S = start_indices.

  2. Utwórz indeks początkowy Sin w lokalizacji operand przy użyciu elementu S, rozpraszając S przy użyciu elementu start_index_map. A dokładniej:

    1. Sin[start_index_map[k]] = S[k], jeśli k < start_index_map.size.

    2. Sin[_] = 0 w innym przypadku.

  3. Utwórz indeks Oin w formacie operand, rozpraszając indeksy w wymiarach przesunięcia w formacie Out zgodnie ze zbiorem collapsed_slice_dims. A dokładniej:

    1. Oin[remapped_offset_dims(k)] = Out[offset_dims[k]], jeśli k < offset_dims.size (definicja remapped_offset_dims jest definiowana poniżej).

    2. Oin[_] = 0 w innym przypadku.

  4. In to Oin + Sin, gdzie + oznacza dodawanie elementu.

remapped_offset_dims to funkcja monotoniczna z domeną [0, offset_dims.size) i zakresem [0, operand.rank) \ collapsed_slice_dims. Jeśli np. offset_dims.size to 4, operand.rank to 6, a collapsed_slice_dims to {0, 2}, a następnie remapped_offset_dims to {01, 13, 24, 35}.

Jeśli indices_are_sorted ma wartość Prawda, XLA może przyjąć, że elementy start_indices są posortowane (w kolejności rosnącej start_index_map) według użytkownika. W przeciwnym razie definicja semantyki jest implementowana.

Nieformalny opis i przykłady

Nieformalnie każdy indeks Out w tablicy wyjściowej odpowiada elementowi E w tablicy operandu, obliczony w ten sposób:

  • Aby wyszukać początkowy indeks z start_indices, używamy wymiarów wsadowych w pliku Out.

  • Używamy start_index_map, aby zmapować indeks początkowy (który może być mniejszy niż operand.rank) na „pełny” indeks początkowy w indeksie operand.

  • Dynamicznie wycinamy wycinek o rozmiarze slice_sizes przy użyciu pełnego indeksu początkowego.

  • Zmieniamy kształt wycinka, zwijając wymiary collapsed_slice_dims. Ponieważ wszystkie wymiary zwiniętego wycinka muszą mieć wartość 1, to zmiana kształtu jest zawsze legalna.

  • Do indeksowania tego wycinka używamy wymiarów przesunięcia w polu Out. W ten sposób uzyskujemy element wejściowy E odpowiadający indeksowi wyjściowemu Out.

We wszystkich podanych niżej przykładach index_vector_dim ma wartość start_indices.rank1. Ciekawsze wartości parametru index_vector_dim nie zmieniają zasadniczo działania, ale sprawiają, że wizualna prezentacja jest bardziej uciążliwa.

Aby pokazać, jak łączą się wszystkie powyższe elementy, spójrzmy na przykład, który zbiera 5 wycinków kształtu [8,6] z tablicy [16,11]. Pozycja wycinka w tablicy [16,11] może być reprezentowana jako wektor indeksu kształtu S64[2], więc zbiór 5 pozycji można przedstawić jako tablica S64[5,2].

Działanie operacji zbierania można przedstawić jako przekształcenie indeksu, które pobiera [G,O0,O1] (indeks w kształcie danych wyjściowych) i mapuje go na element w tablicy wejściowej w ten sposób:

Najpierw wybieramy wektor (X,Y) z tablicy indeksów zbierania danych za pomocą funkcji G. Element tablicy wyjściowej w pozycji [G,O0,O1] jest elementem tablicy wejściowej w indeksie [X+O0,Y+O1].

slice_sizes ma wartość [8,6], która określa zakres wartości O0 i O1, a to z kolei wyznacza granice wycinka.

Ta operacja gromadzenia działa jako wsadowy wycinek dynamiczny z wymiarem G jako wymiarem wsadowym.

Indeksy zbierania mogą być wielowymiarowe. Na przykład bardziej ogólna wersja z powyższego przykładu, w której użyto tablicy „indeksy” o kształtach [4,5,2], można przetłumaczyć indeksy w ten sposób:

Ponownie jest to dynamiczny wycinek G0 i G1 jako wymiary wsadu. Rozmiar wycinka nadal to [8,6].

Operacja zbierania w XLA uogólnia omówioną powyżej nieformalną semantyczną semantykę w następujący sposób:

  1. Możemy skonfigurować, które wymiary w kształcie wyjściowym są wymiarami przesunięcia (w ostatnim przykładzie wymiary zawierające O0 i O1). Wyjściowe wymiary wsadu (wymiary zawierające G0, G1 w ostatnim przykładzie) są zdefiniowane jako wymiary wyjściowe, które nie są wymiarami przesunięcia.

  2. Liczba wymiarów przesunięcia wyjściowego bezpośrednio występujących w kształcie wyjściowym może być mniejsza niż pozycja wejściowa. Te „brakujące” wymiary, które są wymienione wprost jako collapsed_slice_dims, muszą mieć wycinek o rozmiarze 1. Ponieważ mają wycinek o rozmiarze 1, jedynym prawidłowym indeksem jest dla nich 0, a ich usunięcie nie wprowadza niejednoznacznych informacji.

  3. Wycinek wyodrębniony z tablicy „Gather Indices” ((X, Y) w ostatnim przykładzie) może zawierać mniej elementów niż pozycja tablicy wejściowej, a jawne mapowanie określa sposób rozszerzania indeksu tak, aby miał tę samą pozycję co dane wejściowe.

Na koniec korzystamy z punktów (2) i (3) do implementacji tf.gather_nd:

G0 i G1 są używane do wycinania początkowego indeksu z tablicy indeksów zbierania w zwykły sposób, z tą różnicą, że indeks początkowy ma tylko jeden element: X. Podobnie jest tylko 1 wyjściowy indeks przesunięcia o wartości O0. Jednak przed wykorzystaniem ich jako indeksów do tablicy wejściowej są one rozszerzane zgodnie z metodami „Gather Index Mapping” (start_index_map w formalnym opisie) i „Mapowanie przesunięcia” (remapped_offset_dims w formalnym opisie) do wartości [X,0] i [0,O0], które dodają do [X,O0] indeks {1GatherIndicesO0], czyli indeks {1GatherIndicesO0], czyli indeks danych wyjściowych: [G1}G0.00GG11tf.gather_nd

Pole slice_sizes w tym przypadku wynosi [1,11]. Oznacza to intuicyjnie, że każdy indeks X w tablicy indeksów zbierających wybiera cały wiersz, a wynikiem jest łączenie wszystkich tych wierszy.

GetDimensionSize

Zobacz też XlaBuilder::GetDimensionSize.

Zwraca rozmiar danego wymiaru operandu. Argument musi mieć kształt tablicy.

GetDimensionSize(operand, dimension)

Argumenty Typ Semantyka
operand XlaOp n-wymiarowa tablica wejściowa
dimension int64 Wartość z przedziału [0, n), która określa wymiar

SetDimensionSize

Zobacz też XlaBuilder::SetDimensionSize.

Ustawia dynamiczny rozmiar danego wymiaru XlaOp. Argument musi mieć kształt tablicy.

SetDimensionSize(operand, size, dimension)

Argumenty Typ Semantyka
operand XlaOp n-wymiarowa tablica wejściowa.
size XlaOp int32 reprezentujący dynamiczny rozmiar środowiska wykonawczego.
dimension int64 Wartość z przedziału [0, n), która określa wymiar.

W rezultacie przekazuj operand w sposób, który uwzględnia dynamiczny wymiar śledzony przez kompilator.

Dopełnione wartości będą ignorowane przez operacje redukcji na dalszych etapach.

let v: f32[10] = f32[10]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
let five: s32 = 5;
let six: s32 = 6;

// Setting dynamic dimension size doesn't change the upper bound of the static
// shape.
let padded_v_five: f32[10] = set_dimension_size(v, five, /*dimension=*/0);
let padded_v_six: f32[10] = set_dimension_size(v, six, /*dimension=*/0);

// sum == 1 + 2 + 3 + 4 + 5
let sum:f32[] = reduce_sum(padded_v_five);
// product == 1 * 2 * 3 * 4 * 5
let product:f32[] = reduce_product(padded_v_five);

// Changing padding size will yield different result.
// sum == 1 + 2 + 3 + 4 + 5 + 6
let sum:f32[] = reduce_sum(padded_v_six);

GetTupleElement

Zobacz też XlaBuilder::GetTupleElement.

Indeksuje do krotki z wartością stałą w czasie kompilacji.

Wartość musi być stałą w czasie kompilacji, aby wnioskowanie na temat kształtu mogło określać typ wartości wynikowej.

Jest to odpowiednik std::get<int N>(t) w C++. Ogólnie rzecz biorąc:

let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);
let element_1: s32 = gettupleelement(t, 1);  // Inferred shape matches s32.

Zobacz też tf.tuple.

W kanale

Zobacz też XlaBuilder::Infeed.

Infeed(shape)

Argument Typ Semantyka
shape Shape Kształt danych odczytywanych przez interfejs InFeed. Pole układu kształtu musi być ustawione tak, aby pasowało do układu danych wysyłanych do urządzenia. W przeciwnym razie zachowanie kształtu pozostanie nieokreślone.

Odczytuje pojedynczy element danych z małego interfejsu strumieniowego przesyłania danych In-Feed na urządzeniu, interpretuje dane jako dany kształt i jego układ, a następnie zwraca XlaOp danych. Dozwolonych jest wiele operacji związanych z InFeed, ale musi istnieć ich łączna kolejność. Na przykład 2 kanały In-Feed w poniższym kodzie uporządkowane są w porządku całkowitym, ponieważ istnieje zależność między pętlami podczas.

result1 = while (condition, init = init_value) {
  Infeed(shape)
}

result2 = while (condition, init = result1) {
  Infeed(shape)
}

Zagnieżdżone kształty krotek nie są obsługiwane. W przypadku pustego kształtu krotki działanie Infeed jest w rzeczywistości bezobsługowe i działa bez odczytywania żadnych danych z kanału In-Feed w urządzeniu.

Jota

Zobacz też XlaBuilder::Iota.

Iota(shape, iota_dimension)

Tworzy stały literał na urządzeniu, a nie potencjalnie duży transfer hosta. Tworzy tablicę o określonym kształcie i przechowującą wartości, które zaczynają się od zera i rosną o jeden w danym wymiarze. W przypadku typów zmiennoprzecinkowych wygenerowana tablica jest odpowiednikiem funkcji ConvertElementType(Iota(...)), gdzie Iota jest typu całkowego, a przeliczenie na typ zmiennoprzecinkowy.

Argumenty Typ Semantyka
shape Shape Kształt tablicy utworzonej przez funkcję Iota()
iota_dimension int64 Wymiar, który ma być przyrostowy.

Na przykład Iota(s32[4, 8], 0) zwraca:

  [[0, 0, 0, 0, 0, 0, 0, 0 ],
   [1, 1, 1, 1, 1, 1, 1, 1 ],
   [2, 2, 2, 2, 2, 2, 2, 2 ],
   [3, 3, 3, 3, 3, 3, 3, 3 ]]

Iota(s32[4, 8], 1) zwroty

  [[0, 1, 2, 3, 4, 5, 6, 7 ],
   [0, 1, 2, 3, 4, 5, 6, 7 ],
   [0, 1, 2, 3, 4, 5, 6, 7 ],
   [0, 1, 2, 3, 4, 5, 6, 7 ]]

Mapa

Zobacz też XlaBuilder::Map.

Map(operands..., computation)

Argumenty Typ Semantyka
operands sekwencja N XlaOp s N tablic typów T0..T{N-1}
computation XlaComputation obliczenia typu T_0, T_1, .., T_{N + M -1} -> S z N parametrami typów T i M dowolnego typu
dimensions Tablica int64 tablica wymiarów mapy

Stosuje funkcję skalarną do określonych tablic operands, tworząc tablicę o tych samych wymiarach, w której każdy element jest wynikiem zmapowanej funkcji zastosowanej do odpowiednich elementów w tablicach wejściowych.

Zmapowana funkcja to dowolne obliczenia z ograniczeniem, które obejmuje N danych wejściowych typu skalarnego T i pojedyncze dane wyjściowe typu S. Dane wyjściowe mają te same wymiary co operandy, ale typ elementu T zostaje zastąpiony przez S.

Na przykład: Map(op1, op2, op3, computation, par1) mapuje elem_out <- computation(elem1, elem2, elem3, par1) w każdym (wielowymiarowym) indeksie w tablicach wejściowych, aby utworzyć tablicę wyjściową.

OptimizationBarrier

Blokuje możliwość przejścia na optymalizację przed przesunięciem obliczeń przez barierę.

Zapewnia, że wszystkie dane wejściowe są sprawdzane przed użyciem operatorów zależnych od danych wyjściowych bariery.

Wkładka

Zobacz też XlaBuilder::Pad.

Pad(operand, padding_value, padding_config)

Argumenty Typ Semantyka
operand XlaOp tablica typu T
padding_value XlaOp skalar typu T służący do wypełnienia dodanego dopełnienia
padding_config PaddingConfig wielkość dopełnienia na obu krawędziach (niska i wysoka) oraz między elementami każdego wymiaru

Rozszerza podaną tablicę operand przez dopełnienie wokół tablicy, a także między jej elementami z podaną wartością padding_value. padding_config określa stopień dopełnienia krawędzi i dopełnienia wewnętrznego dla każdego wymiaru.

PaddingConfig to powtarzane pole PaddingConfigDimension, które zawiera 3 pola na każdy wymiar: edge_padding_low, edge_padding_high i interior_padding.

edge_padding_low i edge_padding_high określają stopień dopełnienia dodanego odpowiednio najniższego poziomu (obok indeksu 0) i najwyższego (obok najwyższego indeksu) każdego wymiaru. Wartość dopełnienia krawędzi może być ujemna – wartość bezwzględna dopełnienia ujemnego wskazuje liczbę elementów, które mają zostać usunięte z określonego wymiaru.

interior_padding określa stopień dopełnienia dodanego między dowolnymi dwoma elementami w każdym wymiarze. Wartość nie może być ujemna. Dopełnienie wewnętrzne odbywa się logicznie przed dopełnieniem krawędzi, dlatego w przypadku tego dopełnienia elementy są usuwane z operandu z wypełnieniem wewnętrznym.

Ta operacja jest niedostępna, jeśli wszystkie pary dopełnienia krawędzi mają wartość (0, 0), a wartości dopełnienia wewnętrznego mają wartość 0. Na rysunku poniżej widać przykłady różnych wartości edge_padding i interior_padding w tablicy dwuwymiarowej.

RecV

Zobacz też XlaBuilder::Recv.

Recv(shape, channel_handle)

Argumenty Typ Semantyka
shape Shape kształtu danych do otrzymania
channel_handle ChannelHandle unikalny identyfikator dla każdej pary wysyłania/odbierania.

Odbiera dane o danym kształcie z instrukcji Send w innych obliczeniach, które mają ten sam uchwyt kanału. Zwraca XlaOp dla odebranych danych.

Interfejs API klienta operacji Recv reprezentuje komunikację synchroniczną. Instrukcja jest jednak wewnętrznie podzielona na 2 instrukcje HLO (Recv i RecvDone), aby umożliwić asynchroniczne przesyłanie danych. Zobacz też HloInstruction::CreateRecv i HloInstruction::CreateRecvDone.

Recv(const Shape& shape, int64 channel_id)

Przydziela zasoby wymagane do odbierania danych z instrukcji Send o tym samym identyfikatorze kanału. Zwraca kontekst dla przydzielonych zasobów, który jest używany przez następującą instrukcję RecvDone do oczekiwania na zakończenie transferu danych. Kontekst jest kropką {receive buffer (shape), identyfikator żądania (U32)}, której można używać tylko w instrukcji RecvDone.

RecvDone(HloInstruction context)

Biorąc pod uwagę kontekst utworzony przez instrukcję Recv, czeka na zakończenie przenoszenia danych i zwraca odebrane dane.

Ograniczamy

Zobacz też XlaBuilder::Reduce.

Stosuje funkcję redukcji do co najmniej 1 tablicy równolegle.

Reduce(operands..., init_values..., computation, dimensions)

Argumenty Typ Semantyka
operands Sekwencja N XlaOp N tablic typów T_0, ..., T_{N-1}.
init_values Sekwencja N XlaOp N skalary typu T_0, ..., T_{N-1}.
computation XlaComputation obliczenia typu T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}).
dimensions Tablica int64 nieuporządkowaną tablicę wymiarów do zredukowania.

Gdzie:

  • Wartość N musi być większa od lub równa 1.
  • Obliczenie musi przybierać w przybliżeniu (patrz niżej)
  • Wszystkie tablice wejściowe muszą mieć te same wymiary.
  • Wszystkie początkowe wartości muszą tworzyć tożsamość poniżej computation.
  • Jeśli N = 1, Collate(T) ma wartość T.
  • Jeśli N > 1, Collate(T_0, ..., T_{N-1}) jest kropką elementów N typu T.

Ta operacja zmniejsza co najmniej 1 wymiar każdej tablicy wejściowej do postaci skalarnych. Pozycja każdej zwróconej tablicy to rank(operand) - len(dimensions). Wynik operacji to Collate(Q_0, ..., Q_N), gdzie Q_i to tablica typu T_i, których wymiary opisano poniżej.

Różne backendy mogą ponownie powiązać obliczenia redukcji. Może to prowadzić do różnic liczbowych, ponieważ niektóre funkcje redukcji, np. dodawanie, nie wiążą się z liczbami zmiennoprzecinkowymi. Jeśli jednak zakres danych jest ograniczony, dodawanie liczb zmiennoprzecinkowych jest na tyle blisko, że można je powiązać w większości praktycznych zastosowań.

Przykłady

Ograniczając do jednego wymiaru w pojedynczej tablicy 1D z wartościami [10, 11, 12, 13] przy użyciu funkcji redukcji f (czyli computation), można to obliczyć jako

f(10, f(11, f(12, f(init_value, 13)))

ale jest też wiele innych możliwości, np.

f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))

Poniżej znajduje się skrócony przykład wdrożenia redukcji z użyciem sumowania z wartością początkową 0 wartości.

result_shape <- remove all dims in dimensions from operand_shape

# Iterate over all elements in result_shape. The number of r's here is equal
# to the rank of the result
for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
  # Initialize this result element
  result[r0, r1...] <- 0

  # Iterate over all the reduction dimensions
  for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
    # Increment the result element with the value of the operand's element.
    # The index of the operand's element is constructed from all ri's and di's
    # in the right order (by construction ri's and di's together index over the
    # whole operand shape).
    result[r0, r1...] += operand[ri... di]

Oto przykład redukcji tablicy 2D (macierzy). Kształt ma pozycję 2, wymiar 0 rozmiaru 2 i wymiar 1 rozmiaru 3:

Wyniki redukcji wymiarów 0 lub 1 przy użyciu funkcji „dodaj”:

Zwróć uwagę, że oba wyniki redukcji są tablicami 1D. Jedna z nich jest widoczna na diagramie jako kolumna, a druga – jako wiersz, dla wygody wizualnej.

Bardziej złożonym przykładem jest tablica 3D. Jego pozycja to 3, wymiar 0 dla rozmiaru 4, wymiar 1 dla rozmiaru 2 i wymiar 2 dla rozmiaru 3. Dla uproszczenia wartości od 1 do 6 są replikowane w wymiarze 0.

Podobnie jak w przypadku przykładu 2D, możemy zredukować tylko 1 wymiar. Jeśli np. zmniejszymy wymiar 0, uzyskamy tablicę rankingową 2, w której wszystkie wartości w wymiarze 0 zostały zwinięte w skalar:

|  4   8  12 |
| 16  20  24 |

Jeśli zmniejszymy wymiar 2, otrzymamy również tablicę rangi 2, w której wszystkie wartości w wymiarze 2 zostały złożone w skalar:

| 6  15 |
| 6  15 |
| 6  15 |
| 6  15 |

Zauważ, że względna kolejność między pozostałymi wymiarami w danych wejściowych jest zachowywana w danych wyjściowych, ale niektóre wymiary mogą otrzymywać nowe liczby (ponieważ zmienią się pozycje w rankingu).

Możemy też ograniczyć liczbę wymiarów. Po dodaniu wymiarów 0 i 1 powstaje macierz 1D [20, 28, 36].

Zmniejszenie tablicy 3D do wszystkich jej wymiarów daje skalarny 84.

Redukcja Variadic

W przypadku N > 1 zastosowanie funkcji redukcji jest nieco bardziej złożone, ponieważ jest stosowane jednocześnie do wszystkich danych wejściowych. operandy są dostarczane do obliczeń w tej kolejności:

  • Uruchamiam obniżoną wartość pierwszego argumentu
  • ...
  • Działanie obniżonej wartości dla n-tego operandu
  • Wartość wejściowa pierwszego argumentu
  • ...
  • Wartość wejściowa dla n-tego operandu

Spójrzmy na przykład na tę funkcję redukcji, za pomocą której można obliczyć równolegle wartość maksymalną i argument armax tablicy 1D:

f: (Float, Int, Float, Int) -> Float, Int
f(max, argmax, value, index):
  if value >= max:
    return (value, index)
  else:
    return (max, argmax)

W przypadku tablic wejściowych 1D V = Float[N], K = Int[N] i wartości inicjowania I_V = Float, I_K = Int wynik f_(N-1) ograniczania w jedynym wymiarze wejściowym jest odpowiednikiem tej aplikacji rekurencyjnej:

f_0 = f(I_V, I_K, V_0, K_0)
f_1 = f(f_0.first, f_0.second, V_1, K_1)
...
f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))

Zastosowanie tego ograniczenia do tablicy wartości i tablicy indeksów sekwencyjnych (tj. jota) spowoduje powtarzanie iteracji w tablicach i zwrócenie kropki zawierającej maksymalną wartość i pasujący indeks.

ReducePrecision

Zobacz też XlaBuilder::ReducePrecision.

Modeluje efekt konwersji wartości zmiennoprzecinkowych na format o niższej dokładności (np. IEEE-FP16) i z powrotem do formatu oryginalnego. Liczbę bitów wykładnika i mantysy w formacie o niższym precyzji można określić samodzielnie. Jednak niektóre rozmiary bitów mogą nie być obsługiwane przez wszystkie implementacje sprzętowe.

ReducePrecision(operand, mantissa_bits, exponent_bits)

Argumenty Typ Semantyka
operand XlaOp tablicy zmiennoprzecinkowej typu T.
exponent_bits int32 liczba bitów wykładniczych w formacie o niższej precyzji
mantissa_bits int32 liczba bitów mantissy w mniej precyzyjnym formacie

Wynik to tablica typu T. Wartości wejściowe są zaokrąglane do najbliższej wartości reprezentowanej przez podaną liczbę bitów mantysy (za pomocą semantyki „równomiernie”), a wszystkie wartości przekraczające zakres określony przez liczbę bitów wykładnika są ograniczane do dodatniej lub ujemnej nieskończoności. Wartości NaN są zachowane, ale mogą zostać przekonwertowane na kanoniczne wartości NaN.

Format o niższej precyzji musi mieć co najmniej 1 bit wykładnika (aby odróżnić wartość 0 od nieskończoności, ponieważ obie mają mantysę zerową) i musi mieć nieujemną liczbę bitów mantysy. Liczba bitów wykładnika lub mantysy może być większa niż odpowiadająca jej wartość dla typu T. Odpowiednia część konwersji to w takim przypadku brak operacji.

ReduceScatter

Zobacz też XlaBuilder::ReduceScatter.

ReduceScatter to operacja zbiorcza, która skutecznie wykonuje operację AllReduce, a następnie rozkłada wynik, dzieląc go na bloki shard_count wzdłuż scatter_dimension, a replika i w grupie replik otrzymuje fragment ith.

ReduceScatter(operand, computation, scatter_dim, shard_count, replica_group_ids, channel_id)

Argumenty Typ Semantyka
operand XlaOp Tablica lub niepusta krotka tablic do zredukowania w replikach.
computation XlaComputation Obliczanie redukcji
scatter_dimension int64 Wymiar do rozproszenia.
shard_count int64 Liczba bloków do podziału: scatter_dimension
replica_groups wektor wektorów elementu int64 Grupy, w których następuje redukcja
channel_id opcjonalnie: int64 Opcjonalny identyfikator kanału na potrzeby komunikacji między modułami
  • Gdy operand to krotka tablicy, funkcja redukcji rozproszenia jest wykonywana w każdym jej elemencie.
  • replica_groups to lista grup replik, między którymi odbywa się redukcja (identyfikator repliki bieżącej repliki można pobrać za pomocą narzędzia ReplicaId). Kolejność replik w każdej grupie określa kolejność, w której zostanie rozproszony wynik całościowy. replica_groups musi być pusta (wówczas wszystkie repliki należą do jednej grupy) lub zawierać taką samą liczbę elementów jak liczba replik. Jeśli istnieje więcej niż 1 grupa replik, wszystkie muszą mieć ten sam rozmiar. Na przykład replica_groups = {0, 2}, {1, 3} przeprowadza redukcję między replikami 0 i 2 oraz 1 i 3, a następnie rozprasza wynik.
  • shard_count to rozmiar każdej grupy replik. Jest on potrzebny w sytuacjach, gdy pole replica_groups jest puste. Jeśli replica_groups nie jest pusty, shard_count musi być równy rozmiarowi każdej grupy replik.
  • channel_id służy do komunikacji między modułami: tylko operacje reduce-scatter z tym samym atrybutem channel_id mogą się ze sobą komunikować.

Kształt wyjściowy to kształt wejściowy, w którym wartość scatter_dimension jest zmniejszona shard_count razy. Jeśli na przykład są 2 repliki, a operand w 2 replikach ma wartość [1.0, 2.25] i [3.0, 5.25], to wartość wyjściowa tej operacji, gdzie scatter_dim to 0, będzie wynosić [4.0] dla pierwszej repliki i [7.5] dla drugiej repliki.

ReduceWindow

Zobacz też XlaBuilder::ReduceWindow.

Stosuje funkcję redukcji do wszystkich elementów w każdym oknie sekwencji N wielowymiarowych tablic, tworząc jako dane wyjściowe jedną lub kropkę N tablic wielowymiarowych. Każda tablica wyjściowa ma taką samą liczbę elementów, jak liczba prawidłowych pozycji okna. Warstwa puli może być wyrażana jako ReduceWindow. Podobnie jak w przypadku Reduce zastosowany computation jest zawsze wyprzedzony init_values po lewej stronie.

ReduceWindow(operands..., init_values..., computation, window_dimensions, window_strides, padding)

Argumenty Typ Semantyka
operands N XlaOps Sekwencja N wielowymiarowych tablic typów T_0,..., T_{N-1}, z których każda reprezentuje obszar bazowy, na którym znajduje się okno.
init_values N XlaOps Początkowe wartości N redukcji, po jednej dla każdego z N operandów. Więcej informacji znajdziesz w sekcji Zmniejszanie.
computation XlaComputation Funkcja redukcji typu T_0, ..., T_{N-1}, T_0, ..., T_{N-1} -> Collate(T_0, ..., T_{N-1}), stosowana do elementów w każdym oknie wszystkich argumentów wejściowych.
window_dimensions ArraySlice<int64> tablica liczb całkowitych dla wartości wymiarów okna
window_strides ArraySlice<int64> tablica liczb całkowitych dla wartości kroku w oknie
base_dilations ArraySlice<int64> tablica liczb całkowitych dla wartości rozszerzania podstawowego
window_dilations ArraySlice<int64> tablica liczb całkowitych dla wartości rozszerzenia okna
padding Padding typ dopełnienia okna (Dopełnienie::kSame, które powoduje, że kształt wyjściowy jest taki sam jak w przypadku kroku 1, lub Dopełnienie::kValid, który nie używa dopełnienia i „zatrzymuje” okno, gdy już zostanie dopasowane)

Gdzie:

  • Wartość N musi być większa od lub równa 1.
  • Wszystkie tablice wejściowe muszą mieć te same wymiary.
  • Jeśli N = 1, Collate(T) ma wartość T.
  • Jeśli N > 1, Collate(T_0, ..., T_{N-1}) jest kropką elementów N typu (T0,...T{N-1}).

Poniżej kodu i ilustracji przedstawiono przykład użycia funkcji ReduceWindow. Dane wejściowe to macierz o wymiarach [4 x 6], a zarówno window_dimensions, jak i window_stride_dimensions to [2x3].

// Create a computation for the reduction (maximum).
XlaComputation max;
{
  XlaBuilder builder(client_, "max");
  auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
  auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
  builder.Max(y, x);
  max = builder.Build().value();
}

// Create a ReduceWindow computation with the max reduction computation.
XlaBuilder builder(client_, "reduce_window_2x3");
auto shape = ShapeUtil::MakeShape(F32, {4, 6});
auto input = builder.Parameter(0, shape, "input");
builder.ReduceWindow(
    input,
    /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
    *max,
    /*window_dimensions=*/{2, 3},
    /*window_stride_dimensions=*/{2, 3},
    Padding::kValid);

Krok 1 w wymiarze określa, że położenie okna w wymiarze jest oddalone o 1 element od sąsiedniego okna. Aby wskazać, że żadne okna nie nakładają się na siebie, parametr window_stride_dimensions powinien mieć wartość window_dimensions. Na rysunku poniżej widać wykorzystanie 2 różnych wartości kroku. Dopełnienie jest stosowane do każdego wymiaru danych wejściowych, a obliczenia są takie same, jak gdyby dane wejściowe zostały dostarczone z wymiarami po dopełnieniu.

W przypadku nieprostego dopełnienia rozważ obliczenie minimalnego okresu ważności (wartość początkowa to MAX_FLOAT) z wymiarem 3 i przejściem 2 na tablicę wejściową [10000, 1000, 100, 10, 1]. Dopełnienie kValid oblicza wartości minimalne w 2 prawidłowych oknach: [10000, 1000, 100] i [100, 10, 1], co daje wynik [100, 1]. Dopełnienie kSame najpierw dopełnia tablicę, tak aby kształt po oknie redukcji był taki sam jak dane wejściowe dla pierwszego kroku. W tym celu dodaj elementy początkowe po obu stronach, uzyskując w ten sposób wartość [MAX_VALUE, 10000, 1000, 100, 10, 1, MAX_VALUE]. Uruchomienie skrócenia okna nad tablicą z dopełnieniem działa na 3 oknach [MAX_VALUE, 10000, 1000], [1000, 100, 10], [10, 1, MAX_VALUE] i zyskach [1000, 10, 1].

Kolejność oceny funkcji redukcji jest arbitralna i może nie być deterministyczna. Dlatego funkcja redukcji nie powinna być zbyt czuła na ponowne powiązanie. Więcej informacji znajdziesz w dyskusji na temat powiązań w kontekście zapytania Reduce.

ReplicaId

Zobacz też XlaBuilder::ReplicaId.

Zwraca unikalny identyfikator (skalarny U32) repliki.

ReplicaId()

Unikalny identyfikator każdej repliki jest nieoznaczoną liczbą całkowitą w przedziale [0, N), gdzie N to liczba replik. Ponieważ wszystkie repliki działają w ramach tego samego programu, wywołanie funkcji ReplicaId() w programie zwróci dla każdej repliki inną wartość.

Zmień kształt

Zobacz też XlaBuilder::Reshape i operację Collapse.

Zmienia wymiary tablicy w nową konfigurację.

Reshape(operand, new_sizes) Reshape(operand, dimensions, new_sizes)

Argumenty Typ Semantyka
operand XlaOp tablica typu T
dimensions Wektor int64 kolejność zwiniętych wymiarów;
new_sizes Wektor int64 wektor rozmiarów nowych wymiarów

W kontekście kształtowanie najpierw spłaszcza tablicę w jednowymiarowy wektor wartości danych, a następnie precyzuje ten wektor do nowego kształtu. Argumenty wejściowe to arbitralna tablica typu T, wektor stałego kompilowania indeksów wymiarów oraz wektor stałego w czasie kompilacji wielkości wymiarów wyniku. Wartości we wektorze dimension (jeśli podano) muszą być permutacją wszystkich wymiarów T. Wartość domyślna, jeśli nie podano, to {0, ..., rank - 1}. Kolejność wymiarów w elemencie dimensions jest zgodna z kolejnością wymiarów od najwolniej (największy) do najbardziej zmieniających się wymiarów (najmniejszy z nich) w zagnieżdżeniu pętli, co zwija tablicę wejściową do jednego wymiaru. Wektor new_sizes określa rozmiar tablicy wyjściowej. Wartość w indeksie 0 w polu new_sizes to rozmiar wymiaru 0, wartość w indeksie 1 to rozmiar wymiaru 1 itd. Iloczyn wymiarów new_size musi być iloczynowi rozmiarów wymiarów operandu. Podczas ulepszania zwiniętej tablicy w tablicę wielowymiarową zdefiniowaną przez funkcję new_sizes wymiary w new_sizes są uporządkowane od najmniejszej zmiany (największa) do najszybciej zmieniającej się (najbardziej drobne).

Na przykład niech v będzie tablicą 24 elementów:

let v = f32[4x2x3] { { {10, 11, 12}, {15, 16, 17} },
                    { {20, 21, 22}, {25, 26, 27} },
                    { {30, 31, 32}, {35, 36, 37} },
                    { {40, 41, 42}, {45, 46, 47} } };

In-order collapse:
let v012_24 = Reshape(v, {0,1,2}, {24});
then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
                         30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};

let v012_83 = Reshape(v, {0,1,2}, {8,3});
then v012_83 == f32[8x3] { {10, 11, 12}, {15, 16, 17},
                          {20, 21, 22}, {25, 26, 27},
                          {30, 31, 32}, {35, 36, 37},
                          {40, 41, 42}, {45, 46, 47} };

Out-of-order collapse:
let v021_24 = Reshape(v, {1,2,0}, {24});
then v012_24 == f32[24]  {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
                          15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47};

let v021_83 = Reshape(v, {1,2,0}, {8,3});
then v021_83 == f32[8x3] { {10, 20, 30}, {40, 11, 21},
                          {31, 41, 12}, {22, 32, 42},
                          {15, 25, 35}, {45, 16, 26},
                          {36, 46, 17}, {27, 37, 47} };


let v021_262 = Reshape(v, {1,2,0}, {2,6,2});
then v021_262 == f32[2x6x2] { { {10, 20}, {30, 40},
                              {11, 21}, {31, 41},
                              {12, 22}, {32, 42} },
                             { {15, 25}, {35, 45},
                              {16, 26}, {36, 46},
                              {17, 27}, {37, 47} } };

W szczególności funkcja zmiany kształtu może przekształcić tablicę jednoelementową w skalarną i odwrotnie. Przykład:

Reshape(f32[1x1] { {5} }, {0,1}, {}) == 5;
Reshape(5, {}, {1,1}) == f32[1x1] { {5} };

Obr. (odwrócone)

Zobacz też XlaBuilder::Rev.

Rev(operand, dimensions)

Argumenty Typ Semantyka
operand XlaOp tablica typu T
dimensions ArraySlice<int64> wymiary do odwrócenia

Odwraca kolejność elementów w tablicy operand wzdłuż określonej wartości dimensions, generując tablicę wyjściową o tym samym kształcie. Każdy element tablicy operandu w indeksie wielowymiarowym jest przechowywany w tablicy wyjściowej w indeksie przekształconym. Indeks wielowymiarowy jest przekształcany przez odwrócenie indeksu w każdym wymiarze, który ma zostać odwrócony (np. jeśli wymiar o rozmiarze N jest jednym z wymiarów odwróconych, jego indeks i zostanie przekształcony w N–1–i).

Operacja Rev polega na odwróceniu tablicy wagi splotu wzdłuż 2 wymiarów okien podczas obliczania gradientu w sieciach neuronowych.

RngNormal

Zobacz też XlaBuilder::RngNormal.

Konstruuje dane wyjściowe określonego kształtu z liczbami losowymi wygenerowanymi po \(N(\mu, \sigma)\) rozkładzie normalnym. Parametry \(\mu\) i \(\sigma\)oraz kształt wyjściowy muszą mieć typ zmiennoprzecinkowy. Dodatkowo parametry muszą mieć wartości skalarne.

RngNormal(mu, sigma, shape)

Argumenty Typ Semantyka
mu XlaOp Skalar typu T określający średnią generowanych liczb
sigma XlaOp Skalar typu T określający odchylenie standardowe wygenerowanego
shape Shape Kształt wyjściowy typu T

RngUniform

Zobacz też XlaBuilder::RngUniform.

Konstruuje dane wyjściowe danego kształtu z liczbami losowymi wygenerowanymi po jednolitym rozkładzie w przedziale \([a,b)\). Parametry i elementy wyjściowe muszą być wartościami logicznymi, całkowitymi lub zmiennoprzecinowymi, a typy muszą być spójne. Backendy procesora i GPU obsługują obecnie tylko modele F64, F32, F16, BF16, S64, U64, S32 i U32. Ponadto parametry muszą mieć wartości skalarne. Jeśli \(b <= a\) wynik jest zdefiniowany w ramach implementacji.

RngUniform(a, b, shape)

Argumenty Typ Semantyka
a XlaOp Skalar typu T określający dolną granicę odstępu
b XlaOp Skalar typu T określający górny limit interwału
shape Shape Kształt wyjściowy typu T

RngBitGenerator

Generuje dane wyjściowe o danym kształcie wypełnionym jednolitymi losowymi bitami za pomocą określonego algorytmu (lub domyślnego backendu) i zwraca zaktualizowany stan (z tym samym kształtem co stan początkowy) oraz wygenerowane dane losowe.

Stan początkowy to początkowy stan bieżącego generowania liczb losowych. Jego wymagany kształt i prawidłowe wartości zależą od zastosowanego algorytmu.

Gwarantujemy, że dane wyjściowe będą deterministyczną funkcją stanu początkowego, ale nie można zagwarantować, że dane wyjściowe będą deterministyczne między backendami a różnymi wersjami kompilatora.

RngBitGenerator(algorithm, key, shape)

Argumenty Typ Semantyka
algorithm RandomAlgorithm Do użycia algorytm PRNG.
initial_state XlaOp Stan początkowy algorytmu PRNG.
shape Shape Kształt wyjściowy wygenerowanych danych.

Dostępne wartości dla parametru algorithm:

Punktowy

Operacja rozkładu XLA generuje sekwencję wyników, które są wartościami tablicy wejściowej operands, przy czym kilka wycinków (w indeksach określonych przez scatter_indices) jest aktualizowanych o sekwencje wartości w updates przy użyciu update_computation.

Zobacz też XlaBuilder::Scatter.

scatter(operands..., scatter_indices, updates..., update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)

Argumenty Typ Semantyka
operands Sekwencja N XlaOp N tablice typów T_0, ..., T_N, na których zostanie rozproszone.
scatter_indices XlaOp Tablica zawierająca początkowe indeksy wycinków, w których muszą być rozproszone.
updates Sekwencja N XlaOp N tablic typów T_0, ..., T_N. updates[i] zawiera wartości, których należy użyć do rozproszenia elementu operands[i].
update_computation XlaComputation Obliczenia używane do łączenia istniejących wartości w tablicy wejściowej z aktualizacjami podczas rozproszenia. Obliczenie powinno być typu T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N).
index_vector_dim int64 Wymiar w tabeli scatter_indices, który zawiera początkowe indeksy.
update_window_dims ArraySlice<int64> Zbiór wymiarów w kształcie updates, które są wymiarami okna.
inserted_window_dims ArraySlice<int64> Zestaw wymiarów okna, które należy wstawić w kształt o kształcie updates.
scatter_dims_to_operand_dims ArraySlice<int64> Wymiary są mapowane od indeksów punktowych na przestrzeń indeksu operacji. Ta tablica jest interpretowana jako mapowanie i na scatter_dims_to_operand_dims[i] . Musi to być jeden do jednego i całość.
indices_are_sorted bool Określa, czy indeksy na pewno zostaną sortowane według elementu wywołującego.
unique_indices bool Czy element wywołujący gwarantuje, że indeksy będą unikalne.

Gdzie:

  • Wartość N musi być większa od lub równa 1.
  • Parametry operands[0], ..., operands[N-1] muszą mieć te same wymiary.
  • Parametry updates[0], ..., updates[N-1] muszą mieć te same wymiary.
  • Jeśli N = 1, Collate(T) ma wartość T.
  • Jeśli N > 1, Collate(T_0, ..., T_N) jest kropką elementów N typu T.

Jeśli index_vector_dim ma wartość scatter_indices.rank, domyślnie uznajemy, że parametr scatter_indices ma na końcu wymiar 1.

Definiujemy update_scatter_dims typu ArraySlice<int64> jako zbiór wymiarów w kształcie updates, które nie są podane w kolejności rosnącej: update_window_dims.

Argumenty punktowe powinny być zgodne z tymi ograniczeniami:

  • Każda tablica updates musi mieć pozycję update_window_dims.size + scatter_indices.rank - 1.

  • Ograniczenia wymiaru i w każdej tablicy updates muszą spełniać te wymagania:

    • Jeśli parametr i występuje w parametrze update_window_dims (tj. równa się update_window_dims[k] w przypadku niektórych elementów k), granica wymiaru i w kolumnie updates nie może przekraczać odpowiadającej mu progu operand po uwzględnieniu parametru inserted_window_dims (tj. adjusted_window_bounds[k], gdzie adjusted_window_bounds zawiera granice operand z usuniętymi granicami w indeksach inserted_window_dims).
    • Jeśli parametr i występuje w parametrze update_scatter_dims (tj. równa się update_scatter_dims[k] w przypadku niektórych elementów k), granica wymiaru i w parametrze updates musi być równa odpowiedniej granicy wymiaru scatter_indices z pominięciem elementu index_vector_dim (czyli scatter_indices.shape.dims[k], jeśli k < index_vector_dim i scatter_indices.shape.dims[k+1])
  • update_window_dims musi być w kolejności rosnącej, nie może zawierać powtarzających się numerów wymiarów i należeć do zakresu [0, updates.rank).

  • inserted_window_dims musi być w kolejności rosnącej, nie może zawierać powtarzających się numerów wymiarów i należeć do zakresu [0, operand.rank).

  • operand.rank musi być równy sumie wartości update_window_dims.size i inserted_window_dims.size.

  • Wartość scatter_dims_to_operand_dims.size musi być równa scatter_indices.shape.dims[index_vector_dim], a jej wartości muszą się mieścić w zakresie [0, operand.rank).

Dla danego indeksu U w każdej tablicy updates odpowiedni indeks I w odpowiedniej tablicy operands, do której należy zastosować tę aktualizację, jest obliczany w następujący sposób:

  1. Niech G = { U[k] dla k w update_scatter_dims }. Użyj funkcji G, aby wyszukać wektor indeksu S w tablicy scatter_indices w taki sposób, że S[i] = scatter_indices[Połącz(G, i)], gdzie Połącz(A, b) wstawia b w pozycji index_vector_dim do A.
  2. Utwórz indeks Sin w lokalizacji operand przy użyciu pola S, rozpraszając S za pomocą mapy scatter_dims_to_operand_dims. Bardziej formalnie:
    1. Sin[scatter_dims_to_operand_dims[k]] = S[k], jeśli k < scatter_dims_to_operand_dims.size.
    2. Sin[_] = 0 w innym przypadku.
  3. Utwórz indeks Win w każdej tablicy operands, rozpraszając indeksy na poziomie update_window_dims w tabeli U zgodnie z inserted_window_dims. Bardziej formalnie:
    1. Win[window_dims_to_operand_dims(k)] = U[k], jeśli k znajduje się w elemencie update_window_dims, gdzie window_dims_to_operand_dims to funkcja monotoniczna z domeną [0, update_window_dims.size) i zakresem [0, operand.rank) \ inserted_window_dims. (Jeśli np. update_window_dims.size to 4, operand.rank to 6, a inserted_window_dims to {0, 2}, to window_dims_to_operand_dims to {01, 13, 24, 35}).
    2. Win[_] = 0 w innym przypadku.
  4. I to Win + Sin, gdzie + oznacza dodawanie elementu.

Podsumowując, operację rozproszoną można zdefiniować w ten sposób.

  • Zainicjuj output z operands, tj. dla wszystkich indeksów J dla wszystkich indeksów O w tablicy operands[J]:
    output[J][O] = operands[J][O]
  • Dla każdego indeksu U w tablicy updates[J] i odpowiadającego mu indeksu O w tablicy operand[J], jeśli O jest prawidłowym indeksem dla output:
    (output[0][O], ..., output[N-1][O]) =update_computation(output[0][O], ..., ,output[N-1][O],updates[0][U], ...,updates[N-1][U])

Kolejność stosowania aktualizacji nie jest deterministyczna. Jeśli wiele indeksów w tabeli updates odnosi się do tego samego indeksu w tabeli operands, odpowiednia wartość w elemencie output będzie niedeterministyczna.

Pamiętaj, że pierwszy parametr przekazywany do funkcji update_computation jest zawsze bieżącą wartością z tablicy output, a drugi – wartością z tablicy updates. Jest to ważne zwłaszcza w przypadkach, gdy właściwość update_computation nie jest przemienna.

Jeśli indices_are_sorted ma wartość Prawda, XLA może przyjąć, że elementy start_indices są posortowane (w kolejności rosnącej start_index_map) według użytkownika. W przeciwnym razie definicja semantyki jest implementowana.

Jeśli unique_indices ma wartość Prawda, XLA może przyjąć, że wszystkie rozproszone elementy są unikalne. XLA może więc wykorzystywać operacje nieatomowe. Jeśli unique_indices ma wartość Prawda, a rozproszone indeksy nie są unikalne, semantyka jest zdefiniowana.

Nieformalnie operacja rozproszona może być postrzegana jako odwrotność operacji zbierania, tj. operacja punktowa aktualizuje elementy w danych wejściowych wyodrębnione przez odpowiednią operację zbierania danych.

Szczegółowy nieformalny opis i przykłady znajdziesz w sekcji „Informacyjny opis” w sekcji Gather.

Wybierz

Zobacz też XlaBuilder::Select.

Konstruuje tablicę wyjściową z elementów 2 tablic wejściowych na podstawie wartości tablicy predykatów.

Select(pred, on_true, on_false)

Argumenty Typ Semantyka
pred XlaOp tablica typu PRED
on_true XlaOp tablica typu T
on_false XlaOp tablica typu T

Tablice on_true i on_false muszą mieć ten sam kształt. Jest to również kształt tablicy wyjściowej. Tablica pred musi mieć tę samą wymiary co on_true i on_false oraz mieć typ elementu PRED.

Dla każdego elementu P elementu pred odpowiedni element tablicy wyjściowej jest pobierany z on_true, jeśli wartością P jest true, lub z on_false, jeśli wartością P jest false. Jako ograniczonej formy transmisji pred może być skalarem typu PRED. W tym przypadku tablica wyjściowa jest pobierana w całości z tabeli on_true, jeśli pred ma wartość true, lub z on_false, jeśli pred to false.

Przykład z nieskalarnym pred:

let pred: PRED[4] = {true, false, false, true};
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};

Przykład ze skalarnym pred:

let pred: PRED = true;
let v1: s32[4] = {1, 2, 3, 4};
let v2: s32[4] = {100, 200, 300, 400};
==>
Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};

Obsługiwany jest wybór krotek. W tym celu krotki są uznawane za typy skalarne. Jeśli on_true i on_false są krotkami (które muszą mieć ten sam kształt), pred musi być skalarem typu PRED.

SelectAndScatter

Zobacz też XlaBuilder::SelectAndScatter.

Tę operację można uznać za operację złożoną, która najpierw oblicza ReduceWindow w tablicy operand, aby wybrać element z każdego okna, a następnie rozkłada tablicę source na indeksy wybranych elementów, aby utworzyć tablicę wyjściową o tym samym kształcie co tablica operandu. Funkcja binarna select służy do wybierania elementu z każdego okna przez zastosowanie go w każdym oknie. Jest ona wywoływana z właściwością, że wektor indeksu pierwszego parametru jest leksykograficznie mniejszy niż wektor indeksu drugiego parametru. Funkcja select zwraca wartość true, jeśli wybrano pierwszy parametr i zwraca false, jeśli wybrano drugi, a funkcja musi utrzymywać przechodność (np. jeśli select(a, b) i select(b, c) mają wartość true, a select(a, c) to również true), tak aby wybrany element nie zależał od kolejności elementów przemierzanych w danym oknie.

Funkcja scatter jest stosowana w każdym wybranym indeksie w tablicy wyjściowej. Przyjmuje 2 parametry skalarne:

  1. Bieżąca wartość w wybranym indeksie w tablicy wyjściowej
  2. Wartość punktowa z pola source, która ma zastosowanie do wybranego indeksu

Łączy on 2 parametry i zwraca wartość skalarną używaną do aktualizacji wartości w wybranym indeksie w tablicy wyjściowej. Początkowo wszystkie indeksy tablicy wyjściowej mają wartość init_value.

Tablica wyjściowa ma ten sam kształt co tablica operand, a tablica source musi mieć ten sam kształt w wyniku zastosowania operacji ReduceWindow do tablicy operand. Za pomocą SelectAndScatter można cofnąć propagację wartości gradientu w warstwie puli w sieci neuronowej.

SelectAndScatter(operand, select, window_dimensions, window_strides, padding, source, init_value, scatter)

Argumenty Typ Semantyka
operand XlaOp tablica typu T, po której przesuwają się okna
select XlaComputation obliczenia binarne typu T, T -> PRED, które zostaną zastosowane do wszystkich elementów w każdym oknie; zwraca wartość true, jeśli wybrano pierwszy parametr, i zwraca false, jeśli wybrano drugi parametr
window_dimensions ArraySlice<int64> tablica liczb całkowitych dla wartości wymiarów okna
window_strides ArraySlice<int64> tablica liczb całkowitych dla wartości kroku w oknie
padding Padding typ dopełnienia okna (Dopełnienie::kSame lub Dopełnienie::kValid)
source XlaOp tablica typu T z wartościami do rozproszenia
init_value XlaOp wartość skalarna typu T dla wartości początkowej tablicy wyjściowej
scatter XlaComputation obliczenia binarne typu T, T -> T, by zastosować każdy rozproszony element źródłowy z jego elementem docelowym;

Na rysunku poniżej widać przykłady użycia funkcji SelectAndScatter, gdzie funkcja select oblicza maksymalną wartość spośród jej parametrów. Pamiętaj, że gdy okna się nakładają, tak jak na ilustracji (2) poniżej, indeks tablicy operand może zostać wybrany kilka razy przez różne okna. Na ilustracji element wartości 9 jest wybierany przez górne okna (niebieski i czerwony), a funkcja dodawania binarnego scatter tworzy element wyjściowy o wartości 8 (2 + 6).

Kolejność oceny funkcji scatter jest dowolna i może nie być deterministyczna. Dlatego funkcja scatter nie powinna być zbyt wrażliwa na ponowne powiązanie. Więcej informacji znajdziesz w dyskusji na temat powiązań w kontekście zapytania Reduce.

Wyślij

Zobacz też XlaBuilder::Send.

Send(operand, channel_handle)

Argumenty Typ Semantyka
operand XlaOp dane do wysłania (tablica typu T)
channel_handle ChannelHandle unikalny identyfikator dla każdej pary wysyłania/odbierania.

Wysyła dane argumentu do instrukcji Recv w innych obliczeniach, które mają ten sam uchwyt kanału. Nie zwraca żadnych danych.

Podobnie jak w przypadku Recv, interfejs API klienta Send reprezentuje komunikację synchroniczną i jest wewnętrznie podzielony na 2 instrukcje HLO (Send i SendDone), aby umożliwić asynchroniczne przesyłanie danych. Zobacz też HloInstruction::CreateSend i HloInstruction::CreateSendDone.

Send(HloInstruction operand, int64 channel_id)

Inicjuje asynchroniczne przesyłanie operandu do zasobów przydzielonych przez instrukcję Recv o tym samym identyfikatorze kanału. Zwraca kontekst, który jest używany przez następującą instrukcję SendDone do oczekiwania na zakończenie przenoszenia danych. Kontekst jest kropką {operand (kształt), identyfikatorem żądania (U32)}, której można użyć tylko w instrukcji SendDone.

SendDone(HloInstruction context)

Biorąc pod uwagę kontekst utworzony przez instrukcję Send, czeka na zakończenie przenoszenia danych. Instrukcja nie zwraca żadnych danych.

Przygotowanie instrukcji dotyczących kanału

Kolejność wykonania 4 instrukcji dla każdego kanału (Recv, RecvDone, Send, SendDone) jest taka jak poniżej.

  • Recv ma miejsce przed Send
  • Send ma miejsce przed RecvDone
  • Recv ma miejsce przed RecvDone
  • Send ma miejsce przed SendDone

Gdy kompilatory backendu generują liniowy harmonogram dla każdego obliczenia, które komunikują się zgodnie z instrukcjami dotyczącymi kanału, nie mogą występować cykle między obliczeniami. Na przykład poniższe harmonogramy prowadzą do zakleszczeń.

Wycinek

Zobacz też XlaBuilder::Slice.

Wyodrębnianie powoduje wyodrębnienie tablicy podrzędnej z tablicy wejściowej. Tablica podrzędna ma taką samą pozycję jak dane wejściowe i zawiera wartości wewnątrz ramki ograniczającej w tablicy wejściowej, gdzie wymiary i indeksy ramki ograniczającej są podawane jako argumenty operacji wycinka.

Slice(operand, start_indices, limit_indices, strides)

Argumenty Typ Semantyka
operand XlaOp N tablica wymiarowa typu T
start_indices ArraySlice<int64> Lista N liczb całkowitych zawierających początkowe indeksy wycinka dla każdego wymiaru. Wartości nie mogą być większe niż zero.
limit_indices ArraySlice<int64> Lista N liczb całkowitych zawierających końcowe indeksy (wyłącznie) dla wycinka każdego wymiaru. Każda wartość musi być równa lub większa od odpowiedniej wartości start_indices dla wymiaru i mniejsza lub równa rozmiarowi wymiaru.
strides ArraySlice<int64> Lista N liczb całkowitych, która określa krok wejściowy wycinka. Wycinek wybiera każdy element strides[d] w wymiarze d.

Przykład jednowymiarowego:

let a = {0.0, 1.0, 2.0, 3.0, 4.0}
Slice(a, {2}, {4}) produces:
  {2.0, 3.0}

Przykład obrazu dwuwymiarowego:

let b =
 { {0.0,  1.0,  2.0},
   {3.0,  4.0,  5.0},
   {6.0,  7.0,  8.0},
   {9.0, 10.0, 11.0} }

Slice(b, {2, 1}, {4, 3}) produces:
  { { 7.0,  8.0},
    {10.0, 11.0} }

Sortuj

Zobacz też XlaBuilder::Sort.

Sort(operands, comparator, dimension, is_stable)

Argumenty Typ Semantyka
operands ArraySlice<XlaOp> operandy do sortowania.
comparator XlaComputation Wartość porównawcza do użycia.
dimension int64 Wymiar, według którego ma być sortowane.
is_stable bool Określa, czy należy używać stabilnego sortowania.

Jeśli podany jest tylko 1 operand:

  • Jeśli operand jest tensorem o określonej pozycji (tablica), wynik jest posortowaną tablicą. Jeśli chcesz posortować tablicę w kolejności rosnącej, komparator powinien wykonać porównanie mniejsze niż w przypadku. Ogólnie po posortowaniu tablicy obowiązuje ona dla wszystkich pozycji indeksu i, j z wartością i < j, która może przyjmować wartość comparator(value[i], value[j]) = comparator(value[j], value[i]) = false lub comparator(value[i], value[j]) = true.

  • Jeśli operand ma wyższą pozycję w rankingu, jest on posortowany według podanego wymiaru. Na przykład w przypadku tensora rankingu 2 (macierzy) wartość wymiaru 0 będzie niezależnie sortować każdą kolumnę, a wartość wymiaru 1 – niezależnie od sortowania każdego wiersza. Jeśli nie podasz numeru wymiaru, domyślnie zostanie wybrany ostatni wymiar. W przypadku posortowanego wymiaru obowiązuje ta sama kolejność sortowania jak w przypadku argumentu „ranking-1”.

Jeśli podane są operandy n > 1:

  • Wszystkie operandy n muszą być tensorami o tych samych wymiarach. Typy tensorów mogą być różne.

  • Wszystkie operandy są sortowane razem, a nie pojedynczo. operandy są z założenia traktowane jako krotka. Podczas sprawdzania, czy elementy każdego argumentu na pozycjach indeksu i i j muszą być zamienne, komparator jest wywoływany z parametrami skalarnymi 2 * n, gdzie parametr 2 * k odpowiada wartości na pozycji i z operandu k-th, a parametr 2 * k + 1 odpowiada wartości na pozycji j z operandu k-th. Zwykle komparator porównuje ze sobą parametry 2 * k i 2 * k + 1, a potem wykorzystuje inne pary parametrów jako elementy decydujące.

  • Wynikiem jest krotka, która składa się z operandów w posortowanej kolejności (wzdłuż podanego wymiaru, jak powyżej). Argument i-th krotki odpowiada operandowi i-th elementu Sort.

Jeśli np. są 3 operandy operand0 = [3, 1], operand1 = [42, 50] i operand2 = [-3.0, 1.1], a komparator porównuje tylko wartości argumentu operand0 z mniejszym niż, wynikiem sortowania jest kropka ([1, 3], [50, 42], [1.1, -3.0]).

Jeśli zasada is_stable ma wartość Prawda, sortowanie ma gwarantowaną stabilność, czyli w przypadku elementów, których komparator uznaje, że są równe, następuje względna kolejność równych wartości. Dwa elementy e1 i e2 są równe tylko wtedy, gdy comparator(e1, e2) = comparator(e2, e1) = false. Domyślnie zasada is_stable ma wartość Fałsz.

Transponuj

Zobacz też operację tf.reshape.

Transpose(operand)

Argumenty Typ Semantyka
operand XlaOp Argument do transponowania.
permutation ArraySlice<int64> Jak permutować wymiary.

Permutuje wymiary operandu z daną permutacją, więc ∀ i . 0 ≤ i < rank ⇒ input_dimensions[permutation[i]] = output_dimensions[i].

Ta funkcja działa tak samo jak funkcja Reshape(operand, permutation, Permute(permutation, operand.shape.dimensions)).

TriangularSolve

Zobacz też XlaBuilder::TriangularSolve.

Rozwiązanie układów równań liniowych o niższym lub górnym współczynniku trójkąta przez podstawienie do przodu lub wstecz. Ta rutyna przesyła dane zgodnie z wymiarami wiodącymi i rozwiązuje jeden z układów macierzy op(a) * x = b, czyli x * op(a) = b, dla zmiennej x z danymi a i b, gdzie op(a) to op(a) = a, op(a) = Transpose(a) lub op(a) = Conj(Transpose(a)).

TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)

Argumenty Typ Semantyka
a XlaOp tablica rankingowa > 2 w przypadku typu złożonego lub zmiennoprzecinkowego o kształcie [..., M, M].
b XlaOp tablica rankingowa > 2 tego samego typu z kształtem [..., M, K], jeśli zasada left_side ma wartość prawda. W przeciwnym razie [..., K, M].
left_side bool wskazuje, czy należy rozwiązać układ o postaci op(a) * x = b (true) czy x * op(a) = b (false).
lower bool czy użyć górnego czy dolnego trójkąta a.
unit_diagonal bool jeśli jest ustawiona wartość true, przyjmuje się, że elementy a po przekątnej mają wartość 1 i nie są one otwierane.
transpose_a Transpose czy użyć funkcji a w obecnej postaci, przetransponować ją czy zastosować jej sprzężenie.

Dane wejściowe są odczytywane tylko z trójkąta dolnego/górnego a w zależności od wartości lower. Wartości z drugiego trójkąta są ignorowane. Dane wyjściowe są zwracane w tym samym trójkącie. Wartości w drugim trójkącie są zdefiniowane przez implementację i mogą być dowolne.

Jeśli pozycja elementów a i b jest większa niż 2, są one traktowane jako grupy matryc, gdzie wszystkie oprócz 2 podrzędnych wymiarów są wymiarami wsadu. Elementy a i b muszą mieć jednakowe wymiary wsadu.

Kropka

Zobacz też XlaBuilder::Tuple.

Kropka zawierająca zmienną liczbę nicków danych, z których każdy ma własny kształt.

Jest to odpowiednik std::tuple w C++. Ogólnie rzecz biorąc:

let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
let s: s32 = 5;
let t: (f32[10], s32) = tuple(v, s);

Kropki można dekonstruować (czyli uzyskać do nich dostęp) za pomocą operacji GetTupleElement.

I jednocześnie

Zobacz też XlaBuilder::While.

While(condition, body, init)

Argumenty Typ Semantyka
condition XlaComputation XlaComputation typu T -> PRED, który określa warunek zakończenia pętli.
body XlaComputation XlaComputation typu T -> T, który określa treść pętli.
init T Początkowa wartość parametru condition i body.

Wykonuje kolejno body, aż do wystąpienia błędu condition. Działa to podobnie do typowej pętli podczas wykonywania w wielu innych językach, z wyjątkiem różnic i ograniczeń wymienionych poniżej.

  • Węzeł While zwraca wartość typu T, która jest wynikiem ostatniego wykonania metody body.
  • Kształt typu T jest określony statycznie i musi być taki sam we wszystkich iteracjach.

Parametry T obliczeń są inicjowane z wartością init w pierwszej iteracji, a każde kolejne są automatycznie aktualizowane do nowego wyniku z body.

Jednym z głównych przypadków użycia węzła While jest wdrożenie powtarzającego się trenowania w sieciach neuronowych. Poniżej znajduje się uproszczony pseudokod z wykresem, który odzwierciedla obliczenie. Kod znajdziesz tutaj: while_test.cc. W tym przykładzie typ T to typ Tuple składający się z int32 oznaczający liczbę iteracji i vector[10] dla zasobnika. W przypadku 1000 iteracji pętla dodaje do akumulatora stały wektor.

// Pseudocode for the computation.
init = {0, zero_vector[10]} // Tuple of int32 and float[10].
result = init;
while (result(0) < 1000) {
  iteration = result(0) + 1;
  new_vector = result(1) + constant_vector[10];
  result = {iteration, new_vector};
}