common: make cache config_getter, use partial functions

This commit is contained in:
Fabian 2023-08-16 10:50:09 +02:00
parent b9f3df3856
commit b3a6e2c281

View File

@ -3,9 +3,10 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import copy import copy
from functools import partial, lru_cache
def get_config(keys, config, default=None): def get_config(config, keys, default=None):
"""Retrieve a nested value from a dictionary using a tuple of keys.""" """Retrieve a nested value from a dictionary using a tuple of keys."""
value = config value = config
for key in keys: for key in keys:
@ -26,6 +27,27 @@ def merge_configs(base_config, scenario_config):
return merged return merged
@lru_cache
def scenario_config(scenario_name):
"""Retrieve a scenario config based on the overrides from the scenario file."""
return merge_configs(config, scenarios[scenario_name])
def static_getter(wildcards, keys, default):
"""Getter function for static config values."""
return get_config(config, keys, default)
def dynamic_getter(wildcards, keys, default):
"""Getter function for dynamic config values based on scenario."""
scenario_name = wildcards.run
if scenario_name not in scenarios:
raise ValueError(
f"Scenario {scenario_name} not found in file {config['scenariofile']}."
)
return get_config(scenario_config(scenario_name), keys, default)
def config_provider(*keys, default=None): def config_provider(*keys, default=None):
"""Dynamically provide config values based on 'run' -> 'name'. """Dynamically provide config values based on 'run' -> 'name'.
@ -33,25 +55,11 @@ def config_provider(*keys, default=None):
params: params:
my_param=config_provider("key1", "key2", default="some_default_value") my_param=config_provider("key1", "key2", default="some_default_value")
""" """
# Using functools.partial to freeze certain arguments in our getter functions.
def static_getter(wildcards):
"""Getter function for static config values."""
return get_config(keys, config, default)
def dynamic_getter(wildcards):
"""Getter function for dynamic config values based on scenario."""
scenario_name = wildcards.run
if scenario_name not in scenarios:
raise ValueError(
f"Scenario {scenario_name} not found in file {config['scenariofile']}."
)
merged_config = merge_configs(config, scenarios[scenario_name])
return get_config(keys, merged_config, default)
if config["run"].get("scenarios", False): if config["run"].get("scenarios", False):
return dynamic_getter return partial(dynamic_getter, keys=keys, default=default)
else: else:
return static_getter return partial(static_getter, keys=keys, default=default)
def memory(w): def memory(w):