How to Filter Duplicate Rows in a PySpark DataFrame: The Ultimate Guide
Diving Straight into Filtering Duplicate Rows in a PySpark DataFrame
Duplicate rows in a dataset can skew analyses, inflate storage costs, and complicate ETL pipelines. For data engineers working with Apache Spark, identifying and filtering duplicate rows in a PySpark DataFrame is a common task, whether you're cleaning raw data, preparing datasets for machine learning, or ensuring data integrity. PySpark offers intuitive methods to detect and handle duplicates, making it easier to maintain high-quality data. This guide is crafted for data engineers with intermediate PySpark knowledge, walking you through the process of filtering duplicates with practical examples. If you're new to PySpark, check out our PySpark Fundamentals to get started.
In this guide, we'll explore the basics of filtering duplicates, advanced deduplication techniques, handling nested data, using SQL expressions, and optimizing performance. Each section includes clear code examples, outputs, and common pitfalls to help you master duplicate filtering in PySpark. Let’s dive in with a natural, conversational tone to make the concepts approachable and actionable.
Understanding Duplicate Row Filtering in PySpark
Filtering duplicates in PySpark means identifying and either keeping or removing rows that are identical based on all columns or a subset of columns. The primary method, dropDuplicates() (or its alias distinct()), is straightforward and leverages Spark’s distributed computing power to handle large datasets efficiently. You can use it to keep unique rows or filter out duplicates based on specific columns.
Basic Duplicate Filtering Example
Let’s say we have an employee dataset with some duplicate entries, and we want to keep only unique rows. Here’s how we can do it using distinct().
from pyspark.sql import SparkSession
# Initialize Spark session
spark = SparkSession.builder.appName("DuplicateFilterExample").getOrCreate()
# Create employees DataFrame with duplicates
employees_data = [
(1, "Alice", 30, 50000, 101),
(2, "Bob", 25, 45000, 102),
(1, "Alice", 30, 50000, 101), # Duplicate
(3, "Charlie", 35, 60000, 103)
]
employees = spark.createDataFrame(employees_data, ["employee_id", "name", "age", "salary", "dept_id"])
# Remove duplicates using distinct()
filtered_df = employees.distinct()
# Show results
filtered_df.show()
# Output:
# +-----------+-------+---+------+-------+
# |employee_id| name|age|salary|dept_id|
# +-----------+-------+---+------+-------+
# | 1| Alice| 30| 50000| 101|
# | 2| Bob| 25| 45000| 102|
# | 3|Charlie| 35| 60000| 103|
# +-----------+-------+---+------+-------+
# Validate row count
assert filtered_df.count() == 3, "Expected 3 unique rows after filtering"
What’s Happening Here? The distinct() method scans the entire DataFrame and keeps only the first occurrence of each unique row, considering all columns. It’s like saying, “Give me one copy of every unique record.” In this case, the duplicate row for Alice (employee_id=1) is removed, leaving three unique rows. This is perfect for quick deduplication when you want to ensure every row is unique across all columns.
Key Methods:
- distinct(): Removes duplicate rows based on all columns, keeping the first occurrence.
- dropDuplicates(): Similar to distinct(), but allows specifying a subset of columns for deduplication.
Common Mistake: Assuming distinct() considers only primary keys.
# Incorrect assumption: Expecting distinct to consider only employee_id
filtered_df = employees.distinct() # Considers all columns, not just employee_id
# Fix: Use dropDuplicates with specific columns if needed
filtered_df = employees.dropDuplicates(["employee_id"])
Error Output: No error, but unexpected results if you expect deduplication on a subset of columns.
Fix: Use dropDuplicates() with specific columns for targeted deduplication, as we’ll explore next.
Filtering Duplicates Based on Specific Columns
Sometimes, you don’t care about duplicates across all columns—just a few key ones. For example, you might want to ensure each employee_id appears only once, even if other columns differ. This is where dropDuplicates() shines, letting you specify which columns to check for duplicates.
Example: Deduplicating by Employee ID
Let’s filter duplicates based on employee_id, keeping the first occurrence.
# Remove duplicates based on employee_id
filtered_df = employees.dropDuplicates(["employee_id"])
# Show results
filtered_df.show()
# Output:
# +-----------+-------+---+------+-------+
# |employee_id| name|age|salary|dept_id|
# +-----------+-------+---+------+-------+
# | 1| Alice| 30| 50000| 101|
# | 2| Bob| 25| 45000| 102|
# | 3|Charlie| 35| 60000| 103|
# +-----------+-------+---+------+-------+
# Validate
assert filtered_df.count() == 3, "Expected 3 unique employee IDs"
What’s Going On? The dropDuplicates(["employee_id"]) method checks for unique values in the employee_id column and keeps the first row for each unique employee_id. Even though we had two rows for employee_id=1, only the first is retained. This is super useful when you’re cleaning data and want to enforce uniqueness on a key column, like an ID.
Common Mistake: Specifying non-existent columns.
# Incorrect: Non-existent column
filtered_df = employees.dropDuplicates(["emp_id"]) # Raises AnalysisException
# Fix: Verify column names
employees.printSchema()
filtered_df = employees.dropDuplicates(["employee_id"])
Error Output: AnalysisException: cannot resolve 'emp_id'.
Fix: Double-check column names using printSchema() to avoid typos or incorrect references.
Filtering Duplicates in Nested Data
Duplicate filtering gets trickier with nested data, like structs, where you need to consider nested fields for deduplication. PySpark allows you to reference nested fields using dot notation, making it possible to filter duplicates based on specific nested values.
Example: Deduplicating by Nested Email
Imagine our employees DataFrame has a contact struct with email and phone. We want to keep only the first row for each unique email.
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
# Create employees with nested contact data
schema = StructType([
StructField("employee_id", IntegerType()),
StructField("name", StringType()),
StructField("contact", StructType([
StructField("email", StringType()),
StructField("phone", StringType())
])),
StructField("dept_id", IntegerType())
])
employees_data = [
(1, "Alice", {"email": "alice@company.com", "phone": "123-456-7890"}, 101),
(2, "Bob", {"email": "bob@company.com", "phone": "234-567-8901"}, 102),
(3, "Alicia", {"email": "alice@company.com", "phone": "345-678-9012"}, 101)
]
employees = spark.createDataFrame(employees_data, schema)
# Remove duplicates based on contact.email
filtered_df = employees.dropDuplicates(["contact.email"])
# Show results
filtered_df.show()
# Output:
# +-----------+-----+--------------------+-------+
# |employee_id| name| contact|dept_id|
# +-----------+-----+--------------------+-------+
# | 1|Alice|{alice@company.co...| 101|
# | 2| Bob|{bob@company.com,...| 102|
# +-----------+-----+--------------------+-------+
# Validate
assert filtered_df.count() == 2, "Expected 2 unique emails"
What’s Happening? We use dropDuplicates(["contact.email"]) to deduplicate based on the email field within the contact struct. The row for Alicia is removed because her email matches Alice’s, and we keep the first occurrence. This is handy for datasets with nested structures, like JSON data, where you need to enforce uniqueness on specific fields.
Common Mistake: Incorrect nested field reference.
# Incorrect: Wrong nested field
filtered_df = employees.dropDuplicates(["contact.mail"]) # Raises AnalysisException
# Fix: Use correct field
employees.printSchema()
filtered_df = employees.dropDuplicates(["contact.email"])
Error Output: AnalysisException: cannot resolve 'contact.mail'.
Fix: Check the schema with printSchema() to confirm nested field names.
Filtering Duplicates with SQL Expressions
If you’re more comfortable with SQL, PySpark’s SQL module lets you filter duplicates using familiar SQL syntax. By registering a DataFrame as a temporary view, you can use DISTINCT or ROW_NUMBER() to handle duplicates.
Example: SQL-Based Deduplication
Let’s remove duplicate rows using a SQL DISTINCT clause.
# Register DataFrame as a temporary view
employees.createOrReplaceTempView("employees")
# SQL query to remove duplicates
filtered_df = spark.sql("""
SELECT DISTINCT *
FROM employees
""")
# Show results
filtered_df.show()
# Output:
# +-----------+-----+--------------------+-------+
# |employee_id| name| contact|dept_id|
# +-----------+-----+--------------------+-------+
# | 1|Alice|{alice@company.co...| 101|
# | 2| Bob|{bob@company.com,...| 102|
# +-----------+-----+--------------------+-------+
# Validate
assert filtered_df.count() == 2
What’s Going On? The DISTINCT keyword in the SQL query ensures only unique rows are returned, considering all columns. It’s equivalent to distinct() in the DataFrame API and works well for simple deduplication. For more control, you can use ROW_NUMBER() to deduplicate based on specific columns, as shown below.
Example: SQL Deduplication by Specific Column
Let’s deduplicate by contact.email using ROW_NUMBER().
# SQL query with ROW_NUMBER
filtered_df = spark.sql("""
SELECT employee_id, name, contact, dept_id
FROM (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY contact.email ORDER BY employee_id) AS rn
FROM employees
) t
WHERE rn = 1
""")
# Show results
filtered_df.show()
# Output:
# +-----------+-----+--------------------+-------+
# |employee_id| name| contact|dept_id|
# +-----------+-----+--------------------+-------+
# | 1|Alice|{alice@company.co...| 101|
# | 2| Bob|{bob@company.com,...| 102|
# +-----------+-----+--------------------+-------+
# Validate
assert filtered_df.count() == 2
What’s Happening? The ROW_NUMBER() function assigns a unique number to each row within groups of identical contact.email values, ordered by employee_id. We keep only rows where rn = 1, effectively deduplicating by email. This approach gives you fine-grained control, especially when you need to choose which duplicate to keep based on sorting criteria.
Common Mistake: Missing table alias in subquery.
# Incorrect: No alias for subquery
spark.sql("""
SELECT *
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY contact.email ORDER BY employee_id) AS rn
FROM employees
)
WHERE rn = 1
""") # Raises SyntaxError
# Fix: Add alias
spark.sql("""
SELECT *
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY contact.email ORDER BY employee_id) AS rn
FROM employees
) t
WHERE rn = 1
""")
Error Output: SyntaxError: subquery must have an alias.
Fix: Always assign an alias (e.g., t) to subqueries in SQL.
Optimizing Duplicate Filtering Performance
Filtering duplicates on large datasets can be resource-intensive due to the need to compare rows across partitions. Here are four practical ways to boost performance.
- Select Relevant Columns: Reduce data shuffling by selecting only the columns needed for deduplication or downstream tasks.
- Filter Early: Apply any preliminary filters (e.g., by date or department) before deduplication to shrink the dataset.
- Partition Data: Partition the DataFrame by the deduplication columns to minimize shuffling.
- Cache Results: Cache the deduplicated DataFrame if it will be reused in multiple operations.
Example: Optimized Deduplication
Let’s deduplicate by employee_id with optimizations.
# Select relevant columns and filter early
optimized_df = employees.select("employee_id", "name", "salary") \
.filter(col("dept_id") == 101) \
.dropDuplicates(["employee_id"])
# Cache result
optimized_df.cache()
# Show results
optimized_df.show()
# Output:
# +-----------+-----+------+
# |employee_id| name|salary|
# +-----------+-----+------+
# | 1|Alice| 50000|
# +-----------+-----+------+
# Validate
assert optimized_df.count() == 1
What’s Going On? We first filter for dept_id=101 to reduce the dataset, select only employee_id, name, and salary to minimize data processed, and then deduplicate by employee_id. Caching the result ensures faster access for subsequent operations. This approach is ideal for large-scale ETL pipelines where performance matters.
Wrapping Up Your Duplicate Filtering Mastery
Filtering duplicate rows in PySpark DataFrames is a key skill for maintaining clean, reliable datasets. Whether you’re using distinct() for full-row deduplication, dropDuplicates() for specific columns, SQL expressions for flexibility, or optimizing for performance, you now have the tools to tackle duplicates effectively. Try these techniques in your next Spark project and let us know how it goes on X. For more DataFrame operations, check out DataFrame Transformations.
More Spark Resources to Keep You Going
Published: April 17, 2025