def user_transformation(delta_dict):
from typing import List, Optional, Literal
from datetime import datetime
from collections import defaultdict
from pyspark.sql import functions as F
from pyspark.sql import types as T
# SCD columns
AVAILABLE_FROM = "_available_from"
AVAILABLE_UNTIL = "_available_until"
DELETED = "_deleted"
# Sync at
NEKT_SYNC_AT = "_nekt_sync_at"
def calculate_diffs(*, input_df: DataFrame, output_df: DataFrame, primary_keys: List[str], sync_timestamp: datetime):
# Get difference between today and latest scd values
# Identify added, deleted, and modified rows
output_current_rows_df = output_df.filter((F.col(AVAILABLE_UNTIL).isNull()) & (~F.col(DELETED)))
added_rows_df = input_df.join(output_current_rows_df, on=primary_keys, how="left_anti").select(input_df["*"])
deleted_rows_df = output_current_rows_df.join(input_df, on=primary_keys, how="left_anti").select(output_df["*"])
common_rows_df = output_current_rows_df.join(input_df, on=primary_keys, how="inner").select(output_df["*"])
# Exclude specific columns from comparison
exclude_columns = [AVAILABLE_FROM, AVAILABLE_UNTIL, DELETED, NEKT_SYNC_AT]
all_columns = [col for col in input_df.columns if col not in exclude_columns]
# Dynamically generate the condition for modified rows
modification_condition = None
for col in all_columns:
col_condition = (output_df[col] != input_df[col])
modification_condition = col_condition if modification_condition is None else modification_condition | col_condition
modified_rows_df = common_rows_df.filter(modification_condition).select(output_df["*"])
# Update the rows in output_df based on modified_rows_df
primary_keys_with_sync_key = primary_keys + [NEKT_SYNC_AT]
if not modified_rows_df.limit(1).count() == 0:
updated_rows = (
output_df.join(modified_rows_df, primary_keys_with_sync_key, how="inner")
.select(output_df["*"])
.withColumn(AVAILABLE_UNTIL, F.lit(sync_timestamp).cast(T.TimestampType()))
)
non_updated_rows = output_df.join(modified_rows_df, primary_keys_with_sync_key, how="left_anti")
new_rows = (
input_df.join(modified_rows_df.select(*primary_keys), on=primary_keys)
.withColumn(AVAILABLE_FROM, F.lit(sync_timestamp))
.withColumn(AVAILABLE_UNTIL, F.lit(None).cast(T.TimestampType()))
.withColumn(DELETED, F.lit(False))
)
output_df = updated_rows.unionByName(non_updated_rows).unionByName(new_rows)
# Handle added rows
if not added_rows_df.limit(1).count() == 0:
added_rows_with_timestamp = (
added_rows_df
.withColumn(AVAILABLE_FROM, F.lit(sync_timestamp))
.withColumn(AVAILABLE_UNTIL, F.lit(None).cast(T.TimestampType()))
.withColumn(DELETED, F.lit(False))
)
output_df = output_df.unionByName(added_rows_with_timestamp)
# Handle deleted rows
if not deleted_rows_df.limit(1).count() == 0:
deleted_rows_with_timestamp = (
deleted_rows_df
.withColumn(AVAILABLE_UNTIL, F.lit(sync_timestamp))
.withColumn(DELETED, F.lit(True))
)
non_updated_rows = output_df.join(deleted_rows_df, primary_keys_with_sync_key, how="left_anti")
output_df = non_updated_rows.unionByName(deleted_rows_with_timestamp)
# Cast the _available_until column to timestamp
output_df = output_df.withColumn(AVAILABLE_UNTIL, output_df[AVAILABLE_UNTIL].cast(T.TimestampType()))
record_count = output_df.count()
transformation_user_logger.info(f"Record count: {record_count}")
return output_df
def apply_scd_type2(*, input_delta_table, output_delta_table, primary_keys: List[str], input_delta_version: Optional[int] = None):
# Sync timestamp
output_df = output_delta_table.toDF() if output_delta_table else None
history_df = input_delta_table.history()
delta_path = input_delta_table.detail().select("location").first()[0]
# Get difference between today and latest scd values
latest_history_version = history_df.orderBy("version", ascending=False).select("version", "timestamp").collect()[0]
latest_version = latest_history_version["version"]
latest_version_timestamp = latest_history_version["timestamp"]
latest_version_df = spark.read.format("delta").option("versionAsOf", latest_version).load(delta_path)
if not output_df:
output_df = (
latest_version_df
.withColumn(AVAILABLE_FROM, F.lit(latest_version_timestamp))
.withColumn(AVAILABLE_UNTIL, F.lit(None).cast(T.TimestampType()))
.withColumn(DELETED, F.lit(False))
)
return output_df
output_df = calculate_diffs(
input_df=latest_version_df,
output_df=output_df,
primary_keys=primary_keys,
sync_timestamp=latest_version_timestamp,
)
return output_df
# Load delta tables
output_table_delta: DeltaTable = delta_dict.get("trusted").get("output_table")
input_table_delta: DeltaTable = delta_dict.get("raw").get("input_table")
# SCD Type 2:
new_df = apply_scd_type2(
input_delta_table=input_table_delta,
output_delta_table=output_table_delta,
primary_keys=["id"],
)
return new_df