Sprawdź, czy tablica numpy zawiera tylko zera

93

Inicjalizujemy tablicę numpy zerami, jak poniżej:

np.zeros((N,N+1))

Ale jak sprawdzić, czy wszystkie elementy w danej macierzy tablicowej n * n numpy mają wartość zero.
Metoda musi po prostu zwrócić True, jeśli wszystkie wartości są rzeczywiście zerowe.

Nieznany
źródło

Odpowiedzi:

73

Sprawdź numpy.count_nonzero .

>>> np.count_nonzero(np.eye(4))
4
>>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]])
5
Prashant Kumar
źródło
9
Chciałbyś not np.count_nonzero(np.eye(4))zwrócić Truetylko wtedy, gdy wszystkie wartości są
równe
166

Inne zamieszczone tutaj odpowiedzi będą działać, ale najbardziej przejrzystą i najbardziej wydajną funkcją jest numpy.any():

>>> all_zeros = not np.any(a)

lub

>>> all_zeros = not a.any()
  • Jest to preferowane rozwiązanie, numpy.all(a==0)ponieważ zużywa mniej pamięci RAM. (Nie wymaga tymczasowej tablicy utworzonej przez a==0termin).
  • Jest również szybsze niż numpy.count_nonzero(a)to, że może powrócić natychmiast po znalezieniu pierwszego niezerowego elementu.
    • Edycja: Jak @Rachel wskazała w komentarzach, np.any()nie używa już logiki „zwarcia”, więc nie zobaczysz korzyści związanych z szybkością w przypadku małych macierzy.
Stuart Berg
źródło
3
Jak z minuty temu numpy użytkownika anyi allzrobić nie zwarcia. Uważam, że są cukrem dla logical_or.reducei logical_and.reduce. Porównaj ze sobą i moje zwarcie is_in: all_false = np.zeros(10**8) all_true = np.ones(10**8) %timeit np.any(all_false) 91.5 ms ± 1.82 ms per loop %timeit np.any(all_true) 93.7 ms ± 6.16 ms per loop %timeit is_in(1, all_true) 293 ns ± 1.65 ns per loop
Rachel
3
To świetna uwaga, dzięki. Wygląda na to, że kiedyś występowało zwarcie , ale w pewnym momencie zostało to utracone. W odpowiedziach na to pytanie jest ciekawa dyskusja .
Stuart Berg
50

Użyłbym tutaj np.all, jeśli masz tablicę a:

>>> np.all(a==0)
J. Martinot-Lagarde
źródło
3
Podoba mi się, że ta odpowiedź sprawdza również wartości niezerowe. Na przykład można sprawdzić, czy wszystkie elementy w tablicy są takie same, wykonując czynność np.all(a==a[0]). Wielkie dzięki!
aignas
9

Jak mówi inna odpowiedź, możesz skorzystać z prawdziwych / fałszywych ocen, jeśli wiesz, że 0jest to jedyny fałszywy element w tablicy. Wszystkie elementy w tablicy są fałszywe, jeśli nie ma w niej żadnych prawdziwych elementów. *

>>> a = np.zeros(10)
>>> not np.any(a)
True

Jednak odpowiedź twierdziła, że anyjest szybszy niż inne opcje, częściowo z powodu zwarcia. Od 2018 roku Numpy alli any nie powodują zwarcia .

Jeśli często robisz takie rzeczy, bardzo łatwo jest stworzyć własne wersje powodujące zwarcie za pomocą numba:

import numba as nb

# short-circuiting replacement for np.any()
@nb.jit(nopython=True)
def sc_any(array):
    for x in array.flat:
        if x:
            return True
    return False

# short-circuiting replacement for np.all()
@nb.jit(nopython=True)
def sc_all(array):
    for x in array.flat:
        if not x:
            return False
    return True

Wydaje się, że są szybsze niż wersje Numpy, nawet jeśli nie powodują zwarcia. count_nonzerojest najwolniejszy.

Niektóre dane wejściowe do sprawdzenia wydajności:

import numpy as np

n = 10**8
middle = n//2
all_0 = np.zeros(n, dtype=int)
all_1 = np.ones(n, dtype=int)
mid_0 = np.ones(n, dtype=int)
mid_1 = np.zeros(n, dtype=int)
np.put(mid_0, middle, 0)
np.put(mid_1, middle, 1)
# mid_0 = [1 1 1 ... 1 0 1 ... 1 1 1]
# mid_1 = [0 0 0 ... 0 1 0 ... 0 0 0]

Czek:

## count_nonzero
%timeit np.count_nonzero(all_0) 
# 220 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.count_nonzero(all_1)
# 150 ms ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

### all
# np.all
%timeit np.all(all_1)
%timeit np.all(mid_0)
%timeit np.all(all_0)
# 56.8 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.4 ms ± 1.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 55.9 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_all
%timeit sc_all(all_1)
%timeit sc_all(mid_0)
%timeit sc_all(all_0)
# 44.4 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.7 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 288 ns ± 6.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

### any
# np.any
%timeit np.any(all_0)
%timeit np.any(mid_1)
%timeit np.any(all_1)
# 60.7 ms ± 1.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 60 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.7 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# sc_any
%timeit sc_any(all_0)
%timeit sc_any(mid_1)
%timeit sc_any(all_1)
# 41.7 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.4 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 287 ns ± 12.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

* Pomocne alli anyrównoważne:

np.all(a) == np.logical_not(np.any(np.logical_not(a)))
np.any(a) == np.logical_not(np.all(np.logical_not(a)))
not np.all(a) == np.any(np.logical_not(a))
not np.any(a) == np.all(np.logical_not(a))
Rachel
źródło
-8

Jeśli testujesz wszystkie zera, aby uniknąć ostrzeżenia o innej funkcji numpy, a następnie zawijasz wiersz próbą, z wyjątkiem bloku, który zaoszczędzi konieczności wykonywania testu zer przed operacją, którą jesteś zainteresowany, tj.

try: # removes output noise for empty slice 
    mean = np.mean(array)
except:
    mean = 0
ReaddyEddy
źródło