Jak wybrać pierwszy wiersz każdej grupy?

143

Mam DataFrame wygenerowaną w następujący sposób:

df.groupBy($"Hour", $"Category")
  .agg(sum($"value") as "TotalValue")
  .sort($"Hour".asc, $"TotalValue".desc))

Wyniki wyglądają następująco:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   0|   cat13|      22.1|
|   0|   cat95|      19.6|
|   0|  cat105|       1.3|
|   1|   cat67|      28.5|
|   1|    cat4|      26.8|
|   1|   cat13|      12.6|
|   1|   cat23|       5.3|
|   2|   cat56|      39.6|
|   2|   cat40|      29.7|
|   2|  cat187|      27.9|
|   2|   cat68|       9.8|
|   3|    cat8|      35.6|
| ...|    ....|      ....|
+----+--------+----------+

Jak widać, DataFrame jest uporządkowana według Hourrosnącej kolejności, a następnie TotalValuemalejącej.

Chciałbym zaznaczyć górny wiersz każdej grupy, tj

  • z grupy Hour == 0 select (0, cat26,30.9)
  • z grupy Hour == 1 select (1, cat67,28,5)
  • z grupy Hour == 2 select (2, cat56,39.6)
  • i tak dalej

Zatem pożądane wyjście to:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   1|   cat67|      28.5|
|   2|   cat56|      39.6|
|   3|    cat8|      35.6|
| ...|     ...|       ...|
+----+--------+----------+

Przydatna może być również możliwość wybrania N górnych wierszy z każdej grupy.

Każda pomoc jest bardzo ceniona.

Rami
źródło

Odpowiedzi:

231

Funkcje okna :

Coś takiego powinno załatwić sprawę:

import org.apache.spark.sql.functions.{row_number, max, broadcast}
import org.apache.spark.sql.expressions.Window

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)

val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Ta metoda będzie nieefektywna w przypadku znacznego odchylenia danych.

Zwykła agregacja SQL, po której następujejoin :

Alternatywnie możesz dołączyć do zagregowanej ramki danych:

val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))

val dfTopByJoin = df.join(broadcast(dfMax),
    ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
  .drop("max_hour")
  .drop("max_value")

dfTopByJoin.show

// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Zachowa zduplikowane wartości (jeśli na godzinę jest więcej niż jedna kategoria o tej samej łącznej wartości). Możesz je usunąć w następujący sposób:

dfTopByJoin
  .groupBy($"hour")
  .agg(
    first("category").alias("category"),
    first("TotalValue").alias("TotalValue"))

Korzystanie z zamawiania powyżejstructs :

Zgrabna, choć niezbyt dobrze przetestowana sztuczka, która nie wymaga łączenia ani funkcji okna:

val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
  .groupBy($"hour")
  .agg(max("vs").alias("vs"))
  .select($"Hour", $"vs.Category", $"vs.TotalValue")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Z DataSet API (Spark 1.6+, 2.0+):

Spark 1.6 :

case class Record(Hour: Integer, Category: String, TotalValue: Double)

df.as[Record]
  .groupBy($"hour")
  .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
  .show

// +---+--------------+
// | _1|            _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+

Spark 2.0 lub nowszy :

df.as[Record]
  .groupByKey(_.Hour)
  .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)

Dwie ostatnie metody mogą wykorzystywać łączenie strony mapy i nie wymagają pełnego tasowania, więc przez większość czasu powinny wykazywać lepszą wydajność w porównaniu z funkcjami okna i złączeniami. Te laski mogą być również używane z Structured Streaming w completedtrybie wyjściowym.

Nie używaj :

df.orderBy(...).groupBy(...).agg(first(...), ...)

Może się wydawać, że działa (szczególnie w localtrybie), ale jest zawodne (zobacz SPARK-16207 , podziękowania dla Tzacha Zohara za powiązanie odpowiedniego problemu z JIRA i SPARK-30335 ).

Ta sama uwaga dotyczy

df.orderBy(...).dropDuplicates(...)

który wewnętrznie używa równoważnego planu wykonania.

zero323
źródło
3
Wygląda na to, że od iskry 1.6 to row_number () zamiast rowNumber
Adam Szałucha
O nie używaj df.orderBy (...). GropBy (...). W jakich okolicznościach możemy polegać na orderBy (...)? lub jeśli nie możemy być pewni, czy polecenie orderBy () da poprawny wynik, jakie mamy alternatywy?
Ignacio Alorre
Mogę coś przeoczyć, ale generalnie zaleca się unikanie groupByKey , zamiast tego należy użyć lowerByKey. Oszczędzisz też jedną linię.
Thomas,
3
@Thomas unikanie groupBy / groupByKey jest tylko wtedy, gdy mamy do czynienia z RDD, zauważysz, że interfejs API zestawu danych nie ma nawet funkcji redukujByKey.
posmakuj
16

W przypadku Spark 2.0.2 z grupowaniem według wielu kolumn:

import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)

val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
Antonín Hoskovec
źródło
8

To jest dokładnie to samo z zero323 „s odpowiedzi , ale w SQL kwerendy sposób.

Zakładając, że dataframe jest tworzona i rejestrowana jako

df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0   |cat26   |30.9      |
//|0   |cat13   |22.1      |
//|0   |cat95   |19.6      |
//|0   |cat105  |1.3       |
//|1   |cat67   |28.5      |
//|1   |cat4    |26.8      |
//|1   |cat13   |12.6      |
//|1   |cat23   |5.3       |
//|2   |cat56   |39.6      |
//|2   |cat40   |29.7      |
//|2   |cat187  |27.9      |
//|2   |cat68   |9.8       |
//|3   |cat8    |35.6      |
//+----+--------+----------+

Funkcja okna:

sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn  FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Zwykła agregacja SQL, po której następuje łączenie:

sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
  "(select Hour, Category, TotalValue from table tmp1 " +
  "join " +
  "(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
  "on " +
  "tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
  "group by tmp3.Hour")
  .show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Korzystanie z porządkowania na strukturach:

sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

DataSets sposób i nie robią są takie same jak w oryginalnej odpowiedzi

Ramesh Maharjan
źródło
2

Wzorzec jest grupowany według klawiszy => zrób coś dla każdej grupy, np. Zredukuj => wróć do ramki danych

Myślałem, że abstrakcja Dataframe jest w tym przypadku trochę kłopotliwa, więc użyłem funkcjonalności RDD

 val rdd: RDD[Row] = originalDf
  .rdd
  .groupBy(row => row.getAs[String]("grouping_row"))
  .map(iterableTuple => {
    iterableTuple._2.reduce(reduceFunction)
  })

val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)
Gumowa kaczuszka
źródło
1

Poniższe rozwiązanie wykonuje tylko jedną grupę GroupBy i wyodrębnia wiersze ramki danych, które zawierają wartość maxValue w jednym ujęciu. Nie ma potrzeby dalszych połączeń ani systemu Windows.

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.DataFrame

//df is the dataframe with Day, Category, TotalValue

implicit val dfEnc = RowEncoder(df.schema)

val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}
elghoto
źródło
Ale najpierw wszystko tasuje. Nie jest to żadna poprawa (może nie gorsza niż funkcje okna, w zależności od danych).
Alper T. Turker
masz grupę na pierwszym miejscu, co spowoduje przetasowanie. Nie jest gorsza niż funkcja okna, ponieważ w funkcji okna będzie oceniać okno dla każdego pojedynczego wiersza w ramce danych.
elghoto
1

Dobrym sposobem na zrobienie tego z interfejsem dataframe API jest użycie takiej logiki argmax

  val df = Seq(
    (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
    (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
    (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
    (3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")

  df.groupBy($"Hour")
    .agg(max(struct($"TotalValue", $"Category")).as("argmax"))
    .select($"Hour", $"argmax.*").show

 +----+----------+--------+
 |Hour|TotalValue|Category|
 +----+----------+--------+
 |   1|      28.5|   cat67|
 |   3|      35.6|    cat8|
 |   2|      39.6|   cat56|
 |   0|      30.9|   cat26|
 +----+----------+--------+
randal25
źródło
0

Tutaj możesz to zrobić -

   val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")

data.withColumnRenamed("_1","Hour").show
Shubham Agrawal
źródło
-2

Możemy użyć funkcji okna rank () (gdzie wybrałbyś rank = 1) rank po prostu dodaje liczbę dla każdego wiersza grupy (w tym przypadku byłaby to godzina)

oto przykład. (z https://github.com/jaceklaskowski/mastering-apache-spark-book/blob/master/spark-sql-functions.adoc#rank )

val dataset = spark.range(9).withColumn("bucket", 'id % 3)

import org.apache.spark.sql.expressions.Window
val byBucket = Window.partitionBy('bucket).orderBy('id)

scala> dataset.withColumn("rank", rank over byBucket).show
+---+------+----+
| id|bucket|rank|
+---+------+----+
|  0|     0|   1|
|  3|     0|   2|
|  6|     0|   3|
|  1|     1|   1|
|  4|     1|   2|
|  7|     1|   3|
|  2|     2|   1|
|  5|     2|   2|
|  8|     2|   3|
+---+------+----+
Vasile Surdu
źródło