From d48500fc3b49eb93f3440071b33069c8de334ada Mon Sep 17 00:00:00 2001 From: lisazeyen Date: Wed, 31 Jul 2024 16:06:16 +0200 Subject: [PATCH] use ffill and bfill --- scripts/build_energy_totals.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/scripts/build_energy_totals.py b/scripts/build_energy_totals.py index 44bb470f..f4905927 100644 --- a/scripts/build_energy_totals.py +++ b/scripts/build_energy_totals.py @@ -587,7 +587,8 @@ def build_idees(countries: List[str]) -> pd.DataFrame: def fill_missing_years(fill_values: pd.Series) -> pd.Series: """ - Fill missing years for some countries by mean over the other years. + Fill missing years for some countries by first using forward fill (ffill) + and then backward fill (bfill). Parameters ---------- @@ -598,16 +599,23 @@ def fill_missing_years(fill_values: pd.Series) -> pd.Series: Returns ------- pd.Series - A pandas Series with zero values replaced by the mean value of the corresponding - country. + A pandas Series with zero values replaced by the forward-filled and + backward-filled values of the corresponding country. Notes ----- - - The function groups the data by the 'country' level and computes the mean for each group. - - Zero values in the original Series are replaced by the mean value of their respective country group. + - The function groups the data by the 'country' level and performs forward fill + and backward fill to fill zero values. + - Zero values in the original Series are replaced by the ffilled and bfilled + value of their respective country group. """ - means = fill_values.groupby(level="country").transform("mean") - return fill_values.where(fill_values != 0, means) + # Replace zero values with NaN for correct filling + fill_values = fill_values.replace(0, pd.NA) + + # Forward fill and then backward fill within each country group + fill_values = fill_values.groupby(level="country").ffill().bfill() + + return fill_values def build_energy_totals( @@ -724,6 +732,7 @@ def build_energy_totals( eurostat.loc[slicer, eurostat_fuels[fuel]].groupby(level=[0, 1]).sum() ) # fill missing years for some countries by mean over the other years + breakpoint() fill_values = fill_missing_years(fill_values) df.loc[to_fill, f"{fuel} {sector}"] = fill_values