Home >Database >Mysql Tutorial >How to Get the Top N Records for Each Group in a Spark DataFrame?

How to Get the Top N Records for Each Group in a Spark DataFrame?

DDD
DDDOriginal
2024-12-23 11:46:36437browse

How to Get the Top N Records for Each Group in a Spark DataFrame?

Getting TopN for Each Group in a DataFrame

In Spark DataFrame, we often need to group data by a certain column and retrieve the top N records from each group. For example, you have a DataFrame with user-item rating data, and you want to find the top-rated items for each user.

Scala Solution

The Scala solution to this problem involves using a rank window function. Here's how you can do it:

  1. Define the top N value:

    val n: Int = ???
  2. Create a window definition to partition data by the user column and rank the records in descending order of their rating:

    val w = Window.partitionBy($"user").orderBy(desc("rating"))
  3. Add the rank column to the DataFrame using the rank function:

    df.withColumn("rank", rank().over(w))
  4. Filter the DataFrame to keep only the top N records from each group:

    df.where($"rank" <= n)

Alternative with Row Number

If you don't need to break ties, you can use the row_number function instead of the rank function. This will give you a unique row number for each record in each group:

df.withColumn("row_number", row_number().over(w))

You can then filter the DataFrame to keep only the top N records from each group using the same criteria as before.

The above is the detailed content of How to Get the Top N Records for Each Group in a Spark DataFrame?. For more information, please follow other related articles on the PHP Chinese website!

Statement:
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn