首页 >数据库 >mysql教程 >如何获取 Spark DataFrame 中每组的前 N ​​条记录?

如何获取 Spark DataFrame 中每组的前 N ​​条记录?

DDD
DDD原创
2024-12-23 11:46:36438浏览

How to Get the Top N Records for Each Group in a Spark DataFrame?

获取DataFrame中每个组的TopN

在Spark DataFrame中,我们经常需要按某一列对数据进行分组并检索顶部每组 N 条记录。例如,您有一个包含用户项目评分数据的 DataFrame,并且您希望找到每个用户评分最高的项目。

Scala 解决方案

Scala该问题的解决方案涉及使用排名窗口函数。操作方法如下:

  1. 定义前 N 个值:

    val n: Int = ???
  2. 创建一个窗口定义以按用户对数据进行分区列并按其降序排列记录rating:

    val w = Window.partitionBy($"user").orderBy(desc("rating"))
  3. 使用rank函数将排名列添加到DataFrame中:

    df.withColumn("rank", rank().over(w))
  4. 过滤DataFrame以仅保留每个的前 N ​​条记录group:

    df.where($"rank" <= n)

替代行号

如果不需要打破平局,可以使用 row_number 函数的秩函数。这将为每个组中的每条记录提供唯一的行号:

df.withColumn("row_number", row_number().over(w))

然后,您可以使用与之前相同的条件过滤 DataFrame 以仅保留每个组中的前 N ​​条记录。

以上是如何获取 Spark DataFrame 中每组的前 N ​​条记录?的详细内容。更多信息请关注PHP中文网其他相关文章!

声明:
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn