diff --git a/doc/release_notes.rst b/doc/release_notes.rst index d7931f0e..7b1b6d73 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -54,6 +54,9 @@ Upcoming Release reconnected to the main Ukrainian grid with the configuration option `reconnect_crimea`. +* Validate downloads from Zenodo using MD5 checksums. This identifies corrupted + or incomplete downloads. + **Bugs and Compatibility** diff --git a/rules/common.smk b/rules/common.smk index d3416050..a1537c10 100644 --- a/rules/common.smk +++ b/rules/common.smk @@ -2,6 +2,9 @@ # # SPDX-License-Identifier: MIT +import os, sys +sys.path.insert(0, os.path.abspath("scripts")) +from _helpers import validate_checksum def memory(w): factor = 3.0 diff --git a/rules/retrieve.smk b/rules/retrieve.smk index 4fe0cd7b..e2e63427 100644 --- a/rules/retrieve.smk +++ b/rules/retrieve.smk @@ -77,6 +77,7 @@ if config["enable"]["retrieve"] and config["enable"].get("retrieve_cutout", True retries: 2 run: move(input[0], output[0]) + validate_checksum(output[0], input[0]) if config["enable"]["retrieve"] and config["enable"].get("retrieve_cost_data", True): @@ -113,7 +114,7 @@ if config["enable"]["retrieve"] and config["enable"].get( static=True, ), output: - protected(RESOURCES + "natura.tiff"), + RESOURCES + "natura.tiff", log: LOGS + "retrieve_natura_raster.log", resources: @@ -121,6 +122,7 @@ if config["enable"]["retrieve"] and config["enable"].get( retries: 2 run: move(input[0], output[0]) + validate_checksum(output[0], input[0]) if config["enable"]["retrieve"] and config["enable"].get( @@ -226,6 +228,7 @@ if config["enable"]["retrieve"]: retries: 2 run: move(input[0], output[0]) + validate_checksum(output[0], input[0]) if config["enable"]["retrieve"]: @@ -243,6 +246,7 @@ if config["enable"]["retrieve"]: + "Copernicus_LC100_global_v3.0.1_2019-nrt_Discrete-Classification-map_EPSG-4326.tif", run: move(input[0], output[0]) + validate_checksum(output[0], input[0]) if config["enable"]["retrieve"]: diff --git a/scripts/_helpers.py b/scripts/_helpers.py index 398f3a30..d906872d 100644 --- a/scripts/_helpers.py +++ b/scripts/_helpers.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: MIT import contextlib +import hashlib import logging import os import urllib @@ -11,6 +12,7 @@ from pathlib import Path import pandas as pd import pytz +import requests import yaml from pypsa.components import component_attrs, components from pypsa.descriptors import Dict @@ -318,3 +320,60 @@ def update_config_with_sector_opts(config, sector_opts): if o.startswith("CF+"): l = o.split("+")[1:] update_config(config, parse(l)) + + +def get_checksum_from_zenodo(file_url): + parts = file_url.split("/") + record_id = parts[parts.index("record") + 1] + filename = parts[-1] + + response = requests.get(f"https://zenodo.org/api/records/{record_id}", timeout=30) + response.raise_for_status() + data = response.json() + + for file in data["files"]: + if file["key"] == filename: + return file["checksum"] + return None + + +def validate_checksum(file_path, zenodo_url=None, checksum=None): + """ + Validate file checksum against provided or Zenodo-retrieved checksum. + Calculates the hash of a file using 64KB chunks. Compares it against a given + checksum or one from a Zenodo URL. + + Parameters + ---------- + file_path : str + Path to the file for checksum validation. + zenodo_url : str, optional + URL of the file on Zenodo to fetch the checksum. + checksum : str, optional + Checksum (format 'hash_type:checksum_value') for validation. + + Raises + ------ + AssertionError + If the checksum does not match, or if neither `checksum` nor `zenodo_url` is provided. + + + Examples + -------- + >>> validate_checksum('/path/to/file', checksum='md5:abc123...') + >>> validate_checksum('/path/to/file', zenodo_url='https://zenodo.org/record/12345/files/example.txt') + + If the checksum is invalid, an AssertionError will be raised. + """ + assert checksum or zenodo_url, "Either checksum or zenodo_url must be provided" + if zenodo_url: + checksum = get_checksum_from_zenodo(zenodo_url) + hash_type, checksum = checksum.split(":") + hasher = hashlib.new(hash_type) + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(65536), b""): # 64kb chunks + hasher.update(chunk) + calculated_checksum = hasher.hexdigest() + assert ( + calculated_checksum == checksum + ), "Checksum is invalid. This may be due to an incomplete download. Delete the file and re-execute the rule." diff --git a/scripts/retrieve_databundle.py b/scripts/retrieve_databundle.py index 75d8519e..25894063 100644 --- a/scripts/retrieve_databundle.py +++ b/scripts/retrieve_databundle.py @@ -36,7 +36,7 @@ import logging import tarfile from pathlib import Path -from _helpers import configure_logging, progress_retrieve +from _helpers import configure_logging, progress_retrieve, validate_checksum logger = logging.getLogger(__name__) @@ -65,6 +65,8 @@ if __name__ == "__main__": disable_progress = snakemake.config["run"].get("disable_progressbar", False) progress_retrieve(url, tarball_fn, disable=disable_progress) + validate_checksum(tarball_fn, url) + logger.info("Extracting databundle.") tarfile.open(tarball_fn).extractall(to_fn) diff --git a/scripts/retrieve_gas_infrastructure_data.py b/scripts/retrieve_gas_infrastructure_data.py index 42b726db..d984b9fe 100644 --- a/scripts/retrieve_gas_infrastructure_data.py +++ b/scripts/retrieve_gas_infrastructure_data.py @@ -11,7 +11,7 @@ import logging import zipfile from pathlib import Path -from _helpers import progress_retrieve +from _helpers import progress_retrieve, validate_checksum logger = logging.getLogger(__name__) @@ -35,6 +35,8 @@ if __name__ == "__main__": disable_progress = snakemake.config["run"].get("disable_progressbar", False) progress_retrieve(url, zip_fn, disable=disable_progress) + validate_checksum(zip_fn, url) + logger.info("Extracting databundle.") zipfile.ZipFile(zip_fn).extractall(to_fn) diff --git a/scripts/retrieve_sector_databundle.py b/scripts/retrieve_sector_databundle.py index 0d172c8d..cb6cc969 100644 --- a/scripts/retrieve_sector_databundle.py +++ b/scripts/retrieve_sector_databundle.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) import tarfile from pathlib import Path -from _helpers import configure_logging, progress_retrieve +from _helpers import configure_logging, progress_retrieve, validate_checksum if __name__ == "__main__": if "snakemake" not in globals(): @@ -34,6 +34,8 @@ if __name__ == "__main__": disable_progress = snakemake.config["run"].get("disable_progressbar", False) progress_retrieve(url, tarball_fn, disable=disable_progress) + validate_checksum(tarball_fn, url) + logger.info("Extracting databundle.") tarfile.open(tarball_fn).extractall(to_fn)