How to Remove Duplicates in PySpark: A Step-by-Step Guide

In the age of big data, ensuring data quality is more paramount than ever. One common challenge many data practitioners face is dealing with duplicate rows. While a few duplicate entries may seem benign, in a dataset with millions of records, they can significantly skew analytical results.

Enter PySpark, a powerful tool designed for large-scale data processing. One of its strengths lies in its ability to manage and cleanse data efficiently. In this guide, we'll walk through the process of removing duplicate rows in PySpark, ensuring your dataset's integrity.

Setting the Stage: A Sample DataFrame

link to this section

Before we dive into the nitty-gritty of deduplication, let's begin with a sample DataFrame that mimics the real-world scenario of having duplicate records:

from pyspark.sql import Row 
    
sample_data = [ 
    Row(id=1, name="John", city="New York"), 
    Row(id=2, name="Anna", city="Los Angeles"), 
    Row(id=3, name="Mike", city="Chicago"), 
    Row(id=1, name="John", city="New York"), 
    Row(id=4, name="Sara", city="Houston"), 
    Row(id=2, name="Anna", city="Los Angeles") 
] 

df = spark.createDataFrame(sample_data) 
df.show() 

Our dataset clearly contains duplicate entries for "John" and "Anna".

Removing Duplicates: The Direct Approach

link to this section

PySpark's DataFrame API provides a straightforward method called dropDuplicates to help us quickly remove duplicate rows:

cleaned_df = df.dropDuplicates() 
cleaned_df.show() 

With this one-liner, our dataset is already looking much neater:

+---+----+-----------+ 
| id|name| city| 
+---+----+-----------+ 
| 1|John| New York| 
| 2|Anna|Los Angeles| 
| 3|Mike| Chicago| 
| 4|Sara| Houston| 
+---+----+-----------+ 


Column-Specific Deduplication

link to this section

But what if you only want to remove duplicates based on specific columns? PySpark's got you covered:

cleaned_df_id = df.dropDuplicates(["id"]) 
cleaned_df_id.show() 

Here, we've deduplicated based on the id column, offering you a granular control over the process.

Preserving Order While Deduplicating

link to this section

In scenarios where the order of records matters, deduplication should be approached with more care. Thankfully, PySpark lets you combine sorting with deduplication:

ordered_df = df.orderBy("name") 
cleaned_ordered_df = ordered_df.dropDuplicates(["id"]) 
cleaned_ordered_df.show() 

This ensures that while deduplicating, we're retaining the order based on the name column.

Using distinct()

link to this section

One of the most straightforward methods to eliminate duplicate rows is using the distinct() method, which essentially returns a new DataFrame with unique rows:

distinct_df = df.distinct() 
distinct_df.show() 


Deduplication with dropDuplicates with a Subset and orderBy

link to this section

Sometimes, you'd want to keep a particular record among the duplicates. For instance, you might want to keep the latest or the earliest record based on a timestamp:

from pyspark.sql import functions as F 

# Assuming our DataFrame has a 'timestamp' column 
df_with_time = df.withColumn("timestamp", F.current_timestamp()) 

# Keeping the latest 
record latest_records = df_with_time.orderBy("timestamp", ascending=False).dropDuplicates(["id"]) 
latest_records.show() 

# Keeping the earliest record 
earliest_records = df_with_time.orderBy("timestamp").dropDuplicates(["id"]) 
earliest_records.show() 


Using groupBy and Aggregate Functions

link to this section

When you want to remove duplicates but also need some aggregate data, the combination of groupBy and aggregate functions comes in handy:

# Let's assume we want to get the average age of individuals with the same name in our dataset 
agg_df = df.groupBy("name").agg(F.avg("age").alias("average_age")) 
agg_df.show() 


Window Functions for More Complex Deduplication

link to this section

Window functions provide a way to perform calculations across a set of rows related to the current row. This is especially useful for more complex deduplication requirements:

from pyspark.sql.window import Window 
    
# Assuming our DataFrame has a 'score' column, and we want to keep the row with the highest score for each ID 
window_spec = Window.partitionBy("id").orderBy(F.desc("score")) 

# Using the rank function to assign a rank to each row within a window partition 
df_with_rank = df.withColumn("rank", F.rank().over(window_spec)) 

# Filtering rows with rank = 1 will give us the highest score for each ID 
highest_scores = df_with_rank.filter(F.col("rank") == 1).drop("rank") 
highest_scores.show() 


Using exceptAll

link to this section

If you have two DataFrames and want to deduplicate rows of the first DataFrame based on the entire rows of the second DataFrame:

# Let's assume df1 and df2 are our DataFrames, and we want to remove all rows from df1 that exist in df2 
result = df1.exceptAll(df2) 
result.show() 


Conclusion

link to this section

The PySpark framework offers numerous tools and techniques for handling duplicates, ranging from simple one-liners to more advanced methods using window functions. Your choice of method largely depends on the specific needs of your dataset and the nature of the duplicates. By mastering these techniques, you ensure that your data-driven decisions are based on clean and accurate data.