# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: : 2017-2022 The PyPSA-Eur Authors
#
# SPDX-License-Identifier: MIT

"""
Creates summaries of aggregated energy and costs as ``.csv`` files.

Relevant Settings
-----------------

.. code:: yaml

    costs:
        year:
        version:
        fill_values:
        marginal_cost:
        capital_cost:

    electricity:
        max_hours:

.. seealso::
    Documentation of the configuration file ``config.yaml`` at
    :ref:`costs_cf`, :ref:`electricity_cf`

Inputs
------

Outputs
-------

Description
-----------

The following rule can be used to summarize the results in separate .csv files:

.. code:: bash

    snakemake results/summaries/elec_s_all_lall_Co2L-3H_all
                                         clusters
                                             line volume or cost cap
                                                - options
                                                        - all countries

the line volume/cost cap field can be set to one of the following:
* ``lv1.25`` for a particular line volume extension by 25%
* ``lc1.25`` for a line cost extension by 25 %
* ``lall`` for all evaluated caps
* ``lvall`` for all line volume caps
* ``lcall`` for all line cost caps

Replacing '/summaries/' with '/plots/' creates nice colored maps of the results.
"""

import logging
import os

import pandas as pd
import pypsa
from _helpers import configure_logging
from add_electricity import load_costs, update_transmission_costs

idx = pd.IndexSlice

logger = logging.getLogger(__name__)

opt_name = {"Store": "e", "Line": "s", "Transformer": "s"}


def _add_indexed_rows(df, raw_index):
    new_index = df.index.union(pd.MultiIndex.from_product(raw_index))
    if isinstance(new_index, pd.Index):
        new_index = pd.MultiIndex.from_tuples(new_index)

    return df.reindex(new_index)


def assign_carriers(n):
    if "carrier" not in n.loads:
        n.loads["carrier"] = "electricity"
        for carrier in ["transport", "heat", "urban heat"]:
            n.loads.loc[n.loads.index.str.contains(carrier), "carrier"] = carrier

    n.storage_units["carrier"].replace(
        {"hydro": "hydro+PHS", "PHS": "hydro+PHS"}, inplace=True
    )

    if "carrier" not in n.lines:
        n.lines["carrier"] = "AC"

    n.lines["carrier"].replace({"AC": "lines"}, inplace=True)

    if n.links.empty:
        n.links["carrier"] = pd.Series(dtype=str)
    n.links["carrier"].replace({"DC": "lines"}, inplace=True)

    if (
        "EU gas store" in n.stores.index
        and n.stores.loc["EU gas Store", "carrier"] == ""
    ):
        n.stores.loc["EU gas Store", "carrier"] = "gas Store"


def calculate_costs(n, label, costs):
    for c in n.iterate_components(
        n.branch_components | n.controllable_one_port_components ^ {"Load"}
    ):
        capital_costs = c.df.capital_cost * c.df[opt_name.get(c.name, "p") + "_nom_opt"]
        capital_costs_grouped = capital_costs.groupby(c.df.carrier).sum()

        # Index tuple(s) indicating the newly to-be-added row(s)
        raw_index = tuple(
            [[c.list_name], ["capital"], list(capital_costs_grouped.index)]
        )
        costs = _add_indexed_rows(costs, raw_index)

        costs.loc[idx[raw_index], label] = capital_costs_grouped.values

        if c.name == "Link":
            p = c.pnl.p0.multiply(n.snapshot_weightings.generators, axis=0).sum()
        elif c.name == "Line":
            continue
        elif c.name == "StorageUnit":
            p_all = c.pnl.p.multiply(n.snapshot_weightings.generators, axis=0)
            p_all[p_all < 0.0] = 0.0
            p = p_all.sum()
        else:
            p = c.pnl.p.multiply(n.snapshot_weightings.generators, axis=0).sum()

        marginal_costs = p * c.df.marginal_cost

        marginal_costs_grouped = marginal_costs.groupby(c.df.carrier).sum()

        costs = costs.reindex(
            costs.index.union(
                pd.MultiIndex.from_product(
                    [[c.list_name], ["marginal"], marginal_costs_grouped.index]
                )
            )
        )

        costs.loc[
            idx[c.list_name, "marginal", list(marginal_costs_grouped.index)], label
        ] = marginal_costs_grouped.values

    return costs


def calculate_curtailment(n, label, curtailment):
    avail = (
        n.generators_t.p_max_pu.multiply(n.generators.p_nom_opt)
        .sum()
        .groupby(n.generators.carrier)
        .sum()
    )
    used = n.generators_t.p.sum().groupby(n.generators.carrier).sum()

    curtailment[label] = (((avail - used) / avail) * 100).round(3)

    return curtailment


def calculate_energy(n, label, energy):
    for c in n.iterate_components(n.one_port_components | n.branch_components):

        if c.name in {"Generator", "Load", "ShuntImpedance"}:
            c_energies = (
                c.pnl.p.multiply(n.snapshot_weightings.generators, axis=0)
                .sum()
                .multiply(c.df.sign)
                .groupby(c.df.carrier)
                .sum()
            )
        elif c.name in {"StorageUnit", "Store"}:
            c_energies = (
                c.pnl.p.multiply(n.snapshot_weightings.stores, axis=0)
                .sum()
                .multiply(c.df.sign)
                .groupby(c.df.carrier)
                .sum()
            )
        else:
            c_energies = (
                (
                    -c.pnl.p1.multiply(n.snapshot_weightings.generators, axis=0).sum()
                    - c.pnl.p0.multiply(n.snapshot_weightings.generators, axis=0).sum()
                )
                .groupby(c.df.carrier)
                .sum()
            )

        energy = include_in_summary(energy, [c.list_name], label, c_energies)

    return energy


def include_in_summary(summary, multiindexprefix, label, item):
    # Index tuple(s) indicating the newly to-be-added row(s)
    raw_index = tuple([multiindexprefix, list(item.index)])
    summary = _add_indexed_rows(summary, raw_index)

    summary.loc[idx[raw_index], label] = item.values

    return summary


def calculate_capacity(n, label, capacity):
    for c in n.iterate_components(n.one_port_components):
        if "p_nom_opt" in c.df.columns:
            c_capacities = (
                abs(c.df.p_nom_opt.multiply(c.df.sign)).groupby(c.df.carrier).sum()
            )
            capacity = include_in_summary(capacity, [c.list_name], label, c_capacities)
        elif "e_nom_opt" in c.df.columns:
            c_capacities = (
                abs(c.df.e_nom_opt.multiply(c.df.sign)).groupby(c.df.carrier).sum()
            )
            capacity = include_in_summary(capacity, [c.list_name], label, c_capacities)

    for c in n.iterate_components(n.passive_branch_components):
        c_capacities = c.df["s_nom_opt"].groupby(c.df.carrier).sum()
        capacity = include_in_summary(capacity, [c.list_name], label, c_capacities)

    for c in n.iterate_components(n.controllable_branch_components):
        c_capacities = c.df.p_nom_opt.groupby(c.df.carrier).sum()
        capacity = include_in_summary(capacity, [c.list_name], label, c_capacities)

    return capacity


def calculate_supply(n, label, supply):
    """
    calculate the max dispatch of each component at the buses where the loads
    are attached.
    """

    load_types = n.buses.carrier.unique()

    for i in load_types:

        buses = n.buses.query("carrier == @i").index

        bus_map = pd.Series(False, index=n.buses.index)

        bus_map.loc[buses] = True

        for c in n.iterate_components(n.one_port_components):

            items = c.df.index[c.df.bus.map(bus_map)]

            if len(items) == 0 or c.pnl.p.empty:
                continue

            s = (
                c.pnl.p[items]
                .max()
                .multiply(c.df.loc[items, "sign"])
                .groupby(c.df.loc[items, "carrier"])
                .sum()
            )

            # Index tuple(s) indicating the newly to-be-added row(s)
            raw_index = tuple([[i], [c.list_name], list(s.index)])
            supply = _add_indexed_rows(supply, raw_index)

            supply.loc[idx[raw_index], label] = s.values

        for c in n.iterate_components(n.branch_components):

            for end in ["0", "1"]:

                items = c.df.index[c.df["bus" + end].map(bus_map)]

                if len(items) == 0 or c.pnl["p" + end].empty:
                    continue

                # lots of sign compensation for direction and to do maximums
                s = (-1) ** (1 - int(end)) * (
                    (-1) ** int(end) * c.pnl["p" + end][items]
                ).max().groupby(c.df.loc[items, "carrier"]).sum()

                supply = supply.reindex(
                    supply.index.union(
                        pd.MultiIndex.from_product([[i], [c.list_name], s.index])
                    )
                )
                supply.loc[idx[i, c.list_name, list(s.index)], label] = s.values

    return supply


def calculate_supply_energy(n, label, supply_energy):
    """
    calculate the total dispatch of each component at the buses where the loads
    are attached.
    """

    load_types = n.buses.carrier.unique()

    for i in load_types:

        buses = n.buses.query("carrier == @i").index

        bus_map = pd.Series(False, index=n.buses.index)

        bus_map.loc[buses] = True

        for c in n.iterate_components(n.one_port_components):

            items = c.df.index[c.df.bus.map(bus_map)]

            if len(items) == 0 or c.pnl.p.empty:
                continue

            s = (
                c.pnl.p[items]
                .sum()
                .multiply(c.df.loc[items, "sign"])
                .groupby(c.df.loc[items, "carrier"])
                .sum()
            )

            # Index tuple(s) indicating the newly to-be-added row(s)
            raw_index = tuple([[i], [c.list_name], list(s.index)])
            supply_energy = _add_indexed_rows(supply_energy, raw_index)

            supply_energy.loc[idx[raw_index], label] = s.values

        for c in n.iterate_components(n.branch_components):

            for end in ["0", "1"]:

                items = c.df.index[c.df["bus" + end].map(bus_map)]

                if len(items) == 0 or c.pnl["p" + end].empty:
                    continue

                s = (-1) * c.pnl["p" + end][items].sum().groupby(
                    c.df.loc[items, "carrier"]
                ).sum()

                supply_energy = supply_energy.reindex(
                    supply_energy.index.union(
                        pd.MultiIndex.from_product([[i], [c.list_name], s.index])
                    )
                )
                supply_energy.loc[idx[i, c.list_name, list(s.index)], label] = s.values

    return supply_energy


def calculate_metrics(n, label, metrics):
    metrics = metrics.reindex(
        metrics.index.union(
            pd.Index(
                [
                    "line_volume",
                    "line_volume_limit",
                    "line_volume_AC",
                    "line_volume_DC",
                    "line_volume_shadow",
                    "co2_shadow",
                ]
            )
        )
    )

    metrics.at["line_volume_DC", label] = (n.links.length * n.links.p_nom_opt)[
        n.links.carrier == "DC"
    ].sum()
    metrics.at["line_volume_AC", label] = (n.lines.length * n.lines.s_nom_opt).sum()
    metrics.at["line_volume", label] = metrics.loc[
        ["line_volume_AC", "line_volume_DC"], label
    ].sum()

    if hasattr(n, "line_volume_limit"):
        metrics.at["line_volume_limit", label] = n.line_volume_limit

    if hasattr(n, "line_volume_limit_dual"):
        metrics.at["line_volume_shadow", label] = n.line_volume_limit_dual

    if "CO2Limit" in n.global_constraints.index:
        metrics.at["co2_shadow", label] = n.global_constraints.at["CO2Limit", "mu"]

    return metrics


def calculate_prices(n, label, prices):
    bus_type = pd.Series(n.buses.index.str[3:], n.buses.index).replace(
        "", "electricity"
    )

    prices = prices.reindex(prices.index.union(bus_type.value_counts().index))

    logger.warning("Prices are time-averaged, not load-weighted")
    prices[label] = n.buses_t.marginal_price.mean().groupby(bus_type).mean()

    return prices


def calculate_weighted_prices(n, label, weighted_prices):
    logger.warning("Weighted prices don't include storage units as loads")

    weighted_prices = weighted_prices.reindex(
        pd.Index(
            [
                "electricity",
                "heat",
                "space heat",
                "urban heat",
                "space urban heat",
                "gas",
                "H2",
            ]
        )
    )

    link_loads = {
        "electricity": [
            "heat pump",
            "resistive heater",
            "battery charger",
            "H2 Electrolysis",
        ],
        "heat": ["water tanks charger"],
        "urban heat": ["water tanks charger"],
        "space heat": [],
        "space urban heat": [],
        "gas": ["OCGT", "gas boiler", "CHP electric", "CHP heat"],
        "H2": ["Sabatier", "H2 Fuel Cell"],
    }

    for carrier in link_loads:

        if carrier == "electricity":
            suffix = ""
        elif carrier[:5] == "space":
            suffix = carrier[5:]
        else:
            suffix = " " + carrier

        buses = n.buses.index[n.buses.index.str[2:] == suffix]

        if buses.empty:
            continue

        if carrier in ["H2", "gas"]:
            load = pd.DataFrame(index=n.snapshots, columns=buses, data=0.0)
        elif carrier[:5] == "space":
            load = heat_demand_df[buses.str[:2]].rename(
                columns=lambda i: str(i) + suffix
            )
        else:
            load = n.loads_t.p_set[buses]

        for tech in link_loads[carrier]:

            names = n.links.index[n.links.index.to_series().str[-len(tech) :] == tech]

            if names.empty:
                continue

            load += (
                n.links_t.p0[names]
                .groupby(n.links.loc[names, "bus0"], axis=1)
                .sum(axis=1)
            )

        # Add H2 Store when charging
        if carrier == "H2":
            stores = (
                n.stores_t.p[buses + " Store"]
                .groupby(n.stores.loc[buses + " Store", "bus"], axis=1)
                .sum(axis=1)
            )
            stores[stores > 0.0] = 0.0
            load += -stores

        weighted_prices.loc[carrier, label] = (
            load * n.buses_t.marginal_price[buses]
        ).sum().sum() / load.sum().sum()

        if carrier[:5] == "space":
            print(load * n.buses_t.marginal_price[buses])

    return weighted_prices


outputs = [
    "costs",
    "curtailment",
    "energy",
    "capacity",
    "supply",
    "supply_energy",
    "prices",
    "weighted_prices",
    "metrics",
]


def make_summaries(networks_dict, paths, config, country="all"):
    columns = pd.MultiIndex.from_tuples(
        networks_dict.keys(), names=["simpl", "clusters", "ll", "opts"]
    )

    dfs = {}

    for output in outputs:
        dfs[output] = pd.DataFrame(columns=columns, dtype=float)

    for label, filename in networks_dict.items():
        print(label, filename)
        if not os.path.exists(filename):
            print("does not exist!!")
            continue

        try:
            n = pypsa.Network(filename)
        except OSError:
            logger.warning("Skipping {filename}".format(filename=filename))
            continue

        if country != "all":
            n = n[n.buses.country == country]

        Nyears = n.snapshot_weightings.objective.sum() / 8760.0
        costs = load_costs(paths[0], config["costs"], config["electricity"], Nyears)
        update_transmission_costs(n, costs)

        assign_carriers(n)

        for output in outputs:
            dfs[output] = globals()["calculate_" + output](n, label, dfs[output])

    return dfs


def to_csv(dfs, dir):
    os.makedirs(dir, exist_ok=True)
    for key, df in dfs.items():
        df.to_csv(os.path.join(dir, f"{key}.csv"))


if __name__ == "__main__":
    if "snakemake" not in globals():
        from _helpers import mock_snakemake

        snakemake = mock_snakemake(
            "make_summary",
            simpl="",
            clusters="5",
            ll="copt",
            opts="Co2L-24H",
            country="all",
        )
        network_dir = os.path.join("..", "results", "networks")
    else:
        network_dir = os.path.join("results", "networks")
    configure_logging(snakemake)

    config = snakemake.config
    wildcards = snakemake.wildcards

    def expand_from_wildcard(key, config):
        w = getattr(wildcards, key)
        return config["scenario"][key] if w == "all" else [w]

    if wildcards.ll.endswith("all"):
        ll = config["scenario"]["ll"]
        if len(wildcards.ll) == 4:
            ll = [l for l in ll if l[0] == wildcards.ll[0]]
    else:
        ll = [wildcards.ll]

    networks_dict = {
        (simpl, clusters, l, opts): os.path.join(
            network_dir, f"elec_s{simpl}_" f"{clusters}_ec_l{l}_{opts}.nc"
        )
        for simpl in expand_from_wildcard("simpl", config)
        for clusters in expand_from_wildcard("clusters", config)
        for l in ll
        for opts in expand_from_wildcard("opts", config)
    }

    dfs = make_summaries(
        networks_dict, snakemake.input, config, country=wildcards.country
    )

    to_csv(dfs, snakemake.output[0])