validate checksums for zenodo downloads

This commit is contained in:
Fabian Neumann 2023-12-29 12:34:14 +01:00
parent 6ee82e030f
commit 71985d5e3a
7 changed files with 79 additions and 4 deletions

View File

@ -54,6 +54,9 @@ Upcoming Release
reconnected to the main Ukrainian grid with the configuration option reconnected to the main Ukrainian grid with the configuration option
`reconnect_crimea`. `reconnect_crimea`.
* Validate downloads from Zenodo using MD5 checksums. This identifies corrupted
or incomplete downloads.
**Bugs and Compatibility** **Bugs and Compatibility**

View File

@ -2,6 +2,9 @@
# #
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import os, sys
sys.path.insert(0, os.path.abspath("scripts"))
from _helpers import validate_checksum
def memory(w): def memory(w):
factor = 3.0 factor = 3.0

View File

@ -77,6 +77,7 @@ if config["enable"]["retrieve"] and config["enable"].get("retrieve_cutout", True
retries: 2 retries: 2
run: run:
move(input[0], output[0]) move(input[0], output[0])
validate_checksum(output[0], input[0])
if config["enable"]["retrieve"] and config["enable"].get("retrieve_cost_data", True): if config["enable"]["retrieve"] and config["enable"].get("retrieve_cost_data", True):
@ -113,7 +114,7 @@ if config["enable"]["retrieve"] and config["enable"].get(
static=True, static=True,
), ),
output: output:
protected(RESOURCES + "natura.tiff"), RESOURCES + "natura.tiff",
log: log:
LOGS + "retrieve_natura_raster.log", LOGS + "retrieve_natura_raster.log",
resources: resources:
@ -121,6 +122,7 @@ if config["enable"]["retrieve"] and config["enable"].get(
retries: 2 retries: 2
run: run:
move(input[0], output[0]) move(input[0], output[0])
validate_checksum(output[0], input[0])
if config["enable"]["retrieve"] and config["enable"].get( if config["enable"]["retrieve"] and config["enable"].get(
@ -226,6 +228,7 @@ if config["enable"]["retrieve"]:
retries: 2 retries: 2
run: run:
move(input[0], output[0]) move(input[0], output[0])
validate_checksum(output[0], input[0])
if config["enable"]["retrieve"]: if config["enable"]["retrieve"]:
@ -243,6 +246,7 @@ if config["enable"]["retrieve"]:
+ "Copernicus_LC100_global_v3.0.1_2019-nrt_Discrete-Classification-map_EPSG-4326.tif", + "Copernicus_LC100_global_v3.0.1_2019-nrt_Discrete-Classification-map_EPSG-4326.tif",
run: run:
move(input[0], output[0]) move(input[0], output[0])
validate_checksum(output[0], input[0])
if config["enable"]["retrieve"]: if config["enable"]["retrieve"]:

View File

@ -4,6 +4,7 @@
# SPDX-License-Identifier: MIT # SPDX-License-Identifier: MIT
import contextlib import contextlib
import hashlib
import logging import logging
import os import os
import urllib import urllib
@ -11,6 +12,7 @@ from pathlib import Path
import pandas as pd import pandas as pd
import pytz import pytz
import requests
import yaml import yaml
from pypsa.components import component_attrs, components from pypsa.components import component_attrs, components
from pypsa.descriptors import Dict from pypsa.descriptors import Dict
@ -318,3 +320,60 @@ def update_config_with_sector_opts(config, sector_opts):
if o.startswith("CF+"): if o.startswith("CF+"):
l = o.split("+")[1:] l = o.split("+")[1:]
update_config(config, parse(l)) update_config(config, parse(l))
def get_checksum_from_zenodo(file_url):
parts = file_url.split("/")
record_id = parts[parts.index("record") + 1]
filename = parts[-1]
response = requests.get(f"https://zenodo.org/api/records/{record_id}", timeout=30)
response.raise_for_status()
data = response.json()
for file in data["files"]:
if file["key"] == filename:
return file["checksum"]
return None
def validate_checksum(file_path, zenodo_url=None, checksum=None):
"""
Validate file checksum against provided or Zenodo-retrieved checksum.
Calculates the hash of a file using 64KB chunks. Compares it against a given
checksum or one from a Zenodo URL.
Parameters
----------
file_path : str
Path to the file for checksum validation.
zenodo_url : str, optional
URL of the file on Zenodo to fetch the checksum.
checksum : str, optional
Checksum (format 'hash_type:checksum_value') for validation.
Raises
------
AssertionError
If the checksum does not match, or if neither `checksum` nor `zenodo_url` is provided.
Examples
--------
>>> validate_checksum('/path/to/file', checksum='md5:abc123...')
>>> validate_checksum('/path/to/file', zenodo_url='https://zenodo.org/record/12345/files/example.txt')
If the checksum is invalid, an AssertionError will be raised.
"""
assert checksum or zenodo_url, "Either checksum or zenodo_url must be provided"
if zenodo_url:
checksum = get_checksum_from_zenodo(zenodo_url)
hash_type, checksum = checksum.split(":")
hasher = hashlib.new(hash_type)
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(65536), b""): # 64kb chunks
hasher.update(chunk)
calculated_checksum = hasher.hexdigest()
assert (
calculated_checksum == checksum
), "Checksum is invalid. This may be due to an incomplete download. Delete the file and re-execute the rule."

View File

@ -36,7 +36,7 @@ import logging
import tarfile import tarfile
from pathlib import Path from pathlib import Path
from _helpers import configure_logging, progress_retrieve from _helpers import configure_logging, progress_retrieve, validate_checksum
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -65,6 +65,8 @@ if __name__ == "__main__":
disable_progress = snakemake.config["run"].get("disable_progressbar", False) disable_progress = snakemake.config["run"].get("disable_progressbar", False)
progress_retrieve(url, tarball_fn, disable=disable_progress) progress_retrieve(url, tarball_fn, disable=disable_progress)
validate_checksum(tarball_fn, url)
logger.info("Extracting databundle.") logger.info("Extracting databundle.")
tarfile.open(tarball_fn).extractall(to_fn) tarfile.open(tarball_fn).extractall(to_fn)

View File

@ -11,7 +11,7 @@ import logging
import zipfile import zipfile
from pathlib import Path from pathlib import Path
from _helpers import progress_retrieve from _helpers import progress_retrieve, validate_checksum
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,6 +35,8 @@ if __name__ == "__main__":
disable_progress = snakemake.config["run"].get("disable_progressbar", False) disable_progress = snakemake.config["run"].get("disable_progressbar", False)
progress_retrieve(url, zip_fn, disable=disable_progress) progress_retrieve(url, zip_fn, disable=disable_progress)
validate_checksum(zip_fn, url)
logger.info("Extracting databundle.") logger.info("Extracting databundle.")
zipfile.ZipFile(zip_fn).extractall(to_fn) zipfile.ZipFile(zip_fn).extractall(to_fn)

View File

@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
import tarfile import tarfile
from pathlib import Path from pathlib import Path
from _helpers import configure_logging, progress_retrieve from _helpers import configure_logging, progress_retrieve, validate_checksum
if __name__ == "__main__": if __name__ == "__main__":
if "snakemake" not in globals(): if "snakemake" not in globals():
@ -34,6 +34,8 @@ if __name__ == "__main__":
disable_progress = snakemake.config["run"].get("disable_progressbar", False) disable_progress = snakemake.config["run"].get("disable_progressbar", False)
progress_retrieve(url, tarball_fn, disable=disable_progress) progress_retrieve(url, tarball_fn, disable=disable_progress)
validate_checksum(tarball_fn, url)
logger.info("Extracting databundle.") logger.info("Extracting databundle.")
tarfile.open(tarball_fn).extractall(to_fn) tarfile.open(tarball_fn).extractall(to_fn)