from listenbrainz_spark.stats import run_query from pyspark.sql.functions import collect_list, sort_array, struct def get_artists(table): """ Get artist information (artist_name, artist_msid etc) for every user ordered by listen count Args: table (str): name of the temporary table. Returns: iterator (iter): an iterator over result { user1: [{ 'artist_name': str, 'artist_msid': str, 'artist_mbids': list(str), 'listen_count': int }], user2: [{...}], } """ result = run_query(""" SELECT user_name , artist_name , CASE WHEN cardinality(artist_mbids) > 0 THEN NULL ELSE nullif(artist_msid, '') END as artist_msid , artist_mbids , count(artist_name) as listen_count FROM {table} GROUP BY user_name , artist_name , artist_msid , artist_mbids """.format(table=table)) iterator = result \ .withColumn("artists", struct("listen_count", "artist_name", "artist_msid", "artist_mbids")) \ .groupBy("user_name") \ .agg(sort_array(collect_list("artists"), asc=False).alias("artists")) \ .toLocalIterator() return iterator