Merge pull request #821 from PyPSA/md5-checksums

validate checksums for zenodo downloads
This commit is contained in:
Fabian Neumann 2024-01-02 17:23:06 +01:00 committed by GitHub
commit fac257ca97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 83 additions and 3 deletions

View File

@ -54,6 +54,9 @@ Upcoming Release
reconnected to the main Ukrainian grid with the configuration option
`reconnect_crimea`.
* Validate downloads from Zenodo using MD5 checksums. This identifies corrupted
or incomplete downloads.
**Bugs and Compatibility**

View File

@ -2,6 +2,11 @@
#
# SPDX-License-Identifier: MIT
import os, sys
sys.path.insert(0, os.path.abspath("scripts"))
from _helpers import validate_checksum
def memory(w):
factor = 3.0

View File

@ -77,6 +77,7 @@ if config["enable"]["retrieve"] and config["enable"].get("retrieve_cutout", True
retries: 2
run:
move(input[0], output[0])
validate_checksum(output[0], input[0])
if config["enable"]["retrieve"] and config["enable"].get("retrieve_cost_data", True):
@ -121,6 +122,7 @@ if config["enable"]["retrieve"] and config["enable"].get(
retries: 2
run:
move(input[0], output[0])
validate_checksum(output[0], input[0])
if config["enable"]["retrieve"] and config["enable"].get(
@ -226,6 +228,7 @@ if config["enable"]["retrieve"]:
retries: 2
run:
move(input[0], output[0])
validate_checksum(output[0], input[0])
if config["enable"]["retrieve"]:
@ -242,6 +245,7 @@ if config["enable"]["retrieve"]:
"data/Copernicus_LC100_global_v3.0.1_2019-nrt_Discrete-Classification-map_EPSG-4326.tif",
run:
move(input[0], output[0])
validate_checksum(output[0], input[0])
if config["enable"]["retrieve"]:

View File

@ -4,6 +4,7 @@
# SPDX-License-Identifier: MIT
import contextlib
import hashlib
import logging
import os
import urllib
@ -11,6 +12,7 @@ from pathlib import Path
import pandas as pd
import pytz
import requests
import yaml
from pypsa.components import component_attrs, components
from pypsa.descriptors import Dict
@ -318,3 +320,63 @@ def update_config_with_sector_opts(config, sector_opts):
if o.startswith("CF+"):
l = o.split("+")[1:]
update_config(config, parse(l))
def get_checksum_from_zenodo(file_url):
parts = file_url.split("/")
record_id = parts[parts.index("record") + 1]
filename = parts[-1]
response = requests.get(f"https://zenodo.org/api/records/{record_id}", timeout=30)
response.raise_for_status()
data = response.json()
for file in data["files"]:
if file["key"] == filename:
return file["checksum"]
return None
def validate_checksum(file_path, zenodo_url=None, checksum=None):
"""
Validate file checksum against provided or Zenodo-retrieved checksum.
Calculates the hash of a file using 64KB chunks. Compares it against a
given checksum or one from a Zenodo URL.
Parameters
----------
file_path : str
Path to the file for checksum validation.
zenodo_url : str, optional
URL of the file on Zenodo to fetch the checksum.
checksum : str, optional
Checksum (format 'hash_type:checksum_value') for validation.
Raises
------
AssertionError
If the checksum does not match, or if neither `checksum` nor `zenodo_url` is provided.
Examples
--------
>>> validate_checksum("/path/to/file", checksum="md5:abc123...")
>>> validate_checksum(
... "/path/to/file",
... zenodo_url="https://zenodo.org/record/12345/files/example.txt",
... )
If the checksum is invalid, an AssertionError will be raised.
"""
assert checksum or zenodo_url, "Either checksum or zenodo_url must be provided"
if zenodo_url:
checksum = get_checksum_from_zenodo(zenodo_url)
hash_type, checksum = checksum.split(":")
hasher = hashlib.new(hash_type)
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(65536), b""): # 64kb chunks
hasher.update(chunk)
calculated_checksum = hasher.hexdigest()
assert (
calculated_checksum == checksum
), "Checksum is invalid. This may be due to an incomplete download. Delete the file and re-execute the rule."

View File

@ -36,7 +36,7 @@ import logging
import tarfile
from pathlib import Path
from _helpers import configure_logging, progress_retrieve
from _helpers import configure_logging, progress_retrieve, validate_checksum
logger = logging.getLogger(__name__)
@ -65,6 +65,8 @@ if __name__ == "__main__":
disable_progress = snakemake.config["run"].get("disable_progressbar", False)
progress_retrieve(url, tarball_fn, disable=disable_progress)
validate_checksum(tarball_fn, url)
logger.info("Extracting databundle.")
tarfile.open(tarball_fn).extractall(to_fn)

View File

@ -11,7 +11,7 @@ import logging
import zipfile
from pathlib import Path
from _helpers import progress_retrieve
from _helpers import progress_retrieve, validate_checksum
logger = logging.getLogger(__name__)
@ -35,6 +35,8 @@ if __name__ == "__main__":
disable_progress = snakemake.config["run"].get("disable_progressbar", False)
progress_retrieve(url, zip_fn, disable=disable_progress)
validate_checksum(zip_fn, url)
logger.info("Extracting databundle.")
zipfile.ZipFile(zip_fn).extractall(to_fn)

View File

@ -13,7 +13,7 @@ logger = logging.getLogger(__name__)
import tarfile
from pathlib import Path
from _helpers import configure_logging, progress_retrieve
from _helpers import configure_logging, progress_retrieve, validate_checksum
if __name__ == "__main__":
if "snakemake" not in globals():
@ -34,6 +34,8 @@ if __name__ == "__main__":
disable_progress = snakemake.config["run"].get("disable_progressbar", False)
progress_retrieve(url, tarball_fn, disable=disable_progress)
validate_checksum(tarball_fn, url)
logger.info("Extracting databundle.")
tarfile.open(tarball_fn).extractall(to_fn)