NumPy: funkcja jednoczesnego max () i min ()

109

numpy.amax () znajdzie maksymalną wartość w tablicy, a numpy.amin () zrobi to samo dla wartości minimalnej. Jeśli chcę znaleźć zarówno max, jak i min, muszę wywołać obie funkcje, co wymaga dwukrotnego przepuszczenia (bardzo dużej) tablicy, co wydaje się powolne.

Czy w numpy API jest funkcja, która znajduje zarówno max, jak i min za pomocą jednego przejścia przez dane?

Stuart Berg
źródło
1
Jak duże jest bardzo duże? Jeśli zdobędę trochę czasu, przeprowadzę kilka testów porównujących implementację amaxamin
Fortran
1
Przyznam, że „bardzo duże” jest subiektywne. W moim przypadku mówię o macierzach, które mają kilka GB.
Stuart Berg
to całkiem duże. Stworzyłem przykład, aby obliczyć go w fortranie (nawet jeśli nie znasz fortranu, zrozumienie kodu powinno być dość łatwe). To naprawdę robi różnicę, jeśli chodzi o przejście od fortranu do biegania przez numpy. (Prawdopodobnie powinieneś być w stanie uzyskać taką samą wydajność od C ...) Nie jestem pewien - przypuszczam, że potrzebowalibyśmy odrętwiałego programisty, aby skomentować, dlaczego moje funkcje działają o wiele lepiej niż ich ...
mgilson
Oczywiście nie jest to nowatorski pomysł. Na przykład biblioteka boost minmax (C ++) zapewnia implementację algorytmu, którego szukam.
Stuart Berg
3
Niezbyt odpowiedź na zadane pytanie, ale prawdopodobnie zainteresuje ludzi w tym wątku. Zapytano NumPy o dodanie minmaxdo biblioteki, której dotyczy problem ( github.com/numpy/numpy/issues/9836 ).
jakirkham

Odpowiedzi:

49

Czy w numpy API jest funkcja, która znajduje zarówno max, jak i min za pomocą jednego przejścia przez dane?

Nie. W chwili pisania tego tekstu nie ma takiej funkcji. (I tak, jeśli nie było takiej funkcji, jej wyniki byłyby znacznie lepsze niż dzwonienie numpy.amin()i numpy.amax()kolejno na dużej tablicy.)

Stuart Berg
źródło
31

Nie sądzę, aby dwukrotne przepuszczenie tablicy było problemem. Rozważmy następujący pseudokod:

minval = array[0]
maxval = array[0]
for i in array:
    if i < minval:
       minval = i
    if i > maxval:
       maxval = i

Chociaż jest tu tylko 1 pętla, nadal są 2 sprawdzenia. (Zamiast 2 pętli z 1 czekiem każda). Naprawdę jedyne, co oszczędzasz, to narzut 1 pętli. Jeśli tablice są naprawdę duże, jak mówisz, to narzut jest niewielki w porównaniu z rzeczywistym obciążeniem pracą pętli. (Zauważ, że to wszystko jest zaimplementowane w C, więc pętle i tak są mniej więcej wolne).


EDYTUJ Przepraszamy 4 z was, którzy głosowali za i wierzyli we mnie. Zdecydowanie możesz to zoptymalizować.

Oto kod Fortran, który można wkompilować w moduł Pythona za pośrednictwem f2py(może Cythonprzyjdzie guru i porówna to ze zoptymalizowaną wersją C ...):

subroutine minmax1(a,n,amin,amax)
  implicit none
  !f2py intent(hidden) :: n
  !f2py intent(out) :: amin,amax
  !f2py intent(in) :: a
  integer n
  real a(n),amin,amax
  integer i

  amin = a(1)
  amax = a(1)
  do i=2, n
     if(a(i) > amax)then
        amax = a(i)
     elseif(a(i) < amin) then
        amin = a(i)
     endif
  enddo
end subroutine minmax1

subroutine minmax2(a,n,amin,amax)
  implicit none
  !f2py intent(hidden) :: n
  !f2py intent(out) :: amin,amax
  !f2py intent(in) :: a
  integer n
  real a(n),amin,amax
  amin = minval(a)
  amax = maxval(a)
end subroutine minmax2

Skompiluj przez:

f2py -m untitled -c fortran_code.f90

A teraz jesteśmy w miejscu, w którym możemy to przetestować:

import timeit

size = 100000
repeat = 10000

print timeit.timeit(
    'np.min(a); np.max(a)',
    setup='import numpy as np; a = np.arange(%d, dtype=np.float32)' % size,
    number=repeat), " # numpy min/max"

print timeit.timeit(
    'untitled.minmax1(a)',
    setup='import numpy as np; import untitled; a = np.arange(%d, dtype=np.float32)' % size,
    number=repeat), '# minmax1'

print timeit.timeit(
    'untitled.minmax2(a)',
    setup='import numpy as np; import untitled; a = np.arange(%d, dtype=np.float32)' % size,
    number=repeat), '# minmax2'

Wyniki są dla mnie nieco oszałamiające:

8.61869883537 # numpy min/max
1.60417699814 # minmax1
2.30169081688 # minmax2

Muszę powiedzieć, że nie do końca to rozumiem. Porównywanie tylko np.minz minmax1i minmax2nadal jest przegraną bitwą, więc nie jest to tylko kwestia pamięci ...

uwagi - Zwiększanie rozmiaru o współczynnik 10**ai zmniejszanie liczby powtórzeń o współczynnik 10**a(utrzymywanie stałego rozmiaru problemu) zmienia wydajność, ale nie w pozornie spójny sposób, co pokazuje, że istnieje pewna zależność między wydajnością pamięci a narzutem wywołań funkcji w pyton. Nawet porównanie prostej minimplementacji w fortran beats numpy's przez współczynnik około 2 ...

mgilson
źródło
21
Zaletą pojedynczego przebiegu jest wydajność pamięci. W szczególności, jeśli macierz jest wystarczająco duża, aby ją wymienić, może to być ogromne.
Dougal
4
To nie do końca prawda, jest prawie o połowę mniejszy, ponieważ w tego rodzaju tablicach szybkość pamięci jest zwykle czynnikiem ograniczającym, więc może być o połowę mniejsza ...
seberg
3
Nie zawsze potrzebujesz dwóch czeków. Jeśli i < minvaljest prawdziwe, i > maxvalto zawsze jest fałszywe, więc wystarczy wykonać średnio tylko 1,5 sprawdzenia na iterację, gdy sekunda ifjest zastępowana przez elif.
Fred Foo
2
Mała uwaga: wątpię, czy Cython jest sposobem na uzyskanie najbardziej zoptymalizowanego modułu C, który można wywołać w Pythonie. Celem Cythona jest być rodzajem Pythona z adnotacjami typu, który jest następnie tłumaczony maszynowo na C, podczas gdy f2pypo prostu zawija kodowany ręcznie Fortran, tak aby był wywoływany przez Pythona. „Sprawiedliwszy” test to prawdopodobnie ręczne kodowanie C, a następnie użycie f2py(!), Aby opakować go w Pythonie. Jeśli pozwalasz C ++, to Shed Skin może być idealnym miejscem do równoważenia łatwości kodowania z wydajnością.
John Y,
4
od numpy 1.8 min i max są wektoryzowane na platformach amd64, na moim core2duo numpy działa równie dobrze jak ten kod fortran. Ale pojedynczy przebieg byłby korzystny, gdyby tablica przekraczała rozmiar większych pamięci podręcznych procesora.
jtaylor
23

Istnieje funkcja do znajdowania (max-min) o nazwie numpy.ptp, jeśli jest to dla Ciebie przydatne:

>>> import numpy
>>> x = numpy.array([1,2,3,4,5,6])
>>> x.ptp()
5

ale nie sądzę, aby można było znaleźć zarówno wartość minimalną, jak i maksymalną przy jednym przejściu.

EDYCJA: ptp po prostu wywołuje min i max pod maską

jterrace
źródło
2
Jest to denerwujące, ponieważ przypuszczalnie sposób implementacji ptp musi śledzić max i min!
Andy Hayden
1
Lub może po prostu zadzwonić na max i min, nie jestem pewien
jterrace
3
@hayden okazuje się, że ptp po prostu wywołuje max i min
jterrace
1
To był kod zamaskowanej tablicy; główny kod ndarray znajduje się w C. Ale okazuje się, że kod C również dwukrotnie iteruje po tablicy: github.com/numpy/numpy/blob/… .
Ken Arnold
20

Możesz użyć Numba , który jest dynamicznym kompilatorem Pythona obsługującym NumPy używającym LLVM. Wynikowa implementacja jest dość prosta i przejrzysta:

import numpy
import numba


@numba.jit
def minmax(x):
    maximum = x[0]
    minimum = x[0]
    for i in x[1:]:
        if i > maximum:
            maximum = i
        elif i < minimum:
            minimum = i
    return (minimum, maximum)


numpy.random.seed(1)
x = numpy.random.rand(1000000)
print(minmax(x) == (x.min(), x.max()))

Powinien też być szybszy niż min() & max()implementacja Numpy . A wszystko to bez konieczności pisania jednej linii kodu w języku C / Fortran.

Wykonaj własne testy wydajności, ponieważ zawsze zależy to od architektury, danych, wersji pakietów ...

Peque
źródło
2
> Powinien być również szybszy niż implementacja min () i max () Numpy'ego. Nie sądzę, żeby to było właściwe. numpy nie jest natywnym Pythonem - to C. `` x = numpy.random.rand (10000000) t = time () for i in range (1000): minmax (x) print ('numba', time () - t) t = time () for i in range (1000): x.min () x.max () print ('numpy', time () - t) `` Wyniki w: ('numba', 10.299750089645386 ) ('numpy', 9.898081064224243)
Authman Apatira
1
@AuthmanApatira: Tak, benchmarki są zawsze takie, dlatego powiedziałem, że „ powinno ” (być szybsze) i „ zrób własne testy wydajności, ponieważ zawsze zależy to od Twojej architektury, danych… ”. W moim przypadku próbowałem z 3 komputerami i uzyskałem ten sam wynik (Numba był szybszy niż Numpy), ale na twoim komputerze wyniki mogą się różnić ... Czy próbowałeś wykonać numbafunkcję raz przed testem porównawczym, aby upewnić się, że jest skompilowana w JIT ?. Ponadto, jeśli używasz ipython, dla uproszczenia, sugerowałbym użycie %timeit whatever_code()do pomiaru czasu wykonania.
Peque
3
@AuthmanApatira: W każdym razie tym, co próbowałem pokazać tą odpowiedzią, jest to, że czasami kod Pythona (w tym przypadku JIT-skompilowany z Numba) może być tak szybki, jak najszybsza skompilowana biblioteka C (przynajmniej mówimy o tej samej kolejności wielkości), co robi wrażenie, biorąc pod uwagę, że nie napisaliśmy nic poza czystym kodem Pythona, nie zgadzasz się? ^^
Peque
Zgadzam się =) Również dziękuję za wskazówki w poprzednim komentarzu dotyczące jupyter i kompilacji funkcji poza kodem czasowym.
Authman Apatira
1
Po prostu natrafiłem na to, co nie ma znaczenia w praktycznych przypadkach, ale elifpozwala na to, aby twoje minimum było większe niż twoje maksimum. Np. W przypadku tablicy o długości 1, max będzie taka sama jak ta wartość, podczas gdy min to + nieskończoność. Nie jest to wielka sprawa dla jednorazowego, ale nie jest to dobry kod, który można wrzucić głęboko w brzuch bestii produkcyjnej.
Mike Williamson
12

Ogólnie rzecz biorąc, można zmniejszyć liczbę porównań dla algorytmu minmax, przetwarzając jednocześnie dwa elementy i porównując tylko mniejszy z tymczasowym minimum, a większy z tymczasowym maksimum. Średnio potrzeba tylko 3/4 porównań niż podejście naiwne.

Można to zaimplementować w c lub fortran (lub jakimkolwiek innym języku niskiego poziomu) i powinno być prawie nie do pobicia pod względem wydajności. używam aby zilustrować zasadę i uzyskać bardzo szybką, niezależną od typu implementację:

import numba as nb
import numpy as np

@nb.njit
def minmax(array):
    # Ravel the array and return early if it's empty
    array = array.ravel()
    length = array.size
    if not length:
        return

    # We want to process two elements at once so we need
    # an even sized array, but we preprocess the first and
    # start with the second element, so we want it "odd"
    odd = length % 2
    if not odd:
        length -= 1

    # Initialize min and max with the first item
    minimum = maximum = array[0]

    i = 1
    while i < length:
        # Get the next two items and swap them if necessary
        x = array[i]
        y = array[i+1]
        if x > y:
            x, y = y, x
        # Compare the min with the smaller one and the max
        # with the bigger one
        minimum = min(x, minimum)
        maximum = max(y, maximum)
        i += 2

    # If we had an even sized array we need to compare the
    # one remaining item too.
    if not odd:
        x = array[length]
        minimum = min(x, minimum)
        maximum = max(x, maximum)

    return minimum, maximum

Jest zdecydowanie szybszy niż naiwne podejście, które przedstawił Peque :

arr = np.random.random(3000000)
assert minmax(arr) == minmax_peque(arr)  # warmup and making sure they are identical 
%timeit minmax(arr)            # 100 loops, best of 3: 2.1 ms per loop
%timeit minmax_peque(arr)      # 100 loops, best of 3: 2.75 ms per loop

Zgodnie z oczekiwaniami nowa implementacja minmax zajmuje tylko około 3/4 czasu, jaki zajęła naiwna implementacja ( 2.1 / 2.75 = 0.7636363636363637)

MSeifert
źródło
1
Na moim komputerze Twoje rozwiązanie nie jest szybsze niż rozwiązanie Peque. Numba 0.33.
John Zwinck
@johnzwinck, czy przeprowadziłeś test porównawczy w mojej odpowiedzi, czy inny? Jeśli tak, czy mógłbyś się nim podzielić? Ale jest to możliwe: zauważyłem pewne regresje w nowszych wersjach.
MSeifert
Sprawdziłem twój punkt odniesienia. Czasy twojego rozwiązania i @ Peque były prawie takie same (~ 2,8 ms).
John Zwinck
@JohnZwinck To dziwne, właśnie przetestowałem to ponownie i na moim komputerze jest zdecydowanie szybsze. Może ma to coś wspólnego z numba i LLVM, co zależy od sprzętu.
MSeifert
Wypróbowałem teraz na innej maszynie (mocna stacja robocza) i uzyskałem 2,4 ms dla twojego vs 2,6 dla Peque'a. Więc mała wygrana.
John Zwinck
11

Aby uzyskać kilka pomysłów na liczby, których można się spodziewać, biorąc pod uwagę następujące podejścia:

import numpy as np


def extrema_np(arr):
    return np.max(arr), np.min(arr)
import numba as nb


@nb.jit(nopython=True)
def extrema_loop_nb(arr):
    n = arr.size
    max_val = min_val = arr[0]
    for i in range(1, n):
        item = arr[i]
        if item > max_val:
            max_val = item
        elif item < min_val:
            min_val = item
    return max_val, min_val
import numba as nb


@nb.jit(nopython=True)
def extrema_while_nb(arr):
    n = arr.size
    odd = n % 2
    if not odd:
        n -= 1
    max_val = min_val = arr[0]
    i = 1
    while i < n:
        x = arr[i]
        y = arr[i + 1]
        if x > y:
            x, y = y, x
        min_val = min(x, min_val)
        max_val = max(y, max_val)
        i += 2
    if not odd:
        x = arr[n]
        min_val = min(x, min_val)
        max_val = max(x, max_val)
    return max_val, min_val
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


import numpy as np


cdef void _extrema_loop_cy(
        long[:] arr,
        size_t n,
        long[:] result):
    cdef size_t i
    cdef long item, max_val, min_val
    max_val = arr[0]
    min_val = arr[0]
    for i in range(1, n):
        item = arr[i]
        if item > max_val:
            max_val = item
        elif item < min_val:
            min_val = item
    result[0] = max_val
    result[1] = min_val


def extrema_loop_cy(arr):
    result = np.zeros(2, dtype=arr.dtype)
    _extrema_loop_cy(arr, arr.size, result)
    return result[0], result[1]
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


import numpy as np


cdef void _extrema_while_cy(
        long[:] arr,
        size_t n,
        long[:] result):
    cdef size_t i, odd
    cdef long x, y, max_val, min_val
    max_val = arr[0]
    min_val = arr[0]
    odd = n % 2
    if not odd:
        n -= 1
    max_val = min_val = arr[0]
    i = 1
    while i < n:
        x = arr[i]
        y = arr[i + 1]
        if x > y:
            x, y = y, x
        min_val = min(x, min_val)
        max_val = max(y, max_val)
        i += 2
    if not odd:
        x = arr[n]
        min_val = min(x, min_val)
        max_val = max(x, max_val)
    result[0] = max_val
    result[1] = min_val


def extrema_while_cy(arr):
    result = np.zeros(2, dtype=arr.dtype)
    _extrema_while_cy(arr, arr.size, result)
    return result[0], result[1]

( extrema_loop_*()podejścia są podobne do tego, co jest proponowane tutaj , podczas gdy extrema_while_*()podejścia są oparte na kodzie z tego miejsca )

Następujące terminy:

bm

wskazują, że extrema_while_*()są najszybsi i extrema_while_nb()najszybsi. W każdym razie, także extrema_loop_nb()i extrema_loop_cy()roztwory do przewyższają NumPy tylko do podejścia (użyciu np.max()i np.min()oddzielnie).

Na koniec zwróć uwagę, że żaden z nich nie jest tak elastyczny jak np.min()/ np.max()(pod względem obsługi n-dim, axisparametrów itp.).

(pełny kod jest dostępny tutaj )

norok2
źródło
2
Wydaje się, że możesz uzyskać dodatkowe 10% prędkości, jeśli użyjesz @njit (fastmath = True)extrema_while_nb
argenisleon
10

Nikt nie wspomniał o numpy.percentile , więc pomyślałem, że tak. Jeśli poprosisz o [0, 100]percentyle, otrzymasz tablicę dwóch elementów, min (0-ty percentyl) i maksymalny (100-ty percentyl).

Nie spełnia to jednak celu PO: nie jest szybsze niż oddzielnie min i max. Jest to prawdopodobnie spowodowane niektórymi maszynami, które pozwoliłyby na nie-ekstremalne percentyle (trudniejszy problem, który powinien zająć więcej czasu).

In [1]: import numpy

In [2]: a = numpy.random.normal(0, 1, 1000000)

In [3]: %%timeit
   ...: lo, hi = numpy.amin(a), numpy.amax(a)
   ...: 
100 loops, best of 3: 4.08 ms per loop

In [4]: %%timeit
   ...: lo, hi = numpy.percentile(a, [0, 100])
   ...: 
100 loops, best of 3: 17.2 ms per loop

In [5]: numpy.__version__
Out[5]: '1.14.4'

Przyszła wersja Numpy mogłaby umieścić specjalny przypadek, aby pominąć normalne obliczanie percentyla, jeśli tylko [0, 100]jest to wymagane. Bez dodawania czegokolwiek do interfejsu istnieje sposób, aby poprosić Numpy'ego o min i max w jednym wywołaniu (w przeciwieństwie do tego, co zostało powiedziane w akceptowanej odpowiedzi), ale standardowa implementacja biblioteki nie wykorzystuje tego przypadku, aby to zrobić wart.

Jim Pivarski
źródło
9

To stary wątek, ale w każdym razie, jeśli ktoś jeszcze raz na to spojrzy ...

Szukając jednocześnie wartości min i max, można zmniejszyć liczbę porównań. Jeśli porównujesz dane zmiennoprzecinkowe (a wydaje mi się, że tak jest), może to zaoszczędzić trochę czasu, chociaż nie jest to złożoność obliczeniowa.

Zamiast (kod Pythona):

_max = ar[0]
_min=  ar[0]
for ii in xrange(len(ar)):
    if _max > ar[ii]: _max = ar[ii]
    if _min < ar[ii]: _min = ar[ii]

możesz najpierw porównać dwie sąsiednie wartości w tablicy, a następnie porównać tylko mniejszą z bieżącym minimum, a większą z bieżącym maksimum:

## for an even-sized array
_max = ar[0]
_min = ar[0]
for ii in xrange(0, len(ar), 2)):  ## iterate over every other value in the array
    f1 = ar[ii]
    f2 = ar[ii+1]
    if (f1 < f2):
        if f1 < _min: _min = f1
        if f2 > _max: _max = f2
    else:
        if f2 < _min: _min = f2
        if f1 > _max: _max = f1

Kod tutaj jest napisany w Pythonie, najwyraźniej ze względu na szybkość używałbyś C, Fortran lub Cython, ale w ten sposób wykonujesz 3 porównania na iterację, z len (ar) / 2 iteracjami, dając porównania 3/2 * len (ar). W przeciwieństwie do tego, wykonując porównanie „w sposób oczywisty”, wykonujesz dwa porównania na iterację, co prowadzi do porównań 2 * len (ar). Oszczędza 25% czasu porównania.

Może ktoś kiedyś uzna to za przydatne.

Bennet
źródło
6
czy sprawdziłeś to? Na nowoczesnym sprzęcie x86 masz instrukcje maszynowe dla min i max, jak zastosowano w pierwszym wariancie, dzięki czemu unikasz potrzeby rozgałęzień, podczas gdy twój kod nakłada zależność sterowania, która prawdopodobnie nie jest tak dobrze mapowana na sprzęt.
jtaylor
Właściwie nie. Zrobię, jeśli będę miał szansę. Myślę, że jest całkiem jasne, że czysty kod Pythona przegra z każdą rozsądną skompilowaną implementacją, ale zastanawiam się, czy można było zobaczyć przyspieszenie w Cythonie ...
Bennet
13
Istnieje implementacja minmax w numpy, pod maską, używana przez np.bincount, patrz tutaj . Nie wykorzystuje sztuczki, którą wskazałeś, bo okazało się, że jest nawet 2x wolniejsze niż podejście naiwne. Istnieje link z PR do niektórych kompleksowych testów porównawczych obu metod.
Jaime,
5

Na pierwszy rzut oka wydaje się, że załatwia sprawę:numpy.histogram

count, (amin, amax) = numpy.histogram(a, bins=1)

... ale jeśli spojrzysz na źródło tej funkcji, po prostu wywołuje a.min()i a.max()niezależnie, a zatem nie pozwala uniknąć problemów z wydajnością, o których mowa w tym pytaniu. :-(

Podobnie scipy.ndimage.measurements.extremawygląda na możliwość, ale też po prostu dzwoni a.min()i a.max()samodzielnie.

Stuart Berg
źródło
3
np.histogramnie zawsze działa w tym przypadku, ponieważ zwracane (amin, amax)wartości dotyczą minimalnych i maksymalnych wartości pojemnika. Jeśli mam, na przykład a = np.zeros(10), np.histogram(a, bins=1)zwraca (array([10]), array([-0.5, 0.5])). W takim przypadku użytkownik szuka (amin, amax)= (0, 0).
eclark
3

Zresztą i tak było to dla mnie warte wysiłku, więc każdemu zainteresowanemu zaproponuję najtrudniejsze i najmniej eleganckie rozwiązanie. Moim rozwiązaniem jest zaimplementowanie wielowątkowego algorytmu min-max w jednym przebiegu w C ++ i użycie go do utworzenia modułu rozszerzenia Python. Ten wysiłek wymaga trochę narzutu, aby nauczyć się korzystać z interfejsów API Python i NumPy C / C ++, a tutaj pokażę kod i podam kilka małych wyjaśnień i odniesień dla każdego, kto chce podążać tą ścieżką.

Wielowątkowe min./maks

Nie ma tu nic ciekawego. Tablica jest podzielona na fragmenty o rozmiarze length / workers. Wartości min / max są obliczane dla każdego fragmentu w a future, które są następnie skanowane pod kątem globalnego min / max.

    // mt_np.cc
    //
    // multi-threaded min/max algorithm

    #include <algorithm>
    #include <future>
    #include <vector>

    namespace mt_np {

    /*
     * Get {min,max} in interval [begin,end)
     */
    template <typename T> std::pair<T, T> min_max(T *begin, T *end) {
      T min{*begin};
      T max{*begin};
      while (++begin < end) {
        if (*begin < min) {
          min = *begin;
          continue;
        } else if (*begin > max) {
          max = *begin;
        }
      }
      return {min, max};
    }

    /*
     * get {min,max} in interval [begin,end) using #workers for concurrency
     */
    template <typename T>
    std::pair<T, T> min_max_mt(T *begin, T *end, int workers) {
      const long int chunk_size = std::max((end - begin) / workers, 1l);
      std::vector<std::future<std::pair<T, T>>> min_maxes;
      // fire up the workers
      while (begin < end) {
        T *next = std::min(end, begin + chunk_size);
        min_maxes.push_back(std::async(min_max<T>, begin, next));
        begin = next;
      }
      // retrieve the results
      auto min_max_it = min_maxes.begin();
      auto v{min_max_it->get()};
      T min{v.first};
      T max{v.second};
      while (++min_max_it != min_maxes.end()) {
        v = min_max_it->get();
        min = std::min(min, v.first);
        max = std::max(max, v.second);
      }
      return {min, max};
    }
    }; // namespace mt_np

Moduł rozszerzeń języka Python

Tutaj zaczyna się brzydko ... Jednym ze sposobów wykorzystania kodu C ++ w Pythonie jest zaimplementowanie modułu rozszerzającego. Ten moduł można zbudować i zainstalować przy użyciu distutils.corestandardowego modułu. Pełny opis tego, co się z tym wiąże, znajduje się w dokumentacji Pythona: https://docs.python.org/3/extending/extending.html . UWAGA: z pewnością istnieją inne sposoby uzyskania podobnych wyników, cytując https://docs.python.org/3/extending/index.html#extending-index :

Ten przewodnik obejmuje tylko podstawowe narzędzia do tworzenia rozszerzeń dostarczane jako część tej wersji CPython. Narzędzia innych firm, takie jak Cython, cffi, SWIG i Numba, oferują zarówno prostsze, jak i bardziej wyrafinowane podejście do tworzenia rozszerzeń C i C ++ dla Pythona.

Zasadniczo ta trasa jest prawdopodobnie bardziej akademicka niż praktyczna. Mając to na uwadze, następnym krokiem było trzymanie się samouczka i utworzenie pliku modułu. Jest to zasadniczo szablon dla distutils, aby wiedzieć, co zrobić z kodem i stworzyć z niego moduł Pythona. Zanim to zrobisz, prawdopodobnie dobrze jest utworzyć wirtualne środowisko Pythona, aby nie zanieczyszczać pakietów systemowych (patrz https://docs.python.org/3/library/venv.html#module-venv ).

Oto plik modułu:

// mt_np_forpy.cc
//
// C++ module implementation for multi-threaded min/max for np

#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION

#include <python3.6/numpy/arrayobject.h>

#include "mt_np.h"

#include <cstdint>
#include <iostream>

using namespace std;

/*
 * check:
 *  shape
 *  stride
 *  data_type
 *  byteorder
 *  alignment
 */
static bool check_array(PyArrayObject *arr) {
  if (PyArray_NDIM(arr) != 1) {
    PyErr_SetString(PyExc_RuntimeError, "Wrong shape, require (1,n)");
    return false;
  }
  if (PyArray_STRIDES(arr)[0] != 8) {
    PyErr_SetString(PyExc_RuntimeError, "Expected stride of 8");
    return false;
  }
  PyArray_Descr *descr = PyArray_DESCR(arr);
  if (descr->type != NPY_LONGLTR && descr->type != NPY_DOUBLELTR) {
    PyErr_SetString(PyExc_RuntimeError, "Wrong type, require l or d");
    return false;
  }
  if (descr->byteorder != '=') {
    PyErr_SetString(PyExc_RuntimeError, "Expected native byteorder");
    return false;
  }
  if (descr->alignment != 8) {
    cerr << "alignment: " << descr->alignment << endl;
    PyErr_SetString(PyExc_RuntimeError, "Require proper alignement");
    return false;
  }
  return true;
}

template <typename T>
static PyObject *mt_np_minmax_dispatch(PyArrayObject *arr) {
  npy_intp size = PyArray_SHAPE(arr)[0];
  T *begin = (T *)PyArray_DATA(arr);
  auto minmax =
      mt_np::min_max_mt(begin, begin + size, thread::hardware_concurrency());
  return Py_BuildValue("(L,L)", minmax.first, minmax.second);
}

static PyObject *mt_np_minmax(PyObject *self, PyObject *args) {
  PyArrayObject *arr;
  if (!PyArg_ParseTuple(args, "O", &arr))
    return NULL;
  if (!check_array(arr))
    return NULL;
  switch (PyArray_DESCR(arr)->type) {
  case NPY_LONGLTR: {
    return mt_np_minmax_dispatch<int64_t>(arr);
  } break;
  case NPY_DOUBLELTR: {
    return mt_np_minmax_dispatch<double>(arr);
  } break;
  default: {
    PyErr_SetString(PyExc_RuntimeError, "Unknown error");
    return NULL;
  }
  }
}

static PyObject *get_concurrency(PyObject *self, PyObject *args) {
  return Py_BuildValue("I", thread::hardware_concurrency());
}

static PyMethodDef mt_np_Methods[] = {
    {"mt_np_minmax", mt_np_minmax, METH_VARARGS, "multi-threaded np min/max"},
    {"get_concurrency", get_concurrency, METH_VARARGS,
     "retrieve thread::hardware_concurrency()"},
    {NULL, NULL, 0, NULL} /* sentinel */
};

static struct PyModuleDef mt_np_module = {PyModuleDef_HEAD_INIT, "mt_np", NULL,
                                          -1, mt_np_Methods};

PyMODINIT_FUNC PyInit_mt_np() { return PyModule_Create(&mt_np_module); }

W tym pliku znajduje się znaczące wykorzystanie Pythona, a także NumPy API, aby uzyskać więcej informacji, odwiedź: https://docs.python.org/3/c-api/arg.html#c.PyArg_ParseTuple i NumPy : https://docs.scipy.org/doc/numpy/reference/c-api.array.html .

Instalowanie modułu

Następną rzeczą do zrobienia jest użycie distutils do zainstalowania modułu. Wymaga to pliku instalacyjnego:

# setup.py

from distutils.core import setup,Extension

module = Extension('mt_np', sources = ['mt_np_module.cc'])

setup (name = 'mt_np', 
       version = '1.0', 
       description = 'multi-threaded min/max for np arrays',
       ext_modules = [module])

Aby ostatecznie zainstalować moduł, wykonaj python3 setup.py installz poziomu środowiska wirtualnego.

Testowanie modułu

Na koniec możemy przetestować, czy implementacja C ++ faktycznie przewyższa naiwne użycie NumPy. Aby to zrobić, oto prosty skrypt testowy:

# timing.py
# compare numpy min/max vs multi-threaded min/max

import numpy as np
import mt_np
import timeit

def normal_min_max(X):
  return (np.min(X),np.max(X))

print(mt_np.get_concurrency())

for ssize in np.logspace(3,8,6):
  size = int(ssize)
  print('********************')
  print('sample size:', size)
  print('********************')
  samples = np.random.normal(0,50,(2,size))
  for sample in samples:
    print('np:', timeit.timeit('normal_min_max(sample)',
                 globals=globals(),number=10))
    print('mt:', timeit.timeit('mt_np.mt_np_minmax(sample)',
                 globals=globals(),number=10))

Oto wyniki, które otrzymałem, robiąc to wszystko:

8  
********************  
sample size: 1000  
********************  
np: 0.00012079699808964506  
mt: 0.002468645994667895  
np: 0.00011947099847020581  
mt: 0.0020772050047526136  
********************  
sample size: 10000  
********************  
np: 0.00024697799381101504  
mt: 0.002037393998762127  
np: 0.0002713389985729009  
mt: 0.0020942929986631498  
********************  
sample size: 100000  
********************  
np: 0.0007130410012905486  
mt: 0.0019842900001094677  
np: 0.0007540129954577424  
mt: 0.0029724110063398257  
********************  
sample size: 1000000  
********************  
np: 0.0094779249993735  
mt: 0.007134920000680722  
np: 0.009129883001151029  
mt: 0.012836456997320056  
********************  
sample size: 10000000  
********************  
np: 0.09471094200125663  
mt: 0.0453535050037317  
np: 0.09436299200024223  
mt: 0.04188535599678289  
********************  
sample size: 100000000  
********************  
np: 0.9537652180006262  
mt: 0.3957935369980987  
np: 0.9624398809974082  
mt: 0.4019058070043684  

Są one znacznie mniej zachęcające, niż wskazują wyniki wcześniej w wątku, które wskazywały na około 3,5-krotne przyspieszenie i nie obejmowały wielowątkowości. Wyniki, które osiągnąłem, są dość rozsądne, spodziewałbym się, że narzut wątków i będzie dominował w czasie, aż tablice staną się bardzo duże, w którym to momencie wzrost wydajności zacznie zbliżać się do std::thread::hardware_concurrencywzrostu x.

Wniosek

Wydaje się, że z pewnością jest miejsce na optymalizacje specyficzne dla aplikacji w niektórych kodach NumPy, w szczególności w odniesieniu do wielowątkowości. Nie jest dla mnie jasne, czy warto podjąć ten wysiłek, ale z pewnością wydaje się to dobrym ćwiczeniem (lub czymś w tym rodzaju). Myślę, że nauczenie się niektórych z tych „narzędzi stron trzecich”, takich jak Cython, może być lepszym sposobem wykorzystania czasu, ale kto wie.

Nathan Chappell
źródło
1
Zaczynam studiować Twój kod, znam trochę C ++, ale nadal nie używam std :: future i std :: async. Skąd w funkcji szablonu „min_max_mt” wie, że każdy pracownik skończył pracę między zwolnieniem a pobraniem wyników? (Proszę tylko zrozumieć, nie mówiąc, że jest w tym coś złego)
ChrCury78
Linia v = min_max_it->get();. Te getmetody bloki aż wynik jest gotowy i zwraca go. Ponieważ pętla przechodzi przez każdą przyszłość, nie zakończy się, dopóki wszystkie nie zostaną ukończone. future.get ()
Nathan Chappell
0

Najkrótszy sposób, jaki wymyśliłem, to:

mn, mx = np.sort(ar)[[0, -1]]

Ale ponieważ sortuje tablicę, nie jest najbardziej wydajna.

Innym krótkim sposobem byłoby:

mn, mx = np.percentile(ar, [0, 100])

Powinno to być bardziej wydajne, ale wynik jest obliczany i zwracana jest wartość zmiennoprzecinkowa.

Israel Unterman
źródło
Szkoda, że ​​te dwa rozwiązania są najwolniejsze w porównaniu do innych na tej stronie: m = np.min (a); M = np.max (a) -> 0,54002 ||| m, M = f90_minmax1 (a) -> 0,72134 ||| m, M = numba_minmax (a) -> 0,77323 ||| m, M = np.sort (a) [[0, -1]] -> 12,01456 ||| m, M = np.percentyl (a, [0, 100]) -> 11,09418 ||| w sekundach dla 10000 powtórzeń dla tablicy 100000 elementów
Isaías