diff --git a/scripts/_helpers.py b/scripts/_helpers.py index ceeb69f9..f7f1c557 100644 --- a/scripts/_helpers.py +++ b/scripts/_helpers.py @@ -295,7 +295,7 @@ def get_aggregation_strategies(aggregation_strategies): return bus_strategies, generator_strategies -def mock_snakemake(rulename, **wildcards): +def mock_snakemake(rulename, configfiles=[], **wildcards): """ This function is expected to be executed from the 'scripts'-directory of ' the snakemake project. It returns a snakemake.script.Snakemake object, @@ -307,6 +307,8 @@ def mock_snakemake(rulename, **wildcards): ---------- rulename: str name of the rule for which the snakemake object should be generated + configfiles: list, str + list of configfiles to be used to update the config **wildcards: keyword arguments fixing the wildcards. Only necessary if wildcards are needed. @@ -319,46 +321,66 @@ def mock_snakemake(rulename, **wildcards): from snakemake.script import Snakemake script_dir = Path(__file__).parent.resolve() - assert ( - Path.cwd().resolve() == script_dir - ), f"mock_snakemake has to be run from the repository scripts directory {script_dir}" - os.chdir(script_dir.parent) - for p in sm.SNAKEFILE_CHOICES: - if os.path.exists(p): - snakefile = p - break - kwargs = dict(rerun_triggers=[]) if parse(sm.__version__) > Version("7.7.0") else {} - workflow = sm.Workflow(snakefile, overwrite_configfiles=[], **kwargs) - workflow.include(snakefile) - workflow.global_resources = {} - rule = workflow.get_rule(rulename) - dag = sm.dag.DAG(workflow, rules=[rule]) - wc = Dict(wildcards) - job = sm.jobs.Job(rule, dag, wc) + root_dir = script_dir.parent - def make_accessable(*ios): - for io in ios: - for i in range(len(io)): - io[i] = os.path.abspath(io[i]) + user_in_script_dir = Path.cwd().resolve() == script_dir + if user_in_script_dir: + os.chdir(root_dir) + elif Path.cwd().resolve() != root_dir: + raise RuntimeError( + "mock_snakemake has to be run from the repository root" + f" {root_dir} or scripts directory {script_dir}" + ) + try: + for p in sm.SNAKEFILE_CHOICES: + if os.path.exists(p): + snakefile = p + break + kwargs = ( + dict(rerun_triggers=[]) if parse(sm.__version__) > Version("7.7.0") else {} + ) + workflow = sm.Workflow(snakefile, **kwargs) + workflow.include(snakefile) - make_accessable(job.input, job.output, job.log) - snakemake = Snakemake( - job.input, - job.output, - job.params, - job.wildcards, - job.threads, - job.resources, - job.log, - job.dag.workflow.config, - job.rule.name, - None, - ) - # create log and output dir if not existent - for path in list(snakemake.log) + list(snakemake.output): - Path(path).parent.mkdir(parents=True, exist_ok=True) + if isinstance(configfiles, str): + configfiles = [configfiles] + if configfiles: + for f in configfiles: + if not os.path.exists(f): + raise FileNotFoundError(f"Config file {f} does not exist.") + workflow.configfile(f) - os.chdir(script_dir) + workflow.global_resources = {} + rule = workflow.get_rule(rulename) + dag = sm.dag.DAG(workflow, rules=[rule]) + wc = Dict(wildcards) + job = sm.jobs.Job(rule, dag, wc) + + def make_accessable(*ios): + for io in ios: + for i in range(len(io)): + io[i] = os.path.abspath(io[i]) + + make_accessable(job.input, job.output, job.log) + snakemake = Snakemake( + job.input, + job.output, + job.params, + job.wildcards, + job.threads, + job.resources, + job.log, + job.dag.workflow.config, + job.rule.name, + None, + ) + # create log and output dir if not existent + for path in list(snakemake.log) + list(snakemake.output): + Path(path).parent.mkdir(parents=True, exist_ok=True) + + finally: + if user_in_script_dir: + os.chdir(script_dir) return snakemake