{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pypsa\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "from pathlib import Path\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.style.use(\"ggplot\")\n",
    "import pycountry\n",
    "import json\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "available_models = {\n",
    "    \"model_1\": \"elec_s_37_ec_lv1.0_.nc\",\n",
    "    \"model_2\": \"elec_s_37_ec_lv1.0_3H_withUC.nc\",\n",
    "    \"model_3\": \"elec_s_37_ec_lv1.0_Co2L-noUC-noCo2price.nc\",\n",
    "    \"model_4\": \"elec_s_37_ec_lv1.0_Ep.nc\",\n",
    "    \"model_5\": \"elec_s_37_ec_lv1.0_Ep_new.nc\",\n",
    "}\n",
    "\n",
    "model_choice = \"model_5\"\n",
    "\n",
    "data_path = Path.cwd() / \"..\" / \"..\"\n",
    "model_path = data_path / available_models[model_choice]\n",
    "\n",
    "with open(data_path / \"generation_data\" / \"generation_mapper_pypsa.json\", \"r\") as f:\n",
    "    pypsa_generation_mapper = json.load(f)\n",
    "\n",
    "plot_path = data_path / \"plots\" / available_models[model_choice][:-3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.mkdir(data_path / \"plots\" / available_models[model_choice][:-3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = pypsa.Network(str(model_path))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def intersection(alist, blist):\n",
    "    total_list = list()\n",
    "    for val in alist:\n",
    "        if val in blist:\n",
    "            total_list.append(val)\n",
    "    return total_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pypsa_generation_mapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "color_mapper = pd.read_csv(\"color_mapper.csv\", index_col=0).iloc[:, 0]\n",
    "color_mapper.loc[\"Others\"] = \"#D3D3D3\"\n",
    "color_mapper.loc[\"Storage Charge\"] = \"#51dbcc\"\n",
    "color_mapper.loc[\"Storage Discharge\"] = \"#51dbcc\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "color_mapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "countries = set([col[:2] for col in n.generators_t.p.columns])\n",
    "gen = set([col[6:] for col in n.generators_t.p.columns])\n",
    "\n",
    "for i, country in enumerate(countries):\n",
    "    df = pd.DataFrame(index=n.generators_t.p.index)\n",
    "    # country_generation = [col for col in n.generators_t.p.columns if col.startswith(country)]\n",
    "    country_generation = n.generators.loc[n.generators.bus.str.contains(country)]\n",
    "\n",
    "    for key, gens in pypsa_generation_mapper.items():\n",
    "        # curr_gen = country_generation.loc[\n",
    "        #     (country_generation.carrier.str.contains(tech) for tech in gens).astype(bool)].index\n",
    "        curr_gen = country_generation.loc[\n",
    "            country_generation.carrier.apply(lambda carr: carr in gens)\n",
    "        ].index\n",
    "\n",
    "        if len(curr_gen):\n",
    "            df[key] = n.generators_t.p[curr_gen].mean(axis=1)\n",
    "        else:\n",
    "            df[key] = np.zeros(len(df))\n",
    "\n",
    "    df.to_csv(data_path / \"pypsa_data\" / (country + \".csv\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_inflow_cols = [\n",
    "    \"Solar\",\n",
    "    \"Wind Onshore\",\n",
    "    \"Nuclear\",\n",
    "    \"Lignite\",\n",
    "    \"Inflow Lines\",\n",
    "    \"Inflow Links\",\n",
    "    \"Wind Offshore\",\n",
    "    \"Biomass\",\n",
    "    \"Run of River\",\n",
    "    \"Hydro\",\n",
    "    \"Hard Coal\",\n",
    "    \"Gas\",\n",
    "    \"Oil\",\n",
    "]\n",
    "total_outflow_cols = [\"Outflow Links\", \"Outflow Lines\", \"Storage Charge\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total_inflow_set = set()\n",
    "total_outflow_set = set()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "\n",
    "index = n.generators_t.p.index\n",
    "\n",
    "pypsa_total_inflow = pd.DataFrame(\n",
    "    np.zeros((len(index), len(total_inflow_cols))),\n",
    "    index=index,\n",
    "    columns=total_inflow_cols,\n",
    ")\n",
    "entsoe_df = pd.read_csv(\n",
    "    data_path / \"harmonised_generation_data\" / (\"prepared_DE.csv\"),\n",
    "    parse_dates=True,\n",
    "    index_col=0,\n",
    ")\n",
    "entsoe_total_inflow = pd.DataFrame(\n",
    "    np.zeros((len(entsoe_df), len(total_inflow_cols))),\n",
    "    index=entsoe_df.index,\n",
    "    columns=total_inflow_cols,\n",
    ")\n",
    "pypsa_total_outflow = pd.DataFrame(\n",
    "    np.zeros((len(index), len(total_outflow_cols))),\n",
    "    index=index,\n",
    "    columns=total_outflow_cols,\n",
    ")\n",
    "total_load = pd.Series(index=index)\n",
    "\n",
    "for num, country in enumerate(os.listdir(data_path / \"pypsa_data\")):\n",
    "    # country = \"DE.csv\"\n",
    "    cc = country[:2]\n",
    "\n",
    "    country_buses = np.unique(\n",
    "        n.generators.loc[n.generators.bus.str.contains(cc)].bus.values\n",
    "    )\n",
    "    print(f\"Buses for country {country[:-4]}: \", country_buses)\n",
    "\n",
    "    if not len(country_buses) == 1:\n",
    "        print(\"Current implementation is for one bus per country\")\n",
    "        print(f\"Skipping!\")\n",
    "        continue\n",
    "\n",
    "    bus = country_buses[0]\n",
    "\n",
    "    \"\"\"    \n",
    "    pypsa_df = pd.read_csv(data_path / \"pypsa_data\" / country, parse_dates=True, index_col=0)\n",
    "    \"\"\"\n",
    "    try:\n",
    "        entsoe_df = pd.read_csv(\n",
    "            data_path / \"harmonised_generation_data\" / (\"prepared_\" + country),\n",
    "            parse_dates=True,\n",
    "            index_col=0,\n",
    "        )\n",
    "\n",
    "        entsoe_df.columns = [col[:-6] for col in entsoe_df.columns]\n",
    "        entsoe_df = entsoe_df.iloc[1:]\n",
    "        entsoe_df = entsoe_df.multiply(1e-3)\n",
    "    except FileNotFoundError:\n",
    "        continue\n",
    "\n",
    "    fig, axs = plt.subplots(4, 3, figsize=(20, 20))\n",
    "\n",
    "    axs[0, 0].set_title(pycountry.countries.get(alpha_2=country[:2]).name)\n",
    "\n",
    "    start = pd.Timestamp(\"2019-01-01\")  # for small time frame\n",
    "    end = pd.Timestamp(\"2019-01-14\")\n",
    "    coarse_freq = \"3d\"\n",
    "\n",
    "    num_techs_shown = 6\n",
    "\n",
    "    energy_inflow = pd.DataFrame(index=n.loads_t.p_set.index)\n",
    "    energy_outflow = pd.DataFrame(index=n.loads_t.p_set.index)\n",
    "\n",
    "    # add generation\n",
    "    country_gen = n.generators.loc[n.generators.bus == bus]\n",
    "\n",
    "    for tech, pypsa_carrier in pypsa_generation_mapper.items():\n",
    "        gens = country_gen.loc[\n",
    "            country_gen.carrier.apply(lambda c: c in pypsa_carrier)\n",
    "        ].index\n",
    "        energy_inflow[tech] = n.generators_t.p[gens].sum(axis=1)\n",
    "\n",
    "    # add inflows from lines\n",
    "    lines0 = n.lines.loc[n.lines.bus0 == bus].index\n",
    "    lines1 = n.lines.loc[n.lines.bus1 == bus].index\n",
    "\n",
    "    lines_flow = np.zeros(energy_inflow.shape[0])\n",
    "    if not lines0.empty:\n",
    "        lines_flow = -n.lines_t.p0[lines0].sum(axis=1)\n",
    "\n",
    "    if not lines1.empty:\n",
    "        lines_flow -= n.lines_t.p1[lines1].sum(axis=1)\n",
    "\n",
    "    energy_inflow[\"Inflow Lines\"] = np.maximum(np.zeros_like(lines_flow), lines_flow)\n",
    "    energy_outflow[\"Outflow Lines\"] = np.minimum(np.zeros_like(lines_flow), lines_flow)\n",
    "\n",
    "    # add inflows from links\n",
    "    links0 = n.links.loc[n.links.bus0 == bus].index\n",
    "    links1 = n.links.loc[n.links.bus1 == bus].index\n",
    "\n",
    "    links_flow = np.zeros(energy_inflow.shape[0])\n",
    "    if not links0.empty:\n",
    "        links_flow = (\n",
    "            -n.links_t.p0[links0]\n",
    "            .multiply(n.links.loc[links0, \"efficiency\"])\n",
    "            .sum(axis=1)\n",
    "        )\n",
    "\n",
    "    if not links1.empty:\n",
    "        links_flow -= (\n",
    "            n.links_t.p1[links1].multiply(n.links.loc[links1, \"efficiency\"]).sum(axis=1)\n",
    "        )\n",
    "\n",
    "    energy_inflow[\"Inflow Links\"] = np.maximum(np.zeros_like(links_flow), links_flow)\n",
    "    energy_outflow[\"Outflow Links\"] = np.minimum(np.zeros_like(links_flow), links_flow)\n",
    "\n",
    "    storage = n.storage_units.loc[n.storage_units.bus == bus].index\n",
    "    if not storage.empty:\n",
    "        storage_p = n.storage_units_t.p[storage].sum(axis=1).values\n",
    "        # energy_inflow[\"Storage Discharge\"] = np.maximum(np.zeros_like(links_flow), storage_p)\n",
    "        energy_inflow[\"Hydro\"] = energy_inflow[\"Hydro\"].values + np.maximum(\n",
    "            np.zeros_like(links_flow), storage_p\n",
    "        )\n",
    "        energy_outflow[\"Storage Charge\"] = np.minimum(\n",
    "            np.zeros_like(links_flow), storage_p\n",
    "        )\n",
    "\n",
    "    energy_inflow = energy_inflow.iloc[:-1].multiply(1e-3)\n",
    "    energy_outflow = energy_outflow.iloc[:-1].multiply(1e-3)\n",
    "    load = n.loads_t.p_set[bus].iloc[:-1].multiply(1e-3)\n",
    "\n",
    "    total_load = total_load.loc[load.index]\n",
    "    total_load = total_load + load\n",
    "\n",
    "    pypsa_total_inflow = pypsa_total_inflow.loc[energy_inflow.index]\n",
    "    pypsa_total_inflow[energy_inflow.columns] = (\n",
    "        pypsa_total_inflow[energy_inflow.columns] + energy_inflow\n",
    "    )\n",
    "\n",
    "    pypsa_total_outflow = pypsa_total_outflow.loc[energy_outflow.index]\n",
    "    pypsa_total_outflow[energy_outflow.columns] = (\n",
    "        pypsa_total_outflow[energy_outflow.columns] + energy_outflow\n",
    "    )\n",
    "\n",
    "    entsoe_total_inflow = entsoe_total_inflow.loc[entsoe_df.index]\n",
    "    entsoe_total_inflow[entsoe_df.columns] = entsoe_total_inflow[\n",
    "        entsoe_df.columns\n",
    "    ] + entsoe_df.fillna(0.0)\n",
    "\n",
    "    show_techs = (\n",
    "        energy_inflow.sum()\n",
    "        .sort_values(ascending=False)\n",
    "        .iloc[:num_techs_shown]\n",
    "        .index.tolist()\n",
    "    )\n",
    "    others = (\n",
    "        energy_inflow.sum()\n",
    "        .sort_values(ascending=False)\n",
    "        .iloc[num_techs_shown:]\n",
    "        .index.tolist()\n",
    "    )\n",
    "    # show_techs = entsoe_df.sum().sort_values(ascending=False).iloc[:num_techs_shown].index.tolist()\n",
    "\n",
    "    show_techs = intersection(show_techs, entsoe_df.columns.tolist())\n",
    "    entsoe_df[\"Others\"] = entsoe_df.drop(columns=show_techs).sum(axis=1)\n",
    "\n",
    "    # entsoe_df[show_techs + [\"Others\"]].loc[start:end].plot.area(ax=axs[0,0])\n",
    "    index = load.loc[start:end].index\n",
    "\n",
    "    entsoe_df.index = load.index\n",
    "\n",
    "    energy_inflow[\"Others\"] = energy_inflow.drop(columns=show_techs).sum(axis=1)\n",
    "\n",
    "    # plot timeframe\n",
    "    axs[0, 0].plot(\n",
    "        index,\n",
    "        load.loc[index].values,\n",
    "        linestyle=\"--\",\n",
    "        color=\"k\",\n",
    "        linewidth=2,\n",
    "        label=\"PyPSA Load\",\n",
    "    )\n",
    "    axs[0, 1].plot(\n",
    "        index, load.loc[index].values, linestyle=\"--\", color=\"k\", linewidth=2\n",
    "    )\n",
    "\n",
    "    axs[0, 1].stackplot(\n",
    "        index,\n",
    "        *[energy_inflow[col].loc[index].values for col in show_techs + [\"Others\"]],\n",
    "        colors=color_mapper.loc[show_techs + [\"Others\"]].tolist(),\n",
    "    )\n",
    "    axs[0, 1].stackplot(\n",
    "        index,\n",
    "        *[energy_outflow[col].loc[index].values for col in energy_outflow.columns],\n",
    "        colors=color_mapper.loc[energy_outflow.columns].tolist(),\n",
    "        labels=energy_outflow.columns,\n",
    "    )\n",
    "\n",
    "    axs[0, 0].stackplot(\n",
    "        index,\n",
    "        *[entsoe_df[col].loc[index].values for col in show_techs + [\"Others\"]],\n",
    "        labels=show_techs + [\"Others\"],\n",
    "        colors=color_mapper.loc[show_techs + [\"Others\"]].tolist(),\n",
    "    )\n",
    "    axs[0, 1].plot(\n",
    "        index,\n",
    "        energy_inflow.loc[index][show_techs + [\"Others\"]].sum(axis=1).values\n",
    "        + energy_outflow.loc[index].sum(axis=1).values,\n",
    "        color=\"brown\",\n",
    "        linestyle=\":\",\n",
    "        linewidth=2,\n",
    "        label=\"Accum Gen\",\n",
    "    )\n",
    "\n",
    "    axs[0, 0].legend()\n",
    "    axs[0, 1].legend()\n",
    "\n",
    "    # plot whole year\n",
    "\n",
    "    index = load.resample(coarse_freq).mean().index\n",
    "\n",
    "    axs[1, 0].plot(\n",
    "        index,\n",
    "        load.resample(coarse_freq).mean().values,\n",
    "        linestyle=\"--\",\n",
    "        color=\"k\",\n",
    "        linewidth=2,\n",
    "        label=\"PyPSA Load\",\n",
    "    )\n",
    "    axs[1, 1].plot(\n",
    "        index,\n",
    "        load.resample(coarse_freq).mean().values,\n",
    "        linestyle=\"--\",\n",
    "        color=\"k\",\n",
    "        linewidth=2,\n",
    "    )\n",
    "\n",
    "    axs[1, 1].stackplot(\n",
    "        index,\n",
    "        *[\n",
    "            energy_inflow[col].resample(coarse_freq).mean().values\n",
    "            for col in show_techs + [\"Others\"]\n",
    "        ],\n",
    "        colors=color_mapper.loc[show_techs + [\"Others\"]].tolist(),\n",
    "    )\n",
    "    axs[1, 1].stackplot(\n",
    "        index,\n",
    "        *[\n",
    "            energy_outflow[col].resample(coarse_freq).mean().values\n",
    "            for col in energy_outflow.columns\n",
    "        ],\n",
    "        colors=color_mapper.loc[energy_outflow.columns].tolist(),\n",
    "        labels=energy_outflow.columns,\n",
    "    )\n",
    "\n",
    "    axs[1, 0].stackplot(\n",
    "        index,\n",
    "        *[\n",
    "            entsoe_df[col].resample(coarse_freq).mean().values\n",
    "            for col in show_techs + [\"Others\"]\n",
    "        ],\n",
    "        colors=color_mapper.loc[show_techs + [\"Others\"]].tolist(),\n",
    "        labels=show_techs + [\"Others\"],\n",
    "    )\n",
    "\n",
    "    axs[1, 1].plot(\n",
    "        index,\n",
    "        energy_inflow.resample(coarse_freq)\n",
    "        .mean()[show_techs + [\"Others\"]]\n",
    "        .sum(axis=1)\n",
    "        .values\n",
    "        + energy_outflow.resample(coarse_freq).mean().sum(axis=1).values,\n",
    "        color=\"brown\",\n",
    "        linestyle=\":\",\n",
    "        linewidth=2,\n",
    "        label=\"Accum Gen\",\n",
    "    )\n",
    "\n",
    "    axs[1, 0].legend()\n",
    "    axs[1, 1].legend()\n",
    "\n",
    "    y_min = pd.concat([energy_outflow.sum(axis=1)]).min()\n",
    "    y_max = pd.concat(\n",
    "        [energy_inflow.sum(axis=1), entsoe_df.sum(axis=1)], ignore_index=True\n",
    "    ).max()\n",
    "\n",
    "    for ax in axs[:2, :2].flatten():\n",
    "        ax.set_ylim(y_min, y_max)\n",
    "        ax.set_ylim(y_min, y_max)\n",
    "\n",
    "    axs[0, 0].set_ylabel(\"ENTSOE Gen and PyPSA Load [GW]\")\n",
    "    axs[0, 1].set_ylabel(\"PyPSA Gen and Load [GW]\")\n",
    "    axs[1, 0].set_ylabel(\"ENTSOE Gen and PyPSA Load [GW]\")\n",
    "    axs[1, 1].set_ylabel(\"PyPSA Gen and Load [GW]\")\n",
    "    axs[2, 0].set_ylabel(\"ENTSOE Gen and PyPSA Load [GW]\")\n",
    "    axs[2, 1].set_ylabel(\"PyPSA Gen and Load [GW]\")\n",
    "\n",
    "    # -------------------------- electricity prices comparison ----------------------------------\n",
    "    prices_col = [\n",
    "        col for col in n.buses_t.marginal_price.columns if col.startswith(country[:2])\n",
    "    ]\n",
    "    pypsa_prices = n.buses_t.marginal_price[prices_col].mean(axis=1)\n",
    "\n",
    "    full_index = pypsa_prices.index\n",
    "\n",
    "    coarse_pypsa_prices = pypsa_prices.resample(coarse_freq).mean()\n",
    "    pypsa_prices = pypsa_prices.loc[start:end]\n",
    "\n",
    "    axs[0, 2].plot(\n",
    "        pypsa_prices.index, pypsa_prices.values, label=\"PyPSA prices\", color=\"royalblue\"\n",
    "    )\n",
    "    axs[1, 2].plot(\n",
    "        coarse_pypsa_prices.index,\n",
    "        coarse_pypsa_prices.values,\n",
    "        label=\"PyPSA prices\",\n",
    "        color=\"royalblue\",\n",
    "    )\n",
    "\n",
    "    try:\n",
    "        entsoe_prices = pd.read_csv(\n",
    "            data_path / \"price_data\" / country,\n",
    "            index_col=0,\n",
    "            parse_dates=True,\n",
    "        ).iloc[:-1]\n",
    "\n",
    "        def make_tz_time(time):\n",
    "            return pd.Timestamp(time).tz_convert(\"utc\")\n",
    "\n",
    "        # entsoe_prices.index = pd.Series(entsoe_prices.index).apply(lambda time: make_tz_time(time))\n",
    "        entsoe_prices.index = full_index\n",
    "        mean_abs_error = mean_absolute_error(\n",
    "            entsoe_prices.values,\n",
    "            n.buses_t.marginal_price[prices_col].mean(axis=1).values,\n",
    "        )\n",
    "\n",
    "        coarse_prices = entsoe_prices.resample(coarse_freq).mean()\n",
    "        entsoe_prices = entsoe_prices.loc[start:end]\n",
    "\n",
    "        axs[0, 2].plot(\n",
    "            entsoe_prices.index,\n",
    "            entsoe_prices.values,\n",
    "            label=\"ENTSOE prices\",\n",
    "            color=\"darkred\",\n",
    "        )\n",
    "        axs[1, 2].plot(\n",
    "            coarse_prices.index,\n",
    "            coarse_prices.values,\n",
    "            label=\"ENTSOE prices\",\n",
    "            color=\"darkred\",\n",
    "        )\n",
    "\n",
    "    except FileNotFoundError:\n",
    "        mean_abs_error = None\n",
    "        pass\n",
    "\n",
    "    upper_lim = pd.concat((entsoe_prices, pypsa_prices), axis=0).max().max()\n",
    "    for ax in axs[:2, 2]:\n",
    "        ax.set_ylim(0, upper_lim)\n",
    "        ax.set_ylabel(\"Electricty Prices [Euro/MWh]\")\n",
    "        ax.legend()\n",
    "\n",
    "    if not mean_abs_error is None:\n",
    "        axs[1, -1].set_title(f\"Mean Abs Error: {np.around(mean_abs_error, decimals=2)}\")\n",
    "\n",
    "    # remaining_cols = energy_inflow.drop(column    s=show_techs+[\"Others\"]).columns.tolist()\n",
    "    # axs[1,0].set_title(f\"Others: {remaining_cols}\")\n",
    "\n",
    "    # ------------------------------- duration curves ------------------------------\n",
    "\n",
    "    entsoe_ddf = entsoe_df[show_techs + [\"Others\"]].reset_index(drop=True)\n",
    "\n",
    "    entsoe_ddf = pd.concat(\n",
    "        [\n",
    "            entsoe_ddf[col].sort_values(ascending=False).reset_index(drop=True)\n",
    "            for col in entsoe_ddf.columns\n",
    "        ],\n",
    "        axis=1,\n",
    "    )\n",
    "\n",
    "    axs[2, 0].stackplot(\n",
    "        range(len(entsoe_ddf)),\n",
    "        *[entsoe_ddf[col].values for col in entsoe_ddf.columns],\n",
    "        colors=color_mapper.loc[entsoe_ddf.columns].tolist(),\n",
    "        labels=entsoe_ddf.columns,\n",
    "    )\n",
    "\n",
    "    pypsa_ddf = energy_inflow[show_techs + [\"Others\"]].reset_index(drop=True)\n",
    "    pypsa_ddf = pd.concat(\n",
    "        [\n",
    "            pypsa_ddf[col].sort_values(ascending=False).reset_index(drop=True)\n",
    "            for col in pypsa_ddf.columns\n",
    "        ],\n",
    "        axis=1,\n",
    "    )\n",
    "\n",
    "    axs[2, 1].stackplot(\n",
    "        range(len(pypsa_ddf)),\n",
    "        *[pypsa_ddf[col].values for col in pypsa_ddf.columns],\n",
    "        colors=color_mapper.loc[pypsa_ddf.columns].tolist(),\n",
    "        labels=pypsa_ddf.columns,\n",
    "    )\n",
    "\n",
    "    ylim_max = max([pypsa_ddf.max(axis=0).sum(), entsoe_ddf.max(axis=0).sum()])\n",
    "\n",
    "    pypsa_ddf = energy_outflow.reset_index(drop=True)\n",
    "    pypsa_ddf = pd.concat(\n",
    "        [\n",
    "            pypsa_ddf[col].sort_values(ascending=True).reset_index(drop=True)\n",
    "            for col in pypsa_ddf.columns\n",
    "        ],\n",
    "        axis=1,\n",
    "    )\n",
    "\n",
    "    axs[2, 1].stackplot(\n",
    "        range(len(pypsa_ddf)),\n",
    "        *[pypsa_ddf[col].values for col in pypsa_ddf.columns],\n",
    "        colors=color_mapper.loc[pypsa_ddf.columns].tolist(),\n",
    "        labels=energy_outflow.columns,\n",
    "    )\n",
    "\n",
    "    ylim_min = energy_outflow.min(axis=0).sum()\n",
    "\n",
    "    for ax in axs[2, :2]:\n",
    "        ax.legend()\n",
    "        ax.set_ylim(ylim_min, ylim_max)\n",
    "\n",
    "    pypsa_totals = pd.concat(\n",
    "        [energy_inflow[show_techs + [\"Others\"]], energy_outflow], axis=1\n",
    "    ).sum()\n",
    "\n",
    "    entsoe_totals = entsoe_df.sum()\n",
    "    totals = pd.DataFrame(index=pypsa_totals.index)\n",
    "\n",
    "    for tech in pypsa_totals.index:\n",
    "        if tech not in entsoe_totals.index:\n",
    "            entsoe_totals.loc[tech] = 0.0\n",
    "\n",
    "    totals[\"Pypsa\"] = pypsa_totals\n",
    "    totals[\"Entsoe\"] = entsoe_totals\n",
    "    totals[\"Technology\"] = totals.index\n",
    "\n",
    "    totals = pd.concat(\n",
    "        [\n",
    "            pd.DataFrame(\n",
    "                {\n",
    "                    \"Source\": [\"PyPSA\" for _ in range(len(pypsa_totals))],\n",
    "                    \"Technology\": pypsa_totals.index,\n",
    "                    \"Total Generation\": pypsa_totals.values,\n",
    "                }\n",
    "            ),\n",
    "            pd.DataFrame(\n",
    "                {\n",
    "                    \"Source\": [\"ENTSO-E\" for _ in range(len(entsoe_totals))],\n",
    "                    \"Technology\": entsoe_totals.index,\n",
    "                    \"Total Generation\": entsoe_totals.values,\n",
    "                }\n",
    "            ),\n",
    "        ],\n",
    "        axis=0,\n",
    "    )\n",
    "\n",
    "    sns.barplot(\n",
    "        data=totals,\n",
    "        x=\"Technology\",\n",
    "        y=\"Total Generation\",\n",
    "        hue=\"Source\",\n",
    "        ax=axs[2, 2],\n",
    "        palette=\"dark\",\n",
    "        alpha=0.6,\n",
    "        edgecolor=\"k\",\n",
    "    )\n",
    "\n",
    "    axs[2, 0].set_xlabel(\"Hours\")\n",
    "    axs[2, 1].set_xlabel(\"Hours\")\n",
    "    axs[2, 2].set_ylabel(\"Total Generation [GWh]\")\n",
    "    axs[2, 2].set_xticks(\n",
    "        axs[2, 2].get_xticks(), axs[2, 2].get_xticklabels(), rotation=45, ha=\"right\"\n",
    "    )\n",
    "\n",
    "    corrs = (\n",
    "        energy_inflow.corrwith(entsoe_df)\n",
    "        .drop(index=\"Others\")\n",
    "        .dropna()\n",
    "        .sort_values(ascending=False)\n",
    "    )\n",
    "\n",
    "    for col, ax in zip(corrs.index[:2].tolist() + [corrs.index[-1]], axs[3]):\n",
    "        ax.scatter(\n",
    "            entsoe_df[col].values,\n",
    "            energy_inflow[col].values,\n",
    "            color=\"darkred\",\n",
    "            alpha=0.5,\n",
    "            s=20,\n",
    "            edgecolor=\"k\",\n",
    "        )\n",
    "        ax.set_title(f\"{col}; Pearson Corr {np.around(corrs.loc[col], decimals=4)}\")\n",
    "        ax.set_xlabel(\"ENTSO-E Generation [GW]\")\n",
    "        ax.set_ylabel(\"PyPSA-Eur Generation [GW]\")\n",
    "\n",
    "    for ax in axs[:2].flatten():\n",
    "        ax.set_xlabel(\"Datetime\")\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(plot_path / (cc + \".pdf\"))\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pypsa_total_inflow.to_csv(\"total_inflow_pypsa.csv\")\n",
    "# pypsa_total_outflow.to_csv(\"total_outflow_pypsa.csv\")\n",
    "# entsoe_total_inflow.to_csv(\"total_inflow_entsoe.csv\")\n",
    "# total_load.to_csv(\"total_load.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import mean_absolute_error\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.style.use(\"ggplot\")\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "\n",
    "entsoe_total_inflow = pd.read_csv(\n",
    "    \"total_inflow_entsoe.csv\", index_col=0, parse_dates=True\n",
    ")\n",
    "pypsa_total_inflow = pd.read_csv(\n",
    "    \"total_inflow_pypsa.csv\", index_col=0, parse_dates=True\n",
    ")\n",
    "pypsa_total_outflow = pd.read_csv(\n",
    "    \"total_outflow_pypsa.csv\", index_col=0, parse_dates=True\n",
    ")\n",
    "total_load = pd.read_csv(\"total_load.csv\", index_col=0, parse_dates=True)\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(10, 6))\n",
    "\n",
    "# pypsa_totals = pd.concat([pypsa_total_inflow, pypsa_total_outflow], axis=1).sum() * 1e-3\n",
    "pypsa_totals = (\n",
    "    pypsa_total_inflow.drop(columns=[\"Inflow Lines\", \"Inflow Links\"]).sum() * 1e-3\n",
    ")\n",
    "\n",
    "entsoe_totals = (\n",
    "    entsoe_total_inflow.drop(columns=[\"Inflow Lines\", \"Inflow Links\"]).sum() * 1e-3\n",
    ")\n",
    "totals = pd.DataFrame(index=pypsa_totals.index)\n",
    "\n",
    "for tech in pypsa_totals.index:\n",
    "    if tech not in entsoe_totals.index:\n",
    "        entsoe_totals.loc[tech] = 0.0\n",
    "\n",
    "totals[\"Pypsa\"] = pypsa_totals\n",
    "totals[\"Entsoe\"] = entsoe_totals\n",
    "totals[\"Technology\"] = totals.index\n",
    "\n",
    "totals = pd.concat(\n",
    "    [\n",
    "        pd.DataFrame(\n",
    "            {\n",
    "                \"Source\": [\"PyPSA\" for _ in range(len(pypsa_totals))],\n",
    "                \"Technology\": pypsa_totals.index,\n",
    "                \"Total Generation\": pypsa_totals.values,\n",
    "            }\n",
    "        ),\n",
    "        pd.DataFrame(\n",
    "            {\n",
    "                \"Source\": [\"ENTSO-E\" for _ in range(len(entsoe_totals))],\n",
    "                \"Technology\": entsoe_totals.index,\n",
    "                \"Total Generation\": entsoe_totals.values,\n",
    "            }\n",
    "        ),\n",
    "    ],\n",
    "    axis=0,\n",
    ")\n",
    "\n",
    "\n",
    "sns.barplot(\n",
    "    data=totals,\n",
    "    x=\"Technology\",\n",
    "    y=\"Total Generation\",\n",
    "    hue=\"Source\",\n",
    "    ax=ax,\n",
    "    palette=\"dark\",\n",
    "    alpha=0.6,\n",
    "    edgecolor=\"k\",\n",
    ")\n",
    "ax.set_ylabel(\"Total Generation [TWh]\")\n",
    "ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=45, ha=\"right\")\n",
    "plt.savefig(plot_path / \"EuropeTotalGeneration.pdf\")\n",
    "plt.show()\n",
    "\n",
    "pypsa_total_inflow = pypsa_total_inflow.drop(columns=[\"Inflow Links\", \"Inflow Lines\"])\n",
    "pypsa_total_outflow = pypsa_total_outflow.drop(\n",
    "    columns=[\"Outflow Links\", \"Outflow Lines\"]\n",
    ")\n",
    "entsoe_total_inflow = entsoe_total_inflow.drop(columns=[\"Inflow Links\", \"Inflow Lines\"])\n",
    "\n",
    "start = pd.Timestamp(\"2019-01-01\")  # for small time frame\n",
    "end = pd.Timestamp(\"2019-01-14\")\n",
    "coarse_freq = \"3d\"\n",
    "\n",
    "index = load.loc[start:end].index\n",
    "cols = pypsa_total_inflow.std(axis=0).sort_values(ascending=True).index\n",
    "cols_out = pypsa_total_outflow.std(axis=0).sort_values(ascending=False).index\n",
    "\n",
    "fig, axs = plt.subplots(4, 2, figsize=(20, 30))\n",
    "\n",
    "axs[0, 0].stackplot(\n",
    "    index,\n",
    "    *[entsoe_total_inflow[col].loc[start:end].values for col in cols],\n",
    "    colors=color_mapper.loc[cols].tolist(),\n",
    ")\n",
    "\n",
    "axs[0, 1].stackplot(\n",
    "    index,\n",
    "    *[pypsa_total_inflow[col].loc[start:end].values for col in cols],\n",
    "    colors=color_mapper.loc[cols].tolist(),\n",
    ")\n",
    "axs[0, 1].stackplot(\n",
    "    index,\n",
    "    *[pypsa_total_outflow[col].loc[start:end].values for col in cols_out],\n",
    "    colors=color_mapper.loc[cols_out].tolist(),\n",
    ")\n",
    "\n",
    "entsoe_total_inflow = entsoe_total_inflow.resample(coarse_freq).mean()\n",
    "pypsa_total_inflow = pypsa_total_inflow.resample(coarse_freq).mean()\n",
    "pypsa_total_outflow = pypsa_total_outflow.resample(coarse_freq).mean()\n",
    "\n",
    "index = pypsa_total_inflow.index\n",
    "\n",
    "axs[1, 0].stackplot(\n",
    "    index,\n",
    "    *[entsoe_total_inflow[col].values for col in cols],\n",
    "    colors=color_mapper.loc[cols].tolist(),\n",
    ")\n",
    "\n",
    "axs[1, 1].stackplot(\n",
    "    index,\n",
    "    *[pypsa_total_inflow[col].values for col in cols],\n",
    "    colors=color_mapper.loc[cols].tolist(),\n",
    ")\n",
    "axs[1, 1].stackplot(\n",
    "    index,\n",
    "    *[pypsa_total_outflow[col].values for col in cols_out],\n",
    "    colors=color_mapper.loc[cols_out].tolist(),\n",
    ")\n",
    "\n",
    "for ax in axs[:3].flatten():\n",
    "    ax.set_ylim(-100, 400)\n",
    "\n",
    "\n",
    "total_entsoe_ddf = pd.concat(\n",
    "    [\n",
    "        entsoe_total_inflow[col].sort_values(ascending=False).reset_index(drop=True)\n",
    "        for col in entsoe_total_inflow.columns\n",
    "    ],\n",
    "    axis=1,\n",
    ")\n",
    "axs[2, 0].stackplot(\n",
    "    range(len(total_entsoe_ddf)),\n",
    "    *[total_entsoe_ddf[col].values for col in total_entsoe_ddf.columns],\n",
    "    colors=color_mapper.loc[total_entsoe_ddf.columns].tolist(),\n",
    "    labels=total_entsoe_ddf.columns,\n",
    ")\n",
    "\n",
    "total_pypsa_ddf = pd.concat(\n",
    "    [\n",
    "        pypsa_total_inflow[col].sort_values(ascending=False).reset_index(drop=True)\n",
    "        for col in pypsa_total_inflow.columns\n",
    "    ],\n",
    "    axis=1,\n",
    ")\n",
    "axs[2, 1].stackplot(\n",
    "    range(len(total_pypsa_ddf)),\n",
    "    *[total_pypsa_ddf[col].values for col in total_pypsa_ddf.columns],\n",
    "    colors=color_mapper.loc[total_pypsa_ddf.columns].tolist(),\n",
    "    labels=total_pypsa_ddf.columns,\n",
    ")\n",
    "\n",
    "total_pypsa_ddf = pd.concat(\n",
    "    [\n",
    "        pypsa_total_outflow[col].sort_values(ascending=False).reset_index(drop=True)\n",
    "        for col in pypsa_total_outflow.columns\n",
    "    ],\n",
    "    axis=1,\n",
    ")\n",
    "axs[2, 1].stackplot(\n",
    "    range(len(total_pypsa_ddf)),\n",
    "    *[total_pypsa_ddf[col].values for col in total_pypsa_ddf.columns],\n",
    "    colors=color_mapper.loc[total_pypsa_ddf.columns].tolist(),\n",
    "    labels=total_pypsa_ddf.columns,\n",
    ")\n",
    "axs[2, 0].legend(\n",
    "    loc=\"upper center\", bbox_to_anchor=(0.5, -0.05), fancybox=True, shadow=True, ncol=5\n",
    ")\n",
    "\n",
    "total_prices = (\n",
    "    n.buses_t.marginal_price.multiply(n.loads_t.p_set)\n",
    "    .sum(axis=1)\n",
    "    .divide(n.loads_t.p_set.sum(axis=1))\n",
    ")\n",
    "\n",
    "total_entsoe_prices = None\n",
    "\n",
    "for num, country in tqdm(enumerate(os.listdir(data_path / \"pypsa_data\"))):\n",
    "    cc = country[:2]\n",
    "    country_buses = np.unique(\n",
    "        n.generators.loc[n.generators.bus.str.contains(cc)].bus.values\n",
    "    )\n",
    "\n",
    "    if not len(country_buses) == 1:\n",
    "        continue\n",
    "\n",
    "    bus = country_buses[0]\n",
    "\n",
    "    try:\n",
    "        entsoe_prices = pd.read_csv(\n",
    "            data_path / \"price_data\" / country,\n",
    "            index_col=0,\n",
    "            parse_dates=True,\n",
    "        ).iloc[:-1]\n",
    "        entsoe_prices.index = n.loads_t.p_set.index\n",
    "\n",
    "        def make_tz_time(time):\n",
    "            return pd.Timestamp(time).tz_convert(\"utc\")\n",
    "\n",
    "    except FileNotFoundError:\n",
    "        continue\n",
    "\n",
    "    if total_entsoe_prices is None:\n",
    "        total_entsoe_prices = pd.Series(\n",
    "            np.zeros(len(entsoe_prices)), index=entsoe_prices.index\n",
    "        )\n",
    "\n",
    "    total_entsoe_prices += entsoe_prices.iloc[:, 0] * n.loads_t.p_set[bus]\n",
    "\n",
    "total_entsoe_prices /= n.loads_t.p_set.sum(axis=1)\n",
    "\n",
    "error = np.around(\n",
    "    mean_absolute_error(total_entsoe_prices.values, total_prices.values), decimals=2\n",
    ")\n",
    "\n",
    "axs[3, 0].plot(\n",
    "    total_prices.loc[start:end].index,\n",
    "    total_prices.loc[start:end].values,\n",
    "    label=\"PyPSA Marginal Price\",\n",
    ")\n",
    "axs[3, 0].plot(\n",
    "    total_prices.loc[start:end].index,\n",
    "    total_entsoe_prices.loc[start:end].values,\n",
    "    label=\"ENTSO-E\",\n",
    ")\n",
    "axs[3, 1].set_title(f\"Mean Abs Error {error} [Euro/MWh]\")\n",
    "axs[3, 0].legend()\n",
    "\n",
    "total_prices = total_prices.resample(coarse_freq).mean()\n",
    "total_entsoe_prices = total_entsoe_prices.resample(coarse_freq).mean()\n",
    "\n",
    "axs[3, 1].plot(total_prices.index, total_prices.values, label=f\"PyPSA Marginal Price\")\n",
    "axs[3, 1].plot(total_prices.index, total_entsoe_prices.values, label=\"ENTSO-E Price\")\n",
    "\n",
    "axs[3, 1].legend()\n",
    "\n",
    "for ax in axs[:3, 0]:\n",
    "    ax.set_ylabel(\"ENTSO-E Generation [GWh]\")\n",
    "for ax in axs[:3, 1]:\n",
    "    ax.set_ylabel(\"PyPSA Generation [GWh]\")\n",
    "for ax in axs[:2].flatten():\n",
    "    ax.set_xlabel(\"Datetime\")\n",
    "for ax in axs[3]:\n",
    "    ax.set_xlabel(\"Datetime\")\n",
    "    ax.set_ylabel(\"Cost of Electricity [Euro/MWh]\")\n",
    "for ax in axs[2]:\n",
    "    ax.set_xlabel(\"Hour\")\n",
    "\n",
    "plt.savefig(plot_path / \"EuropeDashboard.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "",
   "language": "python",
   "name": ""
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}