Zrozumienie einsum NumPy

190

Próbuję zrozumieć, jak dokładnie einsumdziała. Przejrzałem dokumentację i kilka przykładów, ale wydaje się, że się nie trzyma.

Oto przykład, który przeszliśmy w klasie:

C = np.einsum("ij,jk->ki", A, B)

dla dwóch tablic AiB

Myślę, że to zajmie A^T * B, ale nie jestem pewien (czy transponowanie jednego z nich jest prawidłowe?). Czy ktoś może poprowadzić mnie dokładnie przez to, co się tutaj dzieje (i ogólnie podczas korzystania einsum)?

Cieśnina Lance'a
źródło
7
W rzeczywistości tak będzie (A * B)^Tlub równoważnie B^T * A^T.
Tigran Saluev
21
Napisałem krótki post na blogu o podstawach einsum tutaj . (Z przyjemnością przeszczepię najbardziej odpowiednie bity do odpowiedzi na temat przepełnienia stosu, jeśli jest to przydatne).
Alex Riley
1
@ajcr - piękny link. Dzięki. numpyDokumentacja jest zdecydowanie niewystarczająca przy wyjaśnianiu szczegółów.
rayryeng
Dziękujemy za wotum zaufania! Z opóźnieniem udzieliłem odpowiedzi poniżej .
Alex Riley,
Zauważ, że w Pythonie *nie jest to mnożenie macierzy, ale mnożenie elementarne. Uważaj!
ComputerScientist

Odpowiedzi:

368

(Uwaga: ta odpowiedź jest oparta na krótkim wpisie na blogu, o którym einsumnapisałem jakiś czas temu).

Co ma einsumzrobić?

Wyobraź sobie, że mamy dwie tablice wielowymiarowe Ai B. Załóżmy teraz, że chcemy ...

  • pomnóż A ze Bw określony sposób, aby stworzyć nową gamę produktów; a potem może
  • zsumuj tę nową tablicę wzdłuż poszczególnych osi; a potem może
  • transponować osie nowej tablicy w określonej kolejności.

Istnieje duża szansa, że einsumpomoże nam to zrobić szybciej i bardziej wydajnie pod względem pamięci niż kombinacje funkcji NumPymultiply , sumi transposepozwoli.

Jak einsumdziała

Oto prosty (ale nie całkowicie trywialny) przykład. Weź następujące dwie tablice:

A = np.array([0, 1, 2])

B = np.array([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]])

Pomnożymy Ai Belementarnie, a następnie zsumujemy wzdłuż wierszy nowej tablicy. W „normalnym” NumPy pisalibyśmy:

>>> (A[:, np.newaxis] * B).sum(axis=1)
array([ 0, 22, 76])

Więc tutaj operacja indeksowania włączona A liniach pierwszych osi dwóch tablic, aby można było nadawać mnożenie. Rzędy szeregu produktów są następnie sumowane w celu zwrócenia odpowiedzi.

Teraz, jeśli chcielibyśmy użyć einsumzamiast tego, moglibyśmy napisać:

>>> np.einsum('i,ij->i', A, B)
array([ 0, 22, 76])

Podpis ciąg 'i,ij->i'jest kluczem tutaj i potrzebuje trochę wyjaśnień. Możesz myśleć o tym w dwóch połowach. Po lewej stronie (po lewej stronie ->) oznaczyliśmy dwie tablice wejściowe. Po prawej stronie ->oznaczyliśmy tablicę, z którą chcemy się skończyć.

Oto, co stanie się potem:

  • Ama jedną oś; oznaczyliśmy to i. I Bma dwie osie; oznaczyliśmy oś 0 jako ii oś 1 jako j.

  • Poprzez powtarzanie etykietę iw obu tablic wejściowych, mówimy einsum, że te dwie osie należy pomnożyć razem. Innymi słowy, mnożymy tablicę Az każdą kolumną tablicy B, podobnie jakA[:, np.newaxis] * B robi.

  • Zauważ, że jnie pojawia się jako etykieta w naszym pożądanym wyniku; właśnie użyliśmy i(chcemy skończyć z tablicą 1D). Przez pomijając etykietę, mówimy einsumdo podsumowania wzdłuż tej osi. Innymi słowy, sumujemy rzędy produktów, podobnie jak .sum(axis=1)robi.

To właściwie wszystko, co musisz wiedzieć, aby używać einsum. Pomaga trochę się bawić; jeśli zostawimy obie etykiety w danych wyjściowych, 'i,ij->ij'otrzymamy tablicę 2D produktów (taką samą jak A[:, np.newaxis] * B). Jeśli powiemy, że nie ma etykiet wyjściowych, 'i,ij->otrzymamy pojedynczy numer (taki sam jak czynność (A[:, np.newaxis] * B).sum()).

Wspaniałą rzeczą einsumjest to, że nie tworzy najpierw tymczasowego zestawu produktów; po prostu sumuje produkty na bieżąco. Może to prowadzić do dużych oszczędności w zużyciu pamięci.

Nieco większy przykład

Aby wyjaśnić iloczyn skalarny, oto dwie nowe tablice:

A = array([[1, 1, 1],
           [2, 2, 2],
           [5, 5, 5]])

B = array([[0, 1, 0],
           [1, 1, 0],
           [1, 1, 1]])

Obliczymy iloczyn skalarny za pomocą np.einsum('ij,jk->ik', A, B). Oto zdjęcie przedstawiające etykietę Ai Boraz tablicę wyjściową, którą otrzymujemy z funkcji:

wprowadź opis zdjęcia tutaj

Widać, że etykieta jsię powtarza - oznacza to, że mnożymy wiersze Az kolumnami B. Ponadto etykieta jnie jest uwzględniona w danych wyjściowych - podsumowujemy te produkty. Etykietyi i ksą przechowywane dla danych wyjściowych, więc otrzymujemy tablicę 2D.

Jeszcze bardziej przejrzyste może być porównanie tego wyniku z tablicą, w której etykieta niej jest sumowana. Poniżej po lewej stronie widać tablicę 3D, która powstaje podczas pisania (tzn. Zachowaliśmy etykietę ):np.einsum('ij,jk->ijk', A, B)j

wprowadź opis zdjęcia tutaj

Oś sumująca j daje oczekiwany iloczyn punktowy pokazany po prawej stronie.

Niektóre ćwiczenia

Aby poczuć więcej einsum , przydatne może być zaimplementowanie znanych operacji tablicowych NumPy za pomocą notacji indeksu dolnego. Wszystko, co obejmuje kombinacje mnożenia i sumowania osi, można zapisać za pomocą einsum .

Niech A i B będą dwiema tablicami 1D o tej samej długości. Na przykład A = np.arange(10)iB = np.arange(5, 15) .

  • Suma Amoże być zapisana:

    np.einsum('i->', A)
  • Mnożenie elementarne A * B, można zapisać:

    np.einsum('i,i->i', A, B)
  • Produkt wewnętrzny lub produkt kropkowy np.inner(A, B)lub np.dot(A, B)można zapisać:

    np.einsum('i,i->', A, B) # or just use 'i,i'
  • Produkt zewnętrzny np.outer(A, B)można napisać:

    np.einsum('i,j->ij', A, B)

Dla tablic 2D CiD pod warunkiem, że osie mają zgodne długości (zarówno ta sama długość, jak i jedna z nich ma długość 1), oto kilka przykładów:

  • Ślad C(suma głównej przekątnej) np.trace(C)można zapisać:

    np.einsum('ii', C)
  • Element mnożenie Ci transpozycją D, C * D.Tmożna zapisać:

    np.einsum('ij,ji->ij', C, D)
  • Mnożąc każdy element Cprzez tablicę D(aby utworzyć tablicę 4D) C[:, :, None, None] * D, można zapisać:

    np.einsum('ij,kl->ijkl', C, D)  
Alex Riley
źródło
1
Bardzo ładne wyjaśnienie, dzięki. „Zauważ, że nie pojawia się jako etykieta w naszym pożądanym wyniku” - prawda?
Ian Hincks,
Dzięki @IanHincks! To wygląda na literówkę; Poprawiłem to teraz.
Alex Riley,
1
Bardzo dobra odpowiedź. Warto również zauważyć, że ij,jkmoże działać samodzielnie (bez strzałek), tworząc mnożenie macierzy. Ale wydaje się, że dla jasności najlepiej jest umieścić strzałki, a następnie wymiary wyjściowe. To jest na blogu.
ComputerScientist
1
@Peaceful: jest to jedna z tych sytuacji, w których trudno jest wybrać właściwe słowo! Wydaje mi się, że „kolumna” pasuje tutaj trochę lepiej, ponieważ Ama długość 3, taką samą jak długość kolumn B(podczas gdy rzędy Bmają długość 4 i nie można ich pomnożyć przez element A).
Alex Riley,
1
Zauważ, że pominięcie ->wpływu ma wpływ na semantykę: „W trybie niejawnym wybrane indeksy dolne są ważne, ponieważ osie danych wyjściowych są uporządkowane alfabetycznie. Oznacza to, że np.einsum('ij', a)nie wpływa na tablicę 2D, a np.einsum('ji', a)dokonuje jej transpozycji”.
BallpointBen
40

Uchwycenie idei numpy.einsum()jest bardzo łatwe, jeśli zrozumiesz ją intuicyjnie. Na przykład zacznijmy od prostego opisu dotyczącego mnożenia macierzy .


Aby użyć numpy.einsum(), wystarczy przekazać jako argument tak zwany ciąg indeksu dolnego , a następnie tablice wejściowe .

Powiedzmy, że masz dwie tablice 2D, Ai B, i chcesz zrobić mnożenia macierzy. Więc robisz:

np.einsum("ij, jk -> ik", A, B)

Tutaj łańcuch dolny ij odpowiada tablicy, Apodczas gdy łańcuch dolny jk odpowiada tablicy B. Najważniejszą rzeczą, na którą należy tutaj zwrócić uwagę, jest to, że liczba znaków w każdym łańcuchu indeksu dolnego musi pasować do wymiarów tablicy. (tj. dwa znaki dla tablic 2D, trzy znaki dla tablic 3D itd.) A jeśli powtórzysz znaki między łańcuchami indeksu dolnego ( jw naszym przypadku), oznacza to, że chcesz, aby einsuma występowała wzdłuż tych wymiarów. W ten sposób zostaną zmniejszone sumy. (tzn. ten wymiar zniknie )

Dolny ciąg po tym ->, będzie naszym wypadkowa tablicą. Jeśli pozostawisz to puste, wszystko zostanie zsumowane, w wyniku czego zostanie zwrócona wartość skalarna. W przeciwnym razie wynikowa tablica będzie miała wymiary zgodne z ciągiem indeksu dolnego . W naszym przykładzie tak będzie ik. Jest to intuicyjne, ponieważ wiemy, że do mnożenia macierzy liczba kolumn w tablicy Amusi odpowiadać liczbie wierszy w tablicy, Bco się tutaj dzieje (tj. Kodujemy tę wiedzę, powtarzając znak jw łańcuchu indeksu dolnego )


Oto kilka innych przykładów ilustrujących wykorzystanie / moc np.einsum()wdrażania niektórych typowych operacji tensorowych lub nd-macierzowych , zwięźle.

Wejścia

# a vector
In [197]: vec
Out[197]: array([0, 1, 2, 3])

# an array
In [198]: A
Out[198]: 
array([[11, 12, 13, 14],
       [21, 22, 23, 24],
       [31, 32, 33, 34],
       [41, 42, 43, 44]])

# another array
In [199]: B
Out[199]: 
array([[1, 1, 1, 1],
       [2, 2, 2, 2],
       [3, 3, 3, 3],
       [4, 4, 4, 4]])

1) Mnożenie macierzy (podobne do np.matmul(arr1, arr2))

In [200]: np.einsum("ij, jk -> ik", A, B)
Out[200]: 
array([[130, 130, 130, 130],
       [230, 230, 230, 230],
       [330, 330, 330, 330],
       [430, 430, 430, 430]])

2) Wyodrębnij elementy wzdłuż głównej przekątnej (podobnie do np.diag(arr))

In [202]: np.einsum("ii -> i", A)
Out[202]: array([11, 22, 33, 44])

3) Produkt Hadamarda (tj. Elementarny produkt dwóch tablic) (podobny do arr1 * arr2)

In [203]: np.einsum("ij, ij -> ij", A, B)
Out[203]: 
array([[ 11,  12,  13,  14],
       [ 42,  44,  46,  48],
       [ 93,  96,  99, 102],
       [164, 168, 172, 176]])

4) Elementarne kwadraty (podobne do np.square(arr)lub arr ** 2)

In [210]: np.einsum("ij, ij -> ij", B, B)
Out[210]: 
array([[ 1,  1,  1,  1],
       [ 4,  4,  4,  4],
       [ 9,  9,  9,  9],
       [16, 16, 16, 16]])

5) Ślad (tj. Suma elementów głównych przekątnych) (podobny do np.trace(arr))

In [217]: np.einsum("ii -> ", A)
Out[217]: 110

6) Transpozycja macierzy (podobna do np.transpose(arr))

In [221]: np.einsum("ij -> ji", A)
Out[221]: 
array([[11, 21, 31, 41],
       [12, 22, 32, 42],
       [13, 23, 33, 43],
       [14, 24, 34, 44]])

7) Produkt zewnętrzny (wektorów) (podobny do np.outer(vec1, vec2))

In [255]: np.einsum("i, j -> ij", vec, vec)
Out[255]: 
array([[0, 0, 0, 0],
       [0, 1, 2, 3],
       [0, 2, 4, 6],
       [0, 3, 6, 9]])

8) Produkt wewnętrzny (wektorów) (podobny do np.inner(vec1, vec2))

In [256]: np.einsum("i, i -> ", vec, vec)
Out[256]: 14

9) Suma wzdłuż osi 0 (podobnie do np.sum(arr, axis=0))

In [260]: np.einsum("ij -> j", B)
Out[260]: array([10, 10, 10, 10])

10) Suma wzdłuż osi 1 (podobnie do np.sum(arr, axis=1))

In [261]: np.einsum("ij -> i", B)
Out[261]: array([ 4,  8, 12, 16])

11) Mnożenie macierzy partii

In [287]: BM = np.stack((A, B), axis=0)

In [288]: BM
Out[288]: 
array([[[11, 12, 13, 14],
        [21, 22, 23, 24],
        [31, 32, 33, 34],
        [41, 42, 43, 44]],

       [[ 1,  1,  1,  1],
        [ 2,  2,  2,  2],
        [ 3,  3,  3,  3],
        [ 4,  4,  4,  4]]])

In [289]: BM.shape
Out[289]: (2, 4, 4)

# batch matrix multiply using einsum
In [292]: BMM = np.einsum("bij, bjk -> bik", BM, BM)

In [293]: BMM
Out[293]: 
array([[[1350, 1400, 1450, 1500],
        [2390, 2480, 2570, 2660],
        [3430, 3560, 3690, 3820],
        [4470, 4640, 4810, 4980]],

       [[  10,   10,   10,   10],
        [  20,   20,   20,   20],
        [  30,   30,   30,   30],
        [  40,   40,   40,   40]]])

In [294]: BMM.shape
Out[294]: (2, 4, 4)

12) Suma wzdłuż osi 2 (podobnie do np.sum(arr, axis=2))

In [330]: np.einsum("ijk -> ij", BM)
Out[330]: 
array([[ 50,  90, 130, 170],
       [  4,   8,  12,  16]])

13) Zsumuj wszystkie elementy w tablicy (podobnie do np.sum(arr))

In [335]: np.einsum("ijk -> ", BM)
Out[335]: 480

14) Suma na wielu osiach (tj. Marginalizacja)
(podobnie do np.sum(arr, axis=(axis0, axis1, axis2, axis3, axis4, axis6, axis7)))

# 8D array
In [354]: R = np.random.standard_normal((3,5,4,6,8,2,7,9))

# marginalize out axis 5 (i.e. "n" here)
In [363]: esum = np.einsum("ijklmnop -> n", R)

# marginalize out axis 5 (i.e. sum over rest of the axes)
In [364]: nsum = np.sum(R, axis=(0,1,2,3,4,6,7))

In [365]: np.allclose(esum, nsum)
Out[365]: True

15) Produkty z podwójną kropką ( podobne do np. Suma (produkt hadamardowy) porównaj 3 )

In [772]: A
Out[772]: 
array([[1, 2, 3],
       [4, 2, 2],
       [2, 3, 4]])

In [773]: B
Out[773]: 
array([[1, 4, 7],
       [2, 5, 8],
       [3, 6, 9]])

In [774]: np.einsum("ij, ij -> ", A, B)
Out[774]: 124

16) Mnożenie tablic 2D i 3D

Takie mnożenie może być bardzo przydatne przy rozwiązywaniu liniowego układu równań ( Ax = b ), w którym chcesz zweryfikować wynik.

# inputs
In [115]: A = np.random.rand(3,3)
In [116]: b = np.random.rand(3, 4, 5)

# solve for x
In [117]: x = np.linalg.solve(A, b.reshape(b.shape[0], -1)).reshape(b.shape)

# 2D and 3D array multiplication :)
In [118]: Ax = np.einsum('ij, jkl', A, x)

# indeed the same!
In [119]: np.allclose(Ax, b)
Out[119]: True

Wręcz przeciwnie, jeśli trzeba użyć np.matmul()tej weryfikacji, musimy wykonać kilka reshapeoperacji, aby osiągnąć ten sam wynik, jak:

# reshape 3D array `x` to 2D, perform matmul
# then reshape the resultant array to 3D
In [123]: Ax_matmul = np.matmul(A, x.reshape(x.shape[0], -1)).reshape(x.shape)

# indeed correct!
In [124]: np.allclose(Ax, Ax_matmul)
Out[124]: True

Bonus : Przeczytaj więcej matematyki tutaj: Podsumowanie Einsteina i zdecydowanie tutaj: Notacja Tensor

kmario23
źródło
7

Zróbmy 2 tablice o różnych, ale kompatybilnych wymiarach, aby podkreślić ich wzajemne oddziaływanie

In [43]: A=np.arange(6).reshape(2,3)
Out[43]: 
array([[0, 1, 2],
       [3, 4, 5]])


In [44]: B=np.arange(12).reshape(3,4)
Out[44]: 
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])

Twoje obliczenia wymagają „kropki” (sumy produktów) (2,3) z (3,4), aby utworzyć tablicę (4,2). ijest pierwszym przyciemnieniem A, ostatnim z C; kostatni z B, 1 z C. jjest „pochłonięty” przez sumowanie.

In [45]: C=np.einsum('ij,jk->ki',A,B)
Out[45]: 
array([[20, 56],
       [23, 68],
       [26, 80],
       [29, 92]])

To jest to samo co np.dot(A,B).T - transponowane jest końcowe wyjście.

Aby zobaczyć więcej tego, co się dzieje j, zmień Cindeksy dolne na ijk:

In [46]: np.einsum('ij,jk->ijk',A,B)
Out[46]: 
array([[[ 0,  0,  0,  0],
        [ 4,  5,  6,  7],
        [16, 18, 20, 22]],

       [[ 0,  3,  6,  9],
        [16, 20, 24, 28],
        [40, 45, 50, 55]]])

Można to również wyprodukować z:

A[:,:,None]*B[None,:,:]

Oznacza to, że dodaj kwymiar na końcu Ai ido przodu B, uzyskując tablicę (2,3,4).

0 + 4 + 16 = 20, 9 + 28 + 55 = 92itp .; Suma ji transpozycja, aby uzyskać wcześniejszy wynik:

np.sum(A[:,:,None] * B[None,:,:], axis=1).T

# C[k,i] = sum(j) A[i,j (,k) ] * B[(i,)  j,k]
hpaulj
źródło
7

Uważam, że NumPy: Sztuczki handlu (część II) są pouczające

Używamy ->, aby wskazać kolejność tablicy wyjściowej. Pomyśl więc o „ij, i-> j” jako o lewej stronie (LHS) i prawej stronie (RHS). Każde powtórzenie etykiet na LHS oblicza mądrze element produktu, a następnie sumuje. Zmieniając etykietę po stronie RHS (wyjściowej), możemy zdefiniować oś, w której chcemy kontynuować w odniesieniu do tablicy wejściowej, tj. Sumowanie wzdłuż osi 0, 1 i tak dalej.

import numpy as np

>>> a
array([[1, 1, 1],
       [2, 2, 2],
       [3, 3, 3]])
>>> b
array([[0, 1, 2],
       [3, 4, 5],
       [6, 7, 8]])
>>> d = np.einsum('ij, jk->ki', a, b)

Zauważ, że istnieją trzy osie, i, j, k, i że j jest powtarzane (po lewej stronie). i,jreprezentują wiersze i kolumny dla a. j,kdla b.

Aby obliczyć iloczyn i wyrównać joś, musimy dodać oś a. ( bbędzie nadawany wzdłuż (?) pierwszej osi)

a[i, j, k]
   b[j, k]

>>> c = a[:,:,np.newaxis] * b
>>> c
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 0,  2,  4],
        [ 6,  8, 10],
        [12, 14, 16]],

       [[ 0,  3,  6],
        [ 9, 12, 15],
        [18, 21, 24]]])

jjest nieobecny z prawej strony, więc sumujemy, nad jktórą jest druga oś tablicy 3x3x3

>>> c = c.sum(1)
>>> c
array([[ 9, 12, 15],
       [18, 24, 30],
       [27, 36, 45]])

Wreszcie indeksy są (alfabetycznie) odwrócone po prawej stronie, więc dokonujemy transpozycji.

>>> c.T
array([[ 9, 18, 27],
       [12, 24, 36],
       [15, 30, 45]])

>>> np.einsum('ij, jk->ki', a, b)
array([[ 9, 18, 27],
       [12, 24, 36],
       [15, 30, 45]])
>>>
wwii
źródło
NumPy: Sztuczki handlu (część II) wydają się wymagać zaproszenia od właściciela strony, a także konta Wordpress
Tejas Shetty
... zaktualizowałem link, na szczęście znalazłem go podczas wyszukiwania. - Dziękuję.
wwii
@TejasShetty Tutaj jest wiele lepszych odpowiedzi - może powinienem usunąć tę.
wwii
2
Nie usuwaj swojej odpowiedzi.
Tejas Shetty
5

Czytając równania einsum, najbardziej pomocne było po prostu sprowadzenie ich mentalnie do ich imperatywnych wersji.

Zacznijmy od następującej (nakładającej) instrukcji:

C = np.einsum('bhwi,bhwj->bij', A, B)

Najpierw sprawdzając znaki interpunkcyjne, widzimy, że mamy dwa 4-literowe kropki oddzielone przecinkami - bhwia bhwjprzed strzałką i jeden 3-literowy kropelka bijpo nim. Dlatego równanie daje wynik tensora rangi 3 z dwóch wejść tensora rangi 4.

Teraz niech każda litera w każdym obiekcie blob będzie nazwą zmiennej zakresu. Pozycja, w której litera pojawia się w kropli, jest indeksem osi, nad którą się onaga w tym tensorze. Dlatego sumowanie imperatywne, które wytwarza każdy element C, musi zaczynać się od trzech zagnieżdżonych dla pętli, po jednym dla każdego indeksu C.

for b in range(...):
    for i in range(...):
        for j in range(...):
            # the variables b, i and j index C in the order of their appearance in the equation
            C[b, i, j] = ...

Zasadniczo masz forpętlę dla każdego indeksu wyjściowego C. Na razie pozostawimy zakresy nieokreślone.

Następnie patrzymy na lewą stronę - czy są tam jakieś zmienne zakresu, które nie pojawiają się po prawej stronie? W naszym przypadku - tak hi w. Dodaj wewnętrzną zagnieżdżoną forpętlę dla każdej takiej zmiennej:

for b in range(...):
    for i in range(...):
        for j in range(...):
            C[b, i, j] = 0
            for h in range(...):
                for w in range(...):
                    ...

Wewnątrz najbardziej wewnętrznej pętli mamy teraz zdefiniowane wszystkie indeksy, więc możemy napisać faktyczne podsumowanie i tłumaczenie jest kompletne:

# three nested for-loops that index the elements of C
for b in range(...):
    for i in range(...):
        for j in range(...):

            # prepare to sum
            C[b, i, j] = 0

            # two nested for-loops for the two indexes that don't appear on the right-hand side
            for h in range(...):
                for w in range(...):
                    # Sum! Compare the statement below with the original einsum formula
                    # 'bhwi,bhwj->bij'

                    C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]

Jeśli do tej pory potrafiłeś postępować zgodnie z tym kodem, gratulacje! To wszystko, czego potrzebujesz, aby móc czytać równania einsum. Zwróć uwagę w szczególności na to, w jaki sposób oryginalna formuła einsum mapuje się na końcową instrukcję podsumowania we fragmencie powyżej. Pętle for i zakresy są po prostu puchate, a to ostatnie stwierdzenie jest wszystkim, czego naprawdę potrzebujesz, aby zrozumieć, co się dzieje.

Dla kompletności zobaczmy, jak określić zakresy dla każdej zmiennej zakresu. Cóż, zasięg każdej zmiennej jest po prostu długością wymiaru, który indeksuje. Oczywiście, jeśli zmienna indeksuje więcej niż jeden wymiar w jednym lub więcej tensorach, wówczas długości każdego z tych wymiarów muszą być równe. Oto powyższy kod z pełnymi zakresami:

# C's shape is determined by the shapes of the inputs
# b indexes both A and B, so its range can come from either A.shape or B.shape
# i indexes only A, so its range can only come from A.shape, the same is true for j and B
assert A.shape[0] == B.shape[0]
assert A.shape[1] == B.shape[1]
assert A.shape[2] == B.shape[2]
C = np.zeros((A.shape[0], A.shape[3], B.shape[3]))
for b in range(A.shape[0]): # b indexes both A and B, or B.shape[0], which must be the same
    for i in range(A.shape[3]):
        for j in range(B.shape[3]):
            # h and w can come from either A or B
            for h in range(A.shape[1]):
                for w in range(A.shape[2]):
                    C[b, i, j] += A[b, h, w, i] * B[b, h, w, j]
Stefan Dragnev
źródło
0

Myślę, że najprostszym przykładem są dokumenty tensorflow

Istnieją cztery kroki do przekształcenia równania w notację einsum. Weźmy to równanie jako przykładC[i,k] = sum_j A[i,j] * B[j,k]

  1. Najpierw upuszczamy nazwy zmiennych. Dostajemyik = sum_j ij * jk
  2. Porzucamy ten sum_jtermin, ponieważ jest niejawny. Dostajemyik = ij * jk
  3. Zamieniamy *się ,. Dostajemyik = ij, jk
  4. Dane wyjściowe znajdują się na RHS i są oddzielone ->znakiem. Dostajemyij, jk -> ik

Tłumacz einsum po prostu wykonuje te 4 kroki w odwrotnej kolejności. Wszystkie wskaźniki brakujące w wyniku są sumowane.

Oto kilka przykładów z dokumentów

# Matrix multiplication
einsum('ij,jk->ik', m0, m1)  # output[i,k] = sum_j m0[i,j] * m1[j, k]

# Dot product
einsum('i,i->', u, v)  # output = sum_i u[i]*v[i]

# Outer product
einsum('i,j->ij', u, v)  # output[i,j] = u[i]*v[j]

# Transpose
einsum('ij->ji', m)  # output[j,i] = m[i,j]

# Trace
einsum('ii', m)  # output[j,i] = trace(m) = sum_i m[i, i]

# Batch matrix multiplication
einsum('aij,ajk->aik', s, t)  # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
Souradeep Nanda
źródło