# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: : 2017-2023 The PyPSA-Eur Authors
#
# SPDX-License-Identifier: MIT

"""
Solves linear optimal dispatch in hourly resolution using the capacities of
previous capacity expansion in rule :mod:`solve_network`.

Relevant Settings
-----------------

.. code:: yaml

    solving:
        tmpdir:
        options:
            formulation:
            clip_p_max_pu:
            load_shedding:
            noisy_costs:
            nhours:
            min_iterations:
            max_iterations:
        solver:
            name:
            (solveroptions):

.. seealso::
    Documentation of the configuration file ``config.yaml`` at
    :ref:`solving_cf`

Inputs
------

- ``networks/elec_s{simpl}_{clusters}.nc``: confer :ref:`cluster`
- ``results/networks/elec_s{simpl}_{clusters}_ec_l{ll}_{opts}.nc``: confer :ref:`solve`

Outputs
-------

- ``results/networks/elec_s{simpl}_{clusters}_ec_l{ll}_{opts}_op.nc``: Solved PyPSA network for optimal dispatch including optimisation results

Description
-----------
"""

import logging
from pathlib import Path

import numpy as np
import pypsa
from _helpers import configure_logging
from solve_network import prepare_network, solve_network
from vresutils.benchmark import memory_logger

logger = logging.getLogger(__name__)


def set_parameters_from_optimized(n, n_optim):
    lines_typed_i = n.lines.index[n.lines.type != ""]
    n.lines.loc[lines_typed_i, "num_parallel"] = n_optim.lines["num_parallel"].reindex(
        lines_typed_i, fill_value=0.0
    )
    n.lines.loc[lines_typed_i, "s_nom"] = (
        np.sqrt(3)
        * n.lines["type"].map(n.line_types.i_nom)
        * n.lines.bus0.map(n.buses.v_nom)
        * n.lines.num_parallel
    )

    lines_untyped_i = n.lines.index[n.lines.type == ""]
    for attr in ("s_nom", "r", "x"):
        n.lines.loc[lines_untyped_i, attr] = n_optim.lines[attr].reindex(
            lines_untyped_i, fill_value=0.0
        )
    n.lines["s_nom_extendable"] = False

    links_dc_i = n.links.index[n.links.p_nom_extendable]
    n.links.loc[links_dc_i, "p_nom"] = n_optim.links["p_nom_opt"].reindex(
        links_dc_i, fill_value=0.0
    )
    n.links.loc[links_dc_i, "p_nom_extendable"] = False

    gen_extend_i = n.generators.index[n.generators.p_nom_extendable]
    n.generators.loc[gen_extend_i, "p_nom"] = n_optim.generators["p_nom_opt"].reindex(
        gen_extend_i, fill_value=0.0
    )
    n.generators.loc[gen_extend_i, "p_nom_extendable"] = False

    stor_units_extend_i = n.storage_units.index[n.storage_units.p_nom_extendable]
    n.storage_units.loc[stor_units_extend_i, "p_nom"] = n_optim.storage_units[
        "p_nom_opt"
    ].reindex(stor_units_extend_i, fill_value=0.0)
    n.storage_units.loc[stor_units_extend_i, "p_nom_extendable"] = False

    stor_extend_i = n.stores.index[n.stores.e_nom_extendable]
    n.stores.loc[stor_extend_i, "e_nom"] = n_optim.stores["e_nom_opt"].reindex(
        stor_extend_i, fill_value=0.0
    )
    n.stores.loc[stor_extend_i, "e_nom_extendable"] = False

    return n


if __name__ == "__main__":
    if "snakemake" not in globals():
        from _helpers import mock_snakemake

        snakemake = mock_snakemake(
            "solve_operations_network",
            simpl="",
            clusters="5",
            ll="copt",
            opts="Co2L-BAU-24H",
        )
    configure_logging(snakemake)

    tmpdir = snakemake.config["solving"].get("tmpdir")
    if tmpdir is not None:
        Path(tmpdir).mkdir(parents=True, exist_ok=True)

    n = pypsa.Network(snakemake.input.unprepared)
    n_optim = pypsa.Network(snakemake.input.optimized)
    n = set_parameters_from_optimized(n, n_optim)
    del n_optim

    opts = snakemake.wildcards.opts.split("-")
    snakemake.config["solving"]["options"]["skip_iterations"] = False

    fn = getattr(snakemake.log, "memory", None)
    with memory_logger(filename=fn, interval=30.0) as mem:
        n = prepare_network(n, snakemake.config["solving"]["options"])
        n = solve_network(
            n,
            snakemake.config,
            opts,
            solver_dir=tmpdir,
            solver_logfile=snakemake.log.solver,
        )
        n.meta = dict(snakemake.config, **dict(wildcards=dict(snakemake.wildcards)))
        n.export_to_netcdf(snakemake.output[0])

    logger.info("Maximum memory usage: {}".format(mem.mem_usage))