Numpy: pobierz indeks elementów tablicy 1d jako tablicy 2d

10

Mam tablicę numpy taką jak ta: [1 2 2 0 0 1 3 5]

Czy można uzyskać indeks elementów w postaci tablicy 2d? Na przykład odpowiedzią na powyższe dane wejściowe byłoby[[3 4], [0 5], [1 2], [6], [], [7]]

Obecnie muszę zapętlać różne wartości i wywoływać numpy.where(input == i)każdą wartość, która ma straszną wydajność przy wystarczająco dużym wejściu.

Frederico Schardong
źródło
np.argsort([1, 2, 2, 0, 0, 1, 3, 5])daje array([3, 4, 0, 5, 1, 2, 6, 7], dtype=int64). wtedy możesz po prostu porównać kolejne elementy.
vb_rises

Odpowiedzi:

11

Oto podejście O (max (x) + len (x)) przy użyciu scipy.sparse:

import numpy as np
from scipy import sparse

x = np.array("1 2 2 0 0 1 3 5".split(),int)
x
# array([1, 2, 2, 0, 0, 1, 3, 5])


M,N = x.max()+1,x.size
sparse.csc_matrix((x,x,np.arange(N+1)),(M,N)).tolil().rows.tolist()
# [[3, 4], [0, 5], [1, 2], [6], [], [7]]

Działa to poprzez utworzenie rzadkiej macierzy z wpisami w pozycjach (x [0], 0), (x [1], 1), ... Przy użyciu CSCformatu (skompresowanej kolumny rzadkiej) jest to dość proste. Macierz jest następnie konwertowana do LILformatu (lista połączona). Ten format przechowuje indeksy kolumn dla każdego wiersza jako listę w swoim rowsatrybucie, więc wszystko, co musimy zrobić, to wziąć to i przekonwertować na listę.

Należy zauważyć, że w przypadku argsortrozwiązań opartych na małych tablicach są one prawdopodobnie szybsze, ale przy niektórych, nie tak niesamowicie dużych rozmiarach, to się krzyżuje.

EDYTOWAĆ:

argsortna podstawie - numpytylko rozwiązanie:

np.split(x.argsort(kind="stable"),np.bincount(x)[:-1].cumsum())
# [array([3, 4]), array([0, 5]), array([1, 2]), array([6]), array([], dtype=int64), array([7])]

Jeśli kolejność indeksów w grupach nie ma znaczenia, możesz także spróbować argpartition(w tym małym przykładzie nie robi to różnicy, ale ogólnie nie jest to gwarantowane):

bb = np.bincount(x)[:-1].cumsum()
np.split(x.argpartition(bb),bb)
# [array([3, 4]), array([0, 5]), array([1, 2]), array([6]), array([], dtype=int64), array([7])]

EDYTOWAĆ:

@Divakar odradza korzystanie z np.split. Zamiast tego pętla jest prawdopodobnie szybsza:

A = x.argsort(kind="stable")
B = np.bincount(x+1).cumsum()
[A[B[i-1]:B[i]] for i in range(1,len(B))]

Lub możesz użyć zupełnie nowego operatora morsa (Python3.8 +):

A = x.argsort(kind="stable")
B = np.bincount(x)
L = 0
[A[L:(L:=L+b)] for b in B.tolist()]

EDYCJA (EDYCJA):

(Not pure numpy): Alternatywnie do numby (patrz post @ senderle) możemy również użyć pythran.

Połącz z pythran -O3 <filename.py>

import numpy as np

#pythran export sort_to_bins(int[:],int)

def sort_to_bins(idx, mx):
    if mx==-1: 
        mx = idx.max() + 1
    cnts = np.zeros(mx + 2, int)
    for i in range(idx.size):
        cnts[idx[i] + 2] += 1
    for i in range(3, cnts.size):
        cnts[i] += cnts[i-1]
    res = np.empty_like(idx)
    for i in range(idx.size):
        res[cnts[idx[i]+1]] = i
        cnts[idx[i]+1] += 1
    return [res[cnts[i]:cnts[i+1]] for i in range(mx)]

Tutaj numbawygrywa bokser pod względem wydajności:

repeat(lambda:enum_bins_numba_buffer(x),number=10)
# [0.6235917090671137, 0.6071486569708213, 0.6096088469494134]
repeat(lambda:sort_to_bins(x,-1),number=10)
# [0.6235359431011602, 0.6264424560358748, 0.6217901279451326]

Starsze rzeczy:

import numpy as np

#pythran export bincollect(int[:])

def bincollect(a):
    o = [[] for _ in range(a.max()+1)]
    for i,j in enumerate(a):
        o[j].append(i)
    return o

Czasy vs. Numba (stary)

timeit(lambda:bincollect(x),number=10)
# 3.5732191529823467
timeit(lambda:enumerate_bins(x),number=10)
# 6.7462647299980745
Paul Panzer
źródło
Ostatecznie okazało
Pętla powinna być lepsza niż np.split.
Divakar
@Divakar dobry punkt, dzięki!
Paul Panzer
8

Jedną z potencjalnych opcji w zależności od rozmiaru danych jest po prostu porzucenie numpyi użycie collections.defaultdict:

In [248]: from collections import defaultdict

In [249]: d = defaultdict(list)

In [250]: l = np.random.randint(0, 100, 100000)

In [251]: %%timeit
     ...: for k, v in enumerate(l):
     ...:     d[v].append(k)
     ...:
10 loops, best of 3: 22.8 ms per loop

Potem skończysz ze słownikiem {value1: [index1, index2, ...], value2: [index3, index4, ...]}. Skalowanie czasu jest bardzo zbliżone do liniowego z rozmiarem tablicy, więc 10 000 000 zajmuje ~ 2,7 s na moim komputerze, co wydaje się dość rozsądne.

Niespokojny
źródło
7

Chociaż prośba dotyczy numpyrozwiązania, postanowiłem sprawdzić, czy istnieje numbarozwiązanie oparte na ciekawych rozwiązaniach. I rzeczywiście jest! Oto podejście, które reprezentuje podzieloną na partycje listę jako poszarpaną tablicę przechowywaną w pojedynczym wstępnie przydzielonym buforze. Inspiruje to argsortpodejście zaproponowane przez Paula Panzera . (W przypadku starszej wersji, która nie działała tak dobrze, ale była prostsza, patrz poniżej).

@numba.jit(numba.void(numba.int64[:], 
                      numba.int64[:], 
                      numba.int64[:]), 
           nopython=True)
def enum_bins_numba_buffer_inner(ints, bins, starts):
    for x in range(len(ints)):
        i = ints[x]
        bins[starts[i]] = x
        starts[i] += 1

@numba.jit(nopython=False)  # Not 100% sure this does anything...
def enum_bins_numba_buffer(ints):
    ends = np.bincount(ints).cumsum()
    starts = np.empty(ends.shape, dtype=np.int64)
    starts[1:] = ends[:-1]
    starts[0] = 0

    bins = np.empty(ints.shape, dtype=np.int64)
    enum_bins_numba_buffer_inner(ints, bins, starts)

    starts[1:] = ends[:-1]
    starts[0] = 0
    return [bins[s:e] for s, e in zip(starts, ends)]

Przetwarza dziesięciomilionową listę elementów w 75 ms, co stanowi prawie 50-krotne przyspieszenie w porównaniu z wersją opartą na listach napisaną w czystym języku Python.

W przypadku wolniejszej, ale nieco bardziej czytelnej wersji, oto co miałem wcześniej, w oparciu o niedawno dodane eksperymentalne wsparcie dla dynamicznie zmieniających się „list maszynowych”, które pozwalają nam szybciej zapełniać każdy pojemnik w niewłaściwym porządku.

To numbatrochę zmaga się z silnikiem wnioskowania typu i jestem pewien, że jest lepszy sposób na poradzenie sobie z tą częścią. To również okazuje się prawie 10 razy wolniejsze niż powyższe.

@numba.jit(nopython=True)
def enum_bins_numba(ints):
    bins = numba.typed.List()
    for i in range(ints.max() + 1):
        inner = numba.typed.List()
        inner.append(0)  # An awkward way of forcing type inference.
        inner.pop()
        bins.append(inner)

    for x, i in enumerate(ints):
        bins[i].append(x)

    return bins

Przetestowałem je pod kątem następujących elementów:

def enum_bins_dict(ints):
    enum_bins = defaultdict(list)
    for k, v in enumerate(ints):
        enum_bins[v].append(k)
    return enum_bins

def enum_bins_list(ints):
    enum_bins = [[] for i in range(ints.max() + 1)]
    for x, i in enumerate(ints):
        enum_bins[i].append(x)
    return enum_bins

def enum_bins_sparse(ints):
    M, N = ints.max() + 1, ints.size
    return sparse.csc_matrix((ints, ints, np.arange(N + 1)),
                             (M, N)).tolil().rows.tolist()

Przetestowałem je również w stosunku do wstępnie skompilowanej wersji cytonu podobnej do enum_bins_numba_buffer(opisanej szczegółowo poniżej).

Na liście dziesięciu milionów losowych liczb całkowitych ( ints = np.random.randint(0, 100, 10000000)) otrzymuję następujące wyniki:

enum_bins_dict(ints)
3.71 s ± 80.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

enum_bins_list(ints)
3.28 s ± 52.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

enum_bins_sparse(ints)
1.02 s ± 34.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

enum_bins_numba(ints)
693 ms ± 5.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

enum_bins_cython(ints)
82.3 ms ± 1.77 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

enum_bins_numba_buffer(ints)
77.4 ms ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Imponująco, ten sposób pracy z numbaprogramem przewyższa cythonwersję tej samej funkcji, nawet przy wyłączonym sprawdzaniu granic. Nie mam jeszcze wystarczającej znajomości, pythranaby przetestować to podejście przy użyciu tej metody, ale chciałbym zobaczyć porównanie. Wydaje się prawdopodobne, na podstawie tego przyspieszenia, że ​​ta pythranwersja może być nieco szybsza.

Oto cythonwersja w celach informacyjnych z kilkoma instrukcjami kompilacji. Po cythonzainstalowaniu będziesz potrzebować prostego setup.pypliku takiego jak ten:

from distutils.core import setup
from distutils.extension import Extension
from Cython.Build import cythonize
import numpy

ext_modules = [
    Extension(
        'enum_bins_cython',
        ['enum_bins_cython.pyx'],
    )
]

setup(
    ext_modules=cythonize(ext_modules),
    include_dirs=[numpy.get_include()]
)

I moduł cytonowy enum_bins_cython.pyx:

# cython: language_level=3

import cython
import numpy
cimport numpy

@cython.boundscheck(False)
@cython.cdivision(True)
@cython.wraparound(False)
cdef void enum_bins_inner(long[:] ints, long[:] bins, long[:] starts) nogil:
    cdef long i, x
    for x in range(len(ints)):
        i = ints[x]
        bins[starts[i]] = x
        starts[i] = starts[i] + 1

def enum_bins_cython(ints):
    assert (ints >= 0).all()
    # There might be a way to avoid storing two offset arrays and
    # save memory, but `enum_bins_inner` modifies the input, and
    # having separate lists of starts and ends is convenient for
    # the final partition stage.
    ends = numpy.bincount(ints).cumsum()
    starts = numpy.empty(ends.shape, dtype=numpy.int64)
    starts[1:] = ends[:-1]
    starts[0] = 0

    bins = numpy.empty(ints.shape, dtype=numpy.int64)
    enum_bins_inner(ints, bins, starts)

    starts[1:] = ends[:-1]
    starts[0] = 0
    return [bins[s:e] for s, e in zip(starts, ends)]

Z tymi dwoma plikami w katalogu roboczym uruchom następującą komendę:

python setup.py build_ext --inplace

Następnie możesz zaimportować funkcję za pomocą from enum_bins_cython import enum_bins_cython.

senderle
źródło
Zastanawiam się, czy zdajesz sobie sprawę z pythranu, który w bardzo szerokim ujęciu jest podobny do numby. Dodałem rozwiązanie z pythranem do mojego postu. Przy tej okazji wydaje się, że pythran ma przewagę, zapewniając szybsze i znacznie bardziej pytoniczne rozwiązanie.
Paul Panzer
@PaulPanzer ciekawe! Nie słyszałem o tym. Rozumiem, że deweloperzy numba będą dodawać oczekiwany cukier składniowy, gdy kod listy będzie stabilny. Wydaje się również, że istnieje tutaj wygoda / szybkość kompromisu - dekorator jitów jest bardzo łatwy do zintegrowania ze zwykłą bazą kodu Pythona w porównaniu z podejściem wymagającym oddzielnych wstępnie skompilowanych modułów. Ale trzykrotne przyspieszenie w stosunku do podejrzanego podejścia jest naprawdę imponujące, a nawet zaskakujące!
senderle
Właśnie pamiętałem, że zasadniczo zrobiłem to wcześniej: stackoverflow.com/q/55226662/7207392 . Czy miałbyś coś przeciwko dodaniu wersji Numba i Cython do tego pytania i odpowiedzi? Jedyna różnica polega na tym, że nie binujemy wskaźników 0,1,2, ... ale zamiast tego kolejną tablicę. I nie zawracamy sobie głowy rozdrabnianiem powstałej tablicy.
Paul Panzer
@PaulPanzer ah bardzo fajne. Spróbuję go dodać w pewnym momencie dzisiaj lub jutro. Sugerujesz osobną odpowiedź lub po prostu edycję odpowiedzi? Szczęśliwy z obu stron!
senderle
Świetny! Myślę, że osobny post byłby lepszy, ale bez silnych preferencji.
Paul Panzer
6

Oto naprawdę dziwny sposób na zrobienie tego, co jest okropne, ale uznałem, że to zbyt zabawne, aby nie udostępniać - i wszystko numpy!

out = np.array([''] * (x.max() + 1), dtype = object)
np.add.at(out, x, ["{} ".format(i) for i in range(x.size)])
[[int(i) for i in o.split()] for o in out]

Out[]:
[[3, 4], [0, 5], [1, 2], [6], [], [7]]

EDYCJA: to najlepsza metoda, jaką mogłem znaleźć na tej ścieżce. Jest wciąż 10 razy wolniejszy niż argsortrozwiązanie @PaulPanzer :

out = np.empty((x.max() + 1), dtype = object)
out[:] = [[]] * (x.max() + 1)
coords = np.empty(x.size, dtype = object)
coords[:] = [[i] for i in range(x.size)]
np.add.at(out, x, coords)
list(out)
Daniel F.
źródło
2

Możesz to zrobić, tworząc słownik liczb, kluczami byłyby liczby, a wartości powinny być indeksami, które widziały liczby, jest to jeden z najszybszych sposobów, aby to zrobić, możesz zobaczyć kod poniżej:

>>> import numpy as np
>>> a = np.array([1 ,2 ,2 ,0 ,0 ,1 ,3, 5])
>>> b = {}
# Creating an empty list for the numbers that exist in array a
>>> for i in range(np.min(a),np.max(a)+1):
    b[str(i)] = []

# Adding indices to the corresponding key
>>> for i in range(len(a)):
    b[str(a[i])].append(i)

# Resulting Dictionary
>>> b
{'0': [3, 4], '1': [0, 5], '2': [1, 2], '3': [6], '4': [], '5': [7]}

# Printing the result in the way you wanted.
>>> for i in sorted (b.keys()) :
     print(b[i], end = " ")

[3, 4] [0, 5] [1, 2] [6] [] [7] 
Mohsen_Fatemi
źródło
1

Pseudo kod:

  1. uzyskaj „liczbę tablic 1d w tablicy 2d”, odejmując minimalną wartość tablicy numpy od wartości maksymalnej, a następnie plus jeden. W twoim przypadku będzie to 5-0 + 1 = 6

  2. zainicjuj tablicę 2d liczbą zawartych w niej tablic 1d. W twoim przypadku zainicjuj tablicę 2d z tablicą 6 1d. Każda tablica 1d odpowiada unikalnemu elementowi w tablicy numpy, na przykład pierwsza tablica 1d odpowiada „0”, druga tablica 1d odpowiada „1”, ...

  3. zapętlić pętlę przez tablicę numpy, umieścić indeks elementu w odpowiedniej odpowiedniej tablicy 1d. W twoim przypadku indeks pierwszego elementu w tablicy numpy zostanie umieszczony w drugiej tablicy 1d, indeks drugiego elementu w tablicy numpy zostanie umieszczony w trzeciej tablicy 1d ...

Uruchomienie tego pseudokodu zajmie czas liniowy, ponieważ zależy to od długości tablicy numpy.

ubikayu
źródło
1

To daje dokładnie to, czego chcesz i zajęłoby około 2,5 sekundy na 10 000 000 na moim komputerze:

import numpy as np
import timeit

# x = np.array("1 2 2 0 0 1 3 5".split(),int)
x = np.random.randint(0, 100, 100000)

def create_index_list(x):
    d = {}
    max_value = -1
    for i,v in enumerate(x):
        if v > max_value:
            max_value = v
        try:
            d[v].append(i)
        except:
            d[v] = [i]
    result_list = []
    for i in range(max_value+1):
        if i in d:
            result_list.append(d[i])
        else:
            result_list.append([])
    return result_list

# print(create_index_list(x))
print(timeit.timeit(stmt='create_index_list(x)', number=1, globals=globals()))
Eli Mintz
źródło
0

Biorąc pod uwagę listę elementów, chcesz utworzyć pary (element, indeks). W czasie liniowym można to zrobić jako:

hashtable = dict()
for idx, val in enumerate(mylist):
    if val not in hashtable.keys():
         hashtable[val] = list()
    hashtable[val].append(idx)
newlist = sorted(hashtable.values())

Powinno to zająć czas O (n). Nie mogę teraz wymyślić szybszego rozwiązania, ale zaktualizuję tutaj, jeśli to zrobię.

Ramsha Siddiqui
źródło