Wykresy punktowe w Pandas / Pyplot: Jak kreślić według kategorii

89

Próbuję wykonać prosty wykres punktowy w pyplocie przy użyciu obiektu Pandas DataFrame, ale chcę wydajnego sposobu wykreślania dwóch zmiennych, ale symbole mają podyktowane przez trzecią kolumnę (klucz). Próbowałem różnych sposobów korzystania z df.groupby, ale bez powodzenia. Przykładowy skrypt df znajduje się poniżej. Powoduje to kolorowanie znaczników zgodnie z „kluczem1”, ale chciałbym zobaczyć legendę z kategoriami „klucz1”. Jestem blisko Dzięki.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
plt.show()
user2989613
źródło

Odpowiedzi:

118

Możesz użyć scatterdo tego, ale wymaga to posiadania wartości liczbowych dla twojego key1, a nie będziesz miał legendy, jak zauważyłeś.

Lepiej jest po prostu używać takich plotdyskretnych kategorii. Na przykład:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
fig, ax = plt.subplots()
ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend()

plt.show()

wprowadź opis obrazu tutaj

Jeśli chcesz, aby rzeczy wyglądały jak domyślny pandasstyl, po prostu zaktualizuj rcParamsarkusz stylów pandy i użyj jego generatora kolorów. (Poprawiam też nieco legendę):

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
plt.rcParams.update(pd.tools.plotting.mpl_stylesheet)
colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random')

fig, ax = plt.subplots()
ax.set_color_cycle(colors)
ax.margins(0.05)
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend(numpoints=1, loc='upper left')

plt.show()

wprowadź opis obrazu tutaj

Joe Kington
źródło
Dlaczego w powyższym przykładzie RGB symbol jest pokazany dwukrotnie w legendzie? Jak pokazać tylko raz?
Steve Schulist
1
@SteveSchulist - użyj, ax.legend(numpoints=1)aby wyświetlić tylko jeden znacznik. Są dwa, tak jak w przypadku Line2D, często istnieje linia łącząca dwa znaczniki.
Joe Kington
Ten kod działał u mnie tylko po dodaniu plt.hold(True)po ax.plot()poleceniu. Każdy pomysł, dlaczego?
Yuval Atzmon
set_color_cycle() został uznany za przestarzały w matplotlib 1.5. Jest set_prop_cycle()teraz.
ale
52

Można to łatwo zrobić z Seaborn ( pip install seaborn) jako onelinerem

sns.scatterplot(x_vars="one", y_vars="two", data=df, hue="key1") :

import seaborn as sns
import pandas as pd
import numpy as np
np.random.seed(1974)

df = pd.DataFrame(
    np.random.normal(10, 1, 30).reshape(10, 3),
    index=pd.date_range('2010-01-01', freq='M', periods=10),
    columns=('one', 'two', 'three'))
df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8)

sns.scatterplot(x="one", y="two", data=df, hue="key1")

wprowadź opis obrazu tutaj

Oto ramka danych w celach informacyjnych:

wprowadź opis obrazu tutaj

Ponieważ dane zawierają trzy zmienne kolumny, warto wykreślić wszystkie wymiary parami za pomocą:

sns.pairplot(vars=["one","two","three"], data=df, hue="key1")

wprowadź opis obrazu tutaj

https://rasbt.github.io/mlxtend/user_guide/plotting/category_scatter/ to kolejna opcja.

Bob Baxley
źródło
19

Z pomocą przychodzi plt.scattermi do głowy tylko jedno: użyć artysty proxy:

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
x=ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)

ccm=x.get_cmap()
circles=[Line2D(range(1), range(1), color='w', marker='o', markersize=10, markerfacecolor=item) for item in ccm((array([4,6,8])-4.0)/4)]
leg = plt.legend(circles, ['4','6','8'], loc = "center left", bbox_to_anchor = (1, 0.5), numpoints = 1)

A wynik jest taki:

wprowadź opis obrazu tutaj

CT Zhu
źródło
10

Możesz użyć df.plot.scatter i przekazać tablicę do c = argument definiującą kolor każdego punktu:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
colors = np.where(df["key1"]==4,'r','-')
colors[df["key1"]==6] = 'g'
colors[df["key1"]==8] = 'b'
print(colors)
df.plot.scatter(x="one",y="two",c=colors)
plt.show()

wprowadź opis obrazu tutaj

Arjaan Buijk
źródło
4

Możesz także wypróbować Altair lub ggpot, które koncentrują się na deklaratywnych wizualizacjach.

import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

Kod Altaira

from altair import Chart
c = Chart(df)
c.mark_circle().encode(x='x', y='y', color='label')

wprowadź opis obrazu tutaj

kod ggplot

from ggplot import *
ggplot(aes(x='x', y='y', color='label'), data=df) +\
geom_point(size=50) +\
theme_bw()

wprowadź opis obrazu tutaj

Nipun Batra
źródło
3

Począwszy od Matplotlib 3.1 możesz używać .legend_elements(). Przykład pokazano w sekcji Automatyczne tworzenie legendy . Zaletą jest to, że można użyć pojedynczego wywołania scatter.

W tym przypadku:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)


fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
ax.legend(*sc.legend_elements())
plt.show()

wprowadź opis obrazu tutaj

Gdyby klucze nie zostały bezpośrednio podane jako liczby, wyglądałoby to tak

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = list("AAABBBCCCC")

labels, index = np.unique(df["key1"], return_inverse=True)

fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = index, alpha = 0.8)
ax.legend(sc.legend_elements()[0], labels)
plt.show()

wprowadź opis obrazu tutaj

ImportanceOfBeingErnest
źródło
Wystąpił błąd informujący, że obiekt „PathCollection” nie ma atrybutu „legends_elements”. Mój kod jest następujący. fig, ax = plt.subplots(1, 1, figsize = (4,4)) scat = ax.scatter(rand_jitter(important_dataframe["workout_type_int"], jitter = 0.04), important_dataframe["distance"], c = color_list, marker = 'o', alpha = 0.9) print(scat.legends_elements()) #ax.legend(*scat.legend_elements())
Nandish Patel
1
@NandishPatel Sprawdź pierwsze zdanie tej odpowiedzi. Pamiętaj też, aby nie pomylić legends_elementsi legend_elements.
ImportanceOfBeingErnest
Tak dziękuję. To była literówka (legendy / legenda). Pracowałem nad czymś od 6 godzin, więc wersja Matplotlib nie przyszła mi do głowy. Myślałem, że używam najnowszego. Byłem zdezorientowany, że dokumentacja mówi, że istnieje taka metoda, ale kod dawał błąd. Jeszcze raz dziękuję. Mogę teraz spać.
Nandish Patel
2

Jest to dość hakerskie, ale możesz użyć one1jako, Float64Indexaby zrobić wszystko za jednym razem:

df.set_index('one').sort_index().groupby('key1')['two'].plot(style='--o', legend=True)

wprowadź opis obrazu tutaj

Zauważ, że od wersji 0.20.3 sortowanie indeksu jest konieczne , a legenda jest nieco niepewna .

zepsuty
źródło
1

Seaaborn ma funkcję owijania, scatterplotktóra robi to wydajniej.

sns.scatterplot(data = df, x = 'one', y = 'two', data =  'key1'])
yosemite_k
źródło