diff --git a/rules/common.smk b/rules/common.smk index 5677f577..f24301c8 100644 --- a/rules/common.smk +++ b/rules/common.smk @@ -3,9 +3,10 @@ # SPDX-License-Identifier: MIT import copy +from functools import partial, lru_cache -def get_config(keys, config, default=None): +def get_config(config, keys, default=None): """Retrieve a nested value from a dictionary using a tuple of keys.""" value = config for key in keys: @@ -26,6 +27,27 @@ def merge_configs(base_config, scenario_config): return merged +@lru_cache +def scenario_config(scenario_name): + """Retrieve a scenario config based on the overrides from the scenario file.""" + return merge_configs(config, scenarios[scenario_name]) + + +def static_getter(wildcards, keys, default): + """Getter function for static config values.""" + return get_config(config, keys, default) + + +def dynamic_getter(wildcards, keys, default): + """Getter function for dynamic config values based on scenario.""" + scenario_name = wildcards.run + if scenario_name not in scenarios: + raise ValueError( + f"Scenario {scenario_name} not found in file {config['scenariofile']}." + ) + return get_config(scenario_config(scenario_name), keys, default) + + def config_provider(*keys, default=None): """Dynamically provide config values based on 'run' -> 'name'. @@ -33,25 +55,11 @@ def config_provider(*keys, default=None): params: my_param=config_provider("key1", "key2", default="some_default_value") """ - - def static_getter(wildcards): - """Getter function for static config values.""" - return get_config(keys, config, default) - - def dynamic_getter(wildcards): - """Getter function for dynamic config values based on scenario.""" - scenario_name = wildcards.run - if scenario_name not in scenarios: - raise ValueError( - f"Scenario {scenario_name} not found in file {config['scenariofile']}." - ) - merged_config = merge_configs(config, scenarios[scenario_name]) - return get_config(keys, merged_config, default) - + # Using functools.partial to freeze certain arguments in our getter functions. if config["run"].get("scenarios", False): - return dynamic_getter + return partial(dynamic_getter, keys=keys, default=default) else: - return static_getter + return partial(static_getter, keys=keys, default=default) def memory(w):