locate project dir if pypsa-eur is a submodule

This commit is contained in:
Michael Lindner 2024-01-16 16:50:54 +01:00
parent 8ccd145a19
commit 3f8a55992c

View File

@ -223,7 +223,12 @@ def progress_retrieve(url, file, disable=False):
urllib.request.urlretrieve(url, file, reporthook=update_to) urllib.request.urlretrieve(url, file, reporthook=update_to)
def mock_snakemake(rulename, root_dir=None, configfiles=[], **wildcards): def mock_snakemake(
rulename,
root_dir=None,
configfiles=[],
submodule_dir="workflow/submodules/pypsa-eur",
**wildcards):
""" """
This function is expected to be executed from the 'scripts'-directory of ' This function is expected to be executed from the 'scripts'-directory of '
the snakemake project. It returns a snakemake.script.Snakemake object, the snakemake project. It returns a snakemake.script.Snakemake object,
@ -239,6 +244,9 @@ def mock_snakemake(rulename, root_dir=None, configfiles=[], **wildcards):
path to the root directory of the snakemake project path to the root directory of the snakemake project
configfiles: list, str configfiles: list, str
list of configfiles to be used to update the config list of configfiles to be used to update the config
submodule_dir: str, Path
in case PyPSA-Eur is used as a submodule, submodule_dir is
the path of pypsa-eur relative to the project directory.
**wildcards: **wildcards:
keyword arguments fixing the wildcards. Only necessary if wildcards are keyword arguments fixing the wildcards. Only necessary if wildcards are
needed. needed.
@ -257,7 +265,10 @@ def mock_snakemake(rulename, root_dir=None, configfiles=[], **wildcards):
root_dir = Path(root_dir).resolve() root_dir = Path(root_dir).resolve()
user_in_script_dir = Path.cwd().resolve() == script_dir user_in_script_dir = Path.cwd().resolve() == script_dir
if user_in_script_dir: if str(submodule_dir) in __file__:
# the submodule_dir path is only need to locate the project dir
os.chdir(Path(__file__[:__file__.find(str(submodule_dir))]))
elif user_in_script_dir:
os.chdir(root_dir) os.chdir(root_dir)
elif Path.cwd().resolve() != root_dir: elif Path.cwd().resolve() != root_dir:
raise RuntimeError( raise RuntimeError(