How to get all occurrences of duplicate records in a PySpark DataFrame based on specific columns?
Problem Description:
I need to find all occurrences of duplicate records in a PySpark DataFrame. Following is the sample dataset:
# Prepare Data
data = [("A", "A", 1),
("A", "A", 2),
("A", "A", 3),
("A", "B", 4),
("A", "B", 5),
("A", "C", 6),
("A", "D", 7),
("A", "E", 8),
]
# Create DataFrame
columns= ["col_1", "col_2", "col_3"]
df = spark.createDataFrame(data = data, schema = columns)
df.show(truncate=False)
When I try the following code:
primary_key = ['col_1', 'col_2']
duplicate_records = df.exceptAll(df.dropDuplicates(primary_key))
duplicate_records.show()
The output will be:
As you can see, I don’t get all occurrences of duplicate records based on the Primary Key since one instance of duplicate records is present in "df.dropDuplicates(primary_key)". The 1st and the 4th records of the dataset must be in the output.
Any idea to solve this issue?
Solution – 1
The reason you cant see 1st and the 4th records is dropduplicate keep one of each duplicates. see the code below:
primary_key = ['col_1', 'col_2']
df.dropDuplicates(primary_key).show()
+-----+-----+-----+
|col_1|col_2|col_3|
+-----+-----+-----+
| A| A| 1|
| A| B| 4|
| A| C| 6|
| A| D| 7|
| A| E| 8|
+-----+-----+-----+
For your task you can extract duplicated keys and join it with your main dataframe:
duplicated_keys = (
df
.groupby(primary_key)
.count()
.filter(F.col('count') > 1)
.drop(F.col('count'))
)
(
df
.join(F.broadcast(duplicated_keys), primary_key)
).show()
+-----+-----+-----+-----+
|col_1|col_2|col_3|count|
+-----+-----+-----+-----+
| A| A| 1| 3|
| A| A| 2| 3|
| A| A| 3| 3|
| A| B| 4| 2|
| A| B| 5| 2|
+-----+-----+-----+-----+
Solution – 2
Here are my 2 cents
We can achieve this using Window function
Create dataframe:
data = [("A", "A", 1), ("A", "A", 2), ("A", "A", 3), ("A", "B", 4), ("A", "B", 5), ("A", "C", 6), ("A", "D", 7), ("A", "E", 8), ] columns= ["col_1", "col_2", "col_3"] df = spark.createDataFrame(data = data, schema = columns) df.show(truncate=False)
Use the Window function on top of the primary key to evaluate the count and extract only those rows whose count is greater than 1 and then drop the count Column.
primary_key = ['col_1', 'col_2'] windowSpec = Window.partitionBy(primary_key).orderBy(primary_key) df.withColumn('CountColumns',count('*').over(windowSpec)).filter('CountColumns>1').drop('CountColumns').show()