helpers: allow to overwrite config in mock_snakemake and allow mock_snakemake to be called from repo root

This commit is contained in:
Fabian 2023-03-09 12:45:44 +01:00
parent abe4df543e
commit 5356feab49

View File

@ -295,7 +295,7 @@ def get_aggregation_strategies(aggregation_strategies):
return bus_strategies, generator_strategies
def mock_snakemake(rulename, **wildcards):
def mock_snakemake(rulename, configfiles=[], **wildcards):
"""
This function is expected to be executed from the 'scripts'-directory of '
the snakemake project. It returns a snakemake.script.Snakemake object,
@ -307,6 +307,8 @@ def mock_snakemake(rulename, **wildcards):
----------
rulename: str
name of the rule for which the snakemake object should be generated
configfiles: list, str
list of configfiles to be used to update the config
**wildcards:
keyword arguments fixing the wildcards. Only necessary if wildcards are
needed.
@ -319,46 +321,66 @@ def mock_snakemake(rulename, **wildcards):
from snakemake.script import Snakemake
script_dir = Path(__file__).parent.resolve()
assert (
Path.cwd().resolve() == script_dir
), f"mock_snakemake has to be run from the repository scripts directory {script_dir}"
os.chdir(script_dir.parent)
for p in sm.SNAKEFILE_CHOICES:
if os.path.exists(p):
snakefile = p
break
kwargs = dict(rerun_triggers=[]) if parse(sm.__version__) > Version("7.7.0") else {}
workflow = sm.Workflow(snakefile, overwrite_configfiles=[], **kwargs)
workflow.include(snakefile)
workflow.global_resources = {}
rule = workflow.get_rule(rulename)
dag = sm.dag.DAG(workflow, rules=[rule])
wc = Dict(wildcards)
job = sm.jobs.Job(rule, dag, wc)
root_dir = script_dir.parent
def make_accessable(*ios):
for io in ios:
for i in range(len(io)):
io[i] = os.path.abspath(io[i])
user_in_script_dir = Path.cwd().resolve() == script_dir
if user_in_script_dir:
os.chdir(root_dir)
elif Path.cwd().resolve() != root_dir:
raise RuntimeError(
"mock_snakemake has to be run from the repository root"
f" {root_dir} or scripts directory {script_dir}"
)
try:
for p in sm.SNAKEFILE_CHOICES:
if os.path.exists(p):
snakefile = p
break
kwargs = (
dict(rerun_triggers=[]) if parse(sm.__version__) > Version("7.7.0") else {}
)
workflow = sm.Workflow(snakefile, **kwargs)
workflow.include(snakefile)
make_accessable(job.input, job.output, job.log)
snakemake = Snakemake(
job.input,
job.output,
job.params,
job.wildcards,
job.threads,
job.resources,
job.log,
job.dag.workflow.config,
job.rule.name,
None,
)
# create log and output dir if not existent
for path in list(snakemake.log) + list(snakemake.output):
Path(path).parent.mkdir(parents=True, exist_ok=True)
if isinstance(configfiles, str):
configfiles = [configfiles]
if configfiles:
for f in configfiles:
if not os.path.exists(f):
raise FileNotFoundError(f"Config file {f} does not exist.")
workflow.configfile(f)
os.chdir(script_dir)
workflow.global_resources = {}
rule = workflow.get_rule(rulename)
dag = sm.dag.DAG(workflow, rules=[rule])
wc = Dict(wildcards)
job = sm.jobs.Job(rule, dag, wc)
def make_accessable(*ios):
for io in ios:
for i in range(len(io)):
io[i] = os.path.abspath(io[i])
make_accessable(job.input, job.output, job.log)
snakemake = Snakemake(
job.input,
job.output,
job.params,
job.wildcards,
job.threads,
job.resources,
job.log,
job.dag.workflow.config,
job.rule.name,
None,
)
# create log and output dir if not existent
for path in list(snakemake.log) + list(snakemake.output):
Path(path).parent.mkdir(parents=True, exist_ok=True)
finally:
if user_in_script_dir:
os.chdir(script_dir)
return snakemake