diff --git a/.github/workflows/code_changes.yaml b/.github/workflows/code_changes.yaml index 1ba30bb5..e5169555 100644 --- a/.github/workflows/code_changes.yaml +++ b/.github/workflows/code_changes.yaml @@ -5,6 +5,7 @@ on: pull_request: branches: - main + - '0.x' # Maintenance branch for legacy 0.x releases paths: - policyengine/** diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..cc73968b 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,6 @@ +- bump: minor + changes: + changed: + - Replaced US state filtering with conditional file loading for states and Congressional districts + - Updated data source URLs for state and district files + - Modified GCS loading to support new versioning introduced by -core v3.21.0 \ No newline at end of file diff --git a/policyengine/outputs/macro/comparison/calculate_economy_comparison.py b/policyengine/outputs/macro/comparison/calculate_economy_comparison.py index 7e6457da..b9852eab 100644 --- a/policyengine/outputs/macro/comparison/calculate_economy_comparison.py +++ b/policyengine/outputs/macro/comparison/calculate_economy_comparison.py @@ -712,21 +712,21 @@ def uk_constituency_breakdown( baseline_hnet = baseline.household_net_income reform_hnet = reform.household_net_income - constituency_weights_path = download( + constituency_weights_local_path = download( gcs_bucket="policyengine-uk-data-private", - filepath="parliamentary_constituency_weights.h5", + gcs_key="parliamentary_constituency_weights.h5", ) - with h5py.File(constituency_weights_path, "r") as f: + with h5py.File(constituency_weights_local_path, "r") as f: weights = f["2025"][ ... ] # {2025: array(650, 100180) where cell i, j is the weight of household record i in constituency j} - constituency_names_path = download( + constituency_names_local_path = download( gcs_bucket="policyengine-uk-data-private", - filepath="constituencies_2024.csv", + gcs_key="constituencies_2024.csv", ) constituency_names = pd.read_csv( - constituency_names_path + constituency_names_local_path ) # columns code (constituency code), name (constituency name), x, y (geographic position) for i in range(len(constituency_names)): diff --git a/policyengine/simulation.py b/policyengine/simulation.py index f0733705..909bc717 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -6,7 +6,6 @@ from .utils.data.datasets import ( get_default_dataset, process_gs_path, - POLICYENGINE_DATASETS, DATASET_TIME_PERIODS, ) from policyengine_core.simulations import Simulation as CountrySimulation @@ -146,11 +145,23 @@ def _add_output_functions(self): ) def _set_data(self, file_address: str | None = None) -> None: + """Load and set the dataset for this simulation.""" + # Step 1: Resolve file address (if None, get default) + file_address = self._resolve_file_address(file_address) + print(f"Using dataset: {file_address}", file=sys.stderr) - # filename refers to file's unique name + extension; - # file_address refers to URI + filename + # Step 2: Acquire the file (download if PE dataset, or use local path) + filepath, version = self._acquire_dataset_file(file_address) + self.data_version = version - # If None is passed, user wants default dataset; get URL, then continue initializing. + # Step 3: Load into country-specific format + time_period = self._set_data_time_period(file_address) + self.options.data = self._load_dataset_for_country( + filepath, time_period + ) + + def _resolve_file_address(self, file_address: str | None) -> str: + """If no file address provided, get the default dataset for this country/region.""" if file_address is None: file_address = get_default_dataset( country=self.options.country, region=self.options.region @@ -159,31 +170,33 @@ def _set_data(self, file_address: str | None = None) -> None: f"No data provided, using default dataset: {file_address}", file=sys.stderr, ) + return file_address - if file_address not in POLICYENGINE_DATASETS: - # If it's a local file, no URI present and unable to infer version. - filename = file_address - version = None - + def _acquire_dataset_file( + self, file_address: str + ) -> tuple[str, str | None]: + """ + Get the dataset file, downloading from GCS if it's a GCS path. + Returns (filepath, version) where version is None for local files. + """ + if file_address.startswith("gs://"): + return self._set_data_from_gs(file_address) else: - # All official PolicyEngine datasets are stored in GCS; - # load accordingly - filename, version = self._set_data_from_gs(file_address) - self.data_version = version - - time_period = self._set_data_time_period(file_address) + # Local file - no download needed, no version available + return file_address, None - # UK needs custom loading + def _load_dataset_for_country( + self, filepath: str, time_period: int | None + ) -> Dataset: + """Load the dataset file using the appropriate country-specific loader.""" if self.options.country == "us": - self.options.data = Dataset.from_file( - filename, time_period=time_period - ) - else: + return Dataset.from_file(filepath, time_period=time_period) + elif self.options.country == "uk": from policyengine_uk.data import UKSingleYearDataset - self.options.data = UKSingleYearDataset( - file_path=filename, - ) + return UKSingleYearDataset(file_path=filepath) + else: + raise ValueError(f"Unsupported country: {self.options.country}") def _initialise_simulations(self): self.baseline_simulation = self._initialise_simulation( @@ -270,18 +283,12 @@ def _apply_region_to_simulation( time_period: TimePeriodType, ) -> CountrySimulation: if country == "us": - df = simulation.to_input_dataframe() - state_code = simulation.calculate( - "state_code_str", map_to="person" - ).values - if region == "city/nyc": - in_nyc = simulation.calculate("in_nyc", map_to="person").values - simulation = simulation_type(dataset=df[in_nyc], reform=reform) - elif "state/" in region: - state = region.split("/")[1] - simulation = simulation_type( - dataset=df[state_code == state.upper()], reform=reform - ) + simulation = self._apply_us_region_to_simulation( + simulation=simulation, + simulation_type=simulation_type, + region=region, + reform=reform, + ) elif country == "uk": if "country/" in region: region = region.split("/")[1] @@ -294,14 +301,14 @@ def _apply_region_to_simulation( ) elif "constituency/" in region: constituency = region.split("/")[1] - constituency_names_file_path = download( + constituency_names_local_path = download( gcs_bucket="policyengine-uk-data-private", - filepath="constituencies_2024.csv", + gcs_key="constituencies_2024.csv", ) - constituency_names_file_path = Path( - constituency_names_file_path + constituency_names_local_path = Path( + constituency_names_local_path ) - constituency_names = pd.read_csv(constituency_names_file_path) + constituency_names = pd.read_csv(constituency_names_local_path) if constituency in constituency_names.code.values: constituency_id = constituency_names[ constituency_names.code == constituency @@ -312,14 +319,14 @@ def _apply_region_to_simulation( ].index[0] else: raise ValueError( - f"Constituency {constituency} not found. See {constituency_names_file_path} for the list of available constituencies." + f"Constituency {constituency} not found. See {constituency_names_local_path} for the list of available constituencies." ) - weights_file_path = download( + weights_local_path = download( gcs_bucket="policyengine-uk-data-private", - filepath="parliamentary_constituency_weights.h5", + gcs_key="parliamentary_constituency_weights.h5", ) - with h5py.File(weights_file_path, "r") as f: + with h5py.File(weights_local_path, "r") as f: weights = f[str(time_period)][...] simulation.set_input( @@ -329,26 +336,26 @@ def _apply_region_to_simulation( ) elif "local_authority/" in region: la = region.split("/")[1] - la_names_file_path = download( + la_names_local_path = download( gcs_bucket="policyengine-uk-data-private", - filepath="local_authorities_2021.csv", + gcs_key="local_authorities_2021.csv", ) - la_names_file_path = Path(la_names_file_path) - la_names = pd.read_csv(la_names_file_path) + la_names_local_path = Path(la_names_local_path) + la_names = pd.read_csv(la_names_local_path) if la in la_names.code.values: la_id = la_names[la_names.code == la].index[0] elif la in la_names.name.values: la_id = la_names[la_names.name == la].index[0] else: raise ValueError( - f"Local authority {la} not found. See {la_names_file_path} for the list of available local authorities." + f"Local authority {la} not found. See {la_names_local_path} for the list of available local authorities." ) - weights_file_path = download( + weights_local_path = download( gcs_bucket="policyengine-uk-data-private", - filepath="local_authority_weights.h5", + gcs_key="local_authority_weights.h5", ) - with h5py.File(weights_file_path, "r") as f: + with h5py.File(weights_local_path, "r") as f: weights = f[str(self.time_period)][...] simulation.set_input( @@ -359,6 +366,40 @@ def _apply_region_to_simulation( return simulation + def _apply_us_region_to_simulation( + self, + simulation: CountryMicrosimulation, + simulation_type: type, + region: RegionType, + reform: ReformType | None, + ) -> CountrySimulation: + """Apply US-specific regional filtering to a simulation. + + Note: Most US regions (states, congressional districts) now use + scoped datasets rather than filtering. Only NYC still requires + filtering from the national dataset (and is still using the pooled + CPS by default). This should be replaced with an approach based on + the new datasets. + """ + if region == "city/nyc": + simulation = self._filter_us_simulation_by_nyc( + simulation=simulation, + simulation_type=simulation_type, + reform=reform, + ) + return simulation + + def _filter_us_simulation_by_nyc( + self, + simulation: CountryMicrosimulation, + simulation_type: type, + reform: ReformType | None, + ) -> CountrySimulation: + """Filter a US simulation to only include NYC households.""" + df = simulation.to_input_dataframe() + in_nyc = simulation.calculate("in_nyc", map_to="person").values + return simulation_type(dataset=df[in_nyc], reform=reform) + def check_model_version(self) -> None: """ Check the package versions of the simulation against the current package versions. @@ -402,19 +443,30 @@ def _set_data_time_period(self, file_address: str) -> Optional[int]: def _set_data_from_gs(self, file_address: str) -> tuple[str, str | None]: """ - Set the data from a GCS path and return the filename and version. + Download data from a GCS path and return the local path and version. + + Supports version specification in three ways (in priority order): + 1. Explicit data_version option: Simulation(data="gs://...", data_version="1.2.3") + 2. URL suffix: Simulation(data="gs://bucket/file.h5@1.2.3") + 3. None (latest): Simulation(data="gs://bucket/file.h5") + + Returns: + A tuple of (local_path, version) where: + - local_path: The local filesystem path where the file was saved + - version: The version string, or None if not available """ + bucket, gcs_key, url_version = process_gs_path(file_address) - bucket, filename = process_gs_path(file_address) - version = self.options.data_version + # Priority: explicit option > URL suffix > None (latest) + version = self.options.data_version or url_version - print(f"Downloading {filename} from bucket {bucket}", file=sys.stderr) + print(f"Downloading {gcs_key} from bucket {bucket}", file=sys.stderr) - filepath, version = download( - filepath=filename, + local_path, version = download( + gcs_key=gcs_key, gcs_bucket=bucket, version=version, return_version=True, ) - return filename, version + return local_path, version diff --git a/policyengine/utils/data/__init__.py b/policyengine/utils/data/__init__.py index 74b27825..ab1e193f 100644 --- a/policyengine/utils/data/__init__.py +++ b/policyengine/utils/data/__init__.py @@ -1,2 +1,2 @@ from .caching_google_storage_client import CachingGoogleStorageClient -from .simplified_google_storage_client import SimplifiedGoogleStorageClient +from .version_aware_storage_client import VersionAwareStorageClient diff --git a/policyengine/utils/data/caching_google_storage_client.py b/policyengine/utils/data/caching_google_storage_client.py index dc9dd8f5..4f4f8f6c 100644 --- a/policyengine/utils/data/caching_google_storage_client.py +++ b/policyengine/utils/data/caching_google_storage_client.py @@ -3,7 +3,7 @@ from pathlib import Path from policyengine_core.data.dataset import atomic_write import logging -from .simplified_google_storage_client import SimplifiedGoogleStorageClient +from .version_aware_storage_client import VersionAwareStorageClient from typing import Optional logger = logging.getLogger(__name__) @@ -11,12 +11,19 @@ class CachingGoogleStorageClient(AbstractContextManager): """ - Client for downloaded resources from a google storage bucket only when the CRC - of the blob changes. + Client for downloading resources from a Google Storage bucket with caching. + + Only downloads when the CRC of the blob changes, using a disk-based cache + to persist downloaded data across sessions. + + The client supports multiple versioning strategies via VersionAwareStorageClient: + - Generation-based: version is a GCS generation number + - Metadata-based: version is stored in blob.metadata["version"] + - Latest: when no version is specified, gets the latest blob """ def __init__(self): - self.client = SimplifiedGoogleStorageClient() + self.client = VersionAwareStorageClient() self.cache = diskcache.Cache() def _data_key( @@ -36,6 +43,17 @@ def download( ): """ Atomically write the latest version of the cloud storage blob to the target path. + + Args: + bucket: The GCS bucket name. + key: The blob path within the bucket. + target: The local file path to write the downloaded content to. + version: Optional version string. Can be a GCS generation number, + a metadata version string, or None for latest. + return_version: If True, return the version string of the downloaded blob. + + Returns: + The version string if return_version is True, otherwise None. """ if version is None: # If no version is specified, get the latest version from the cache diff --git a/policyengine/utils/data/datasets.py b/policyengine/utils/data/datasets.py index e4ed682b..a97f629b 100644 --- a/policyengine/utils/data/datasets.py +++ b/policyengine/utils/data/datasets.py @@ -1,12 +1,16 @@ """Mainly simulation options and parameters.""" -from typing import Tuple, Optional +from typing import Tuple, Optional, Literal + +from policyengine_core.tools.google_cloud import parse_gs_url + +US_DATA_BUCKET = "gs://policyengine-us-data" EFRS_2023 = "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5" FRS_2023 = "gs://policyengine-uk-data-private/frs_2023_24.h5" -CPS_2023 = "gs://policyengine-us-data/cps_2023.h5" -CPS_2023_POOLED = "gs://policyengine-us-data/pooled_3_year_cps_2023.h5" -ECPS_2024 = "gs://policyengine-us-data/enhanced_cps_2024.h5" +CPS_2023 = f"{US_DATA_BUCKET}/cps_2023.h5" +CPS_2023_POOLED = f"{US_DATA_BUCKET}/pooled_3_year_cps_2023.h5" +ECPS_2024 = f"{US_DATA_BUCKET}/enhanced_cps_2024.h5" POLICYENGINE_DATASETS = [ EFRS_2023, @@ -25,26 +29,128 @@ def get_default_dataset( - country: str, region: str, version: Optional[str] = None + country: str, region: str | None, version: Optional[str] = None ) -> str: if country == "uk": return EFRS_2023 elif country == "us": - if region is not None and region != "us": - return CPS_2023_POOLED - else: - return CPS_2023 + return _get_default_us_dataset(region) raise ValueError( f"Unable to select a default dataset for country {country} and region {region}." ) -def process_gs_path(path: str) -> Tuple[str, str]: - """Process a GS path to return bucket and object.""" - if not path.startswith("gs://"): - raise ValueError(f"Invalid GS path: {path}") +def _get_default_us_dataset(region: str | None) -> str: + """Get the default dataset for a US region.""" + region_type = determine_us_region_type(region) + + if region_type == "nationwide": + return CPS_2023 + elif region_type == "city": + # TODO: Implement a better approach to this for our one + # city, New York City. + # Cities use the pooled CPS dataset + return CPS_2023_POOLED + + # For state and congressional_district, region is guaranteed to be non-None + assert region is not None + + if region_type == "state": + state_code = region.split("/")[1] + return get_us_state_dataset_path(state_code) + elif region_type == "congressional_district": + # Expected format: "congressional_district/CA-01" + district_str = region.split("/")[1] + state_code, district_num_str = district_str.split("-") + district_number = int(district_num_str) + return get_us_congressional_district_dataset_path( + state_code, district_number + ) + + raise ValueError(f"Unhandled US region type: {region_type}") + + +def process_gs_path(path: str) -> Tuple[str, str, Optional[str]]: + """ + Process a GS path to return bucket, object, and optional version. + + Supports: + - gs://bucket/path/file.h5 + - gs://bucket/path/file.h5@1.2.3 + + Args: + path: A GCS URL in the format gs://bucket/path/to/file[@version] + + Returns: + A tuple of (bucket, object_path, version) where version may be None. + + Raises: + ValueError: If the path is not a valid gs:// URL. + """ + return parse_gs_url(path) + + +def get_us_state_dataset_path(state_code: str) -> str: + """ + Get the GCS path for a US state-level dataset. + + Args: + state_code: Two-letter US state code (e.g., "CA", "NY"). + + Returns: + GCS path to the state dataset. + """ + return f"{US_DATA_BUCKET}/states/{state_code.upper()}.h5" + - path = path[5:] # Remove 'gs://' - bucket, obj = path.split("/", 1) - return bucket, obj +def get_us_congressional_district_dataset_path( + state_code: str, district_number: int +) -> str: + """ + Get the GCS path for a US Congressional district-level dataset. + + Note: This is a theorized schema. The exact format of the district + dataset filenames may change once the actual data is available. + + Args: + state_code: Two-letter US state code (e.g., "CA", "NY"). + district_number: District number (1-52). + + Returns: + GCS path to the Congressional district dataset. + """ + print(f"US_DATA_BUCKET in GET_DATASET_PATH: {US_DATA_BUCKET}") + return f"{US_DATA_BUCKET}/districts/{state_code.upper()}-{district_number:02d}.h5" + + +USRegionType = Literal["nationwide", "city", "state", "congressional_district"] + +US_REGION_PREFIXES = ("city", "state", "congressional_district") + + +def determine_us_region_type(region: str | None) -> USRegionType: + """ + Determine the type of US region from a region string. + + Args: + region: A region string (e.g., "us", "city/nyc", "state/CA", + "congressional_district/CA-01") or None. + + Returns: + One of "nationwide", "city", "state", or "congressional_district". + + Raises: + ValueError: If the region string has an unrecognized prefix. + """ + if region is None or region == "us": + return "nationwide" + + for prefix in US_REGION_PREFIXES: + if region.startswith(f"{prefix}/"): + return prefix + + raise ValueError( + f"Unrecognized US region format: '{region}'. " + f"Expected 'us', or one of the following prefixes: {list(US_REGION_PREFIXES)}" + ) diff --git a/policyengine/utils/data/simplified_google_storage_client.py b/policyengine/utils/data/simplified_google_storage_client.py deleted file mode 100644 index b7d80df0..00000000 --- a/policyengine/utils/data/simplified_google_storage_client.py +++ /dev/null @@ -1,107 +0,0 @@ -import asyncio -from policyengine_core.data.dataset import atomic_write -import logging -from google.cloud.storage import Client, Blob -from typing import Iterable, Optional - -logger = logging.getLogger(__name__) - - -class SimplifiedGoogleStorageClient: - """ - Class separating out just the interactions with google storage required to - cache downloaded files. - - Simplifies the dependent code and unit testing. - """ - - def __init__(self): - self.client = Client() - - def get_versioned_blob( - self, bucket_name: str, key: str, version: Optional[str] = None - ) -> Blob: - """ - Get a versioned blob from the specified bucket and key. - If version is None, returns the latest version of the blob. - """ - bucket = self.client.bucket(bucket_name) - if version is None: - return bucket.blob(key) - logging.debug( - "Searching {bucket_name}, {prefix}* for version {version}" - ) - versions: Iterable[Blob] = bucket.list_blobs(prefix=key, versions=True) - for v in versions: - if v.metadata is None: - continue # Skip blobs without metadata - if v.metadata.get("version") == version: - logging.info( - f"Blob {bucket_name}, {v.path} has version {version}" - ) - return v - logging.info(f"No version {version} found for {bucket_name}, {key}") - raise ValueError( - f"Could not find version {version} of blob {key} in bucket {bucket.name}" - ) - - def crc32c( - self, bucket_name: str, key: str, version: Optional[str] = None - ) -> Optional[str]: - """ - get the current CRC of the specified blob. None if it doesn't exist. - """ - logger.debug(f"Getting crc for {bucket_name}, {key}") - blob = self.get_versioned_blob(bucket_name, key, version) - - blob.reload() - logger.info(f"Crc for {bucket_name}, {key} is {blob.crc32c}") - return blob.crc32c - - def download( - self, bucket: str, key: str, version: Optional[str] = None - ) -> tuple[bytes, str]: - """ - get the blob content and associated CRC from google storage. - """ - logger.debug( - f"Downloading {bucket}, {key}{ ', version:' + version if version is not None else ''}" - ) - blob = self.get_versioned_blob(bucket, key, version) - - result = blob.download_as_bytes() - logger.info( - f"Downloaded {bucket}, {key}{ ', version:' + version if version is not None else ''}" - ) - # According to documentation blob.crc32c is updated as a side effect of - # downloading the content. As a result this should now be the crc of the downloaded - # content (i.e. there is not a race condition where it's getting the CRC from the cloud) - return (result, blob.crc32c) - - def _get_latest_version(self, bucket: str, key: str) -> Optional[str]: - """ - Get the latest version of a blob in the specified bucket and key. - If no version is specified, return None. - """ - logger.debug(f"Getting latest version of {bucket}, {key}") - blob = self.client.get_bucket(bucket).get_blob(key) - if blob is None: - logging.warning(f"No blob found in bucket {bucket} with key {key}") - return None - - if blob.metadata is None: - logging.warning( - f"No metadata found for blob {bucket}, {key}, so it has no version attached." - ) - return None - - version = blob.metadata.get("version") - if version is None: - logging.warning( - f"Blob {bucket}, {key} does not have a version in its metadata" - ) - return None - logging.info( - f"Metadata for blob {bucket}, {key} has version: {version}" - ) - return blob.metadata.get("version") diff --git a/policyengine/utils/data/version_aware_storage_client.py b/policyengine/utils/data/version_aware_storage_client.py new file mode 100644 index 00000000..0fdc144d --- /dev/null +++ b/policyengine/utils/data/version_aware_storage_client.py @@ -0,0 +1,207 @@ +""" +GCS client supporting multiple versioning strategies. + +This module provides a unified interface for downloading versioned blobs from +Google Cloud Storage, supporting both: +1. Generation-based versioning (GCS native object generations) +2. Metadata-based versioning (version string in blob metadata) +""" + +import logging +from typing import Optional + +from google.cloud.storage import Blob, Bucket, Client + +logger = logging.getLogger(__name__) + + +class VersionAwareStorageClient: + """ + GCS client supporting multiple versioning strategies. + + Versioning strategies: + 1. Generation-based: version is a GCS generation number (integer string) + 2. Metadata-based: version is stored in blob.metadata["version"] + 3. None: get the latest blob + + The client attempts to resolve versions in this order: + - If version looks like an integer, try generation-based first + - Fall back to metadata-based matching + - Raise an error if no matching version is found + """ + + def __init__(self): + self.client = Client() + + def get_blob( + self, + bucket_name: str, + key: str, + version: Optional[str] = None, + ) -> Blob: + """ + Get a blob, resolving version using the appropriate strategy. + + Args: + bucket_name: The GCS bucket name. + key: The blob path within the bucket. + version: Optional version string. Can be: + - None: get latest blob + - Integer string: treated as GCS generation number + - Other string: matched against blob metadata["version"] + + Returns: + The resolved Blob object. + + Raises: + ValueError: If a specific version is requested but not found. + """ + bucket = self.client.bucket(bucket_name) + + if version is None: + # No version specified: return latest + logger.debug( + f"No version specified for {bucket_name}/{key}, using latest" + ) + return bucket.blob(key) + + # Try generation-based first (if version looks like an integer) + if version.isdigit(): + logger.debug( + f"Version '{version}' looks like a generation number, " + f"trying generation-based lookup for {bucket_name}/{key}" + ) + try: + blob = bucket.blob(key, generation=int(version)) + # Verify the blob exists by reloading it + blob.reload() + logger.info( + f"Found blob {bucket_name}/{key} with generation {version}" + ) + return blob + except Exception as e: + logger.debug( + f"Generation-based lookup failed for {bucket_name}/{key}@{version}: {e}. " + f"Falling back to metadata-based lookup." + ) + + # Metadata-based: iterate versions and match + return self._get_blob_by_metadata_version(bucket, key, version) + + def _get_blob_by_metadata_version( + self, + bucket: Bucket, + key: str, + version: str, + ) -> Blob: + """ + Find a blob whose metadata["version"] matches the requested version. + + Args: + bucket: The GCS Bucket object. + key: The blob path within the bucket. + version: The version string to match in metadata. + + Returns: + The matching Blob object. + + Raises: + ValueError: If no blob with the matching version is found. + """ + logger.debug( + f"Searching for blob {bucket.name}/{key} with metadata version '{version}'" + ) + versions = bucket.list_blobs(prefix=key, versions=True) + for blob in versions: + if blob.metadata and blob.metadata.get("version") == version: + logger.info( + f"Found blob {bucket.name}/{key} with metadata version '{version}'" + ) + return blob + + raise ValueError( + f"No blob found with version '{version}' for {bucket.name}/{key}" + ) + + def crc32c( + self, + bucket_name: str, + key: str, + version: Optional[str] = None, + ) -> Optional[str]: + """ + Get the CRC32C checksum for a blob. + + Args: + bucket_name: The GCS bucket name. + key: The blob path within the bucket. + version: Optional version string. + + Returns: + The CRC32C checksum string, or None if blob doesn't exist. + """ + logger.debug(f"Getting CRC32C for {bucket_name}/{key}") + blob = self.get_blob(bucket_name, key, version) + blob.reload() + logger.info(f"CRC32C for {bucket_name}/{key} is {blob.crc32c}") + return blob.crc32c + + def download( + self, + bucket_name: str, + key: str, + version: Optional[str] = None, + ) -> tuple[bytes, str]: + """ + Download blob content and return (content, crc). + + Args: + bucket_name: The GCS bucket name. + key: The blob path within the bucket. + version: Optional version string. + + Returns: + A tuple of (content_bytes, crc32c_checksum). + """ + logger.debug( + f"Downloading {bucket_name}/{key}" + f"{', version: ' + version if version else ''}" + ) + blob = self.get_blob(bucket_name, key, version) + content = blob.download_as_bytes() + logger.info( + f"Downloaded {bucket_name}/{key}" + f"{', version: ' + version if version else ''}" + ) + # According to documentation, blob.crc32c is updated as a side effect of + # downloading the content. This should be the CRC of the downloaded + # content (avoiding race conditions with the cloud). + return (content, blob.crc32c) + + def _get_latest_version(self, bucket: str, key: str) -> Optional[str]: + """ + Get the latest version of a blob in the specified bucket and key. + If no version is specified, return None. + """ + logger.debug(f"Getting latest version of {bucket}, {key}") + blob = self.client.get_bucket(bucket).get_blob(key) + if blob is None: + logging.warning(f"No blob found in bucket {bucket} with key {key}") + return None + + if blob.metadata is None: + logging.warning( + f"No metadata found for blob {bucket}, {key}, so it has no version attached." + ) + return None + + version = blob.metadata.get("version") + if version is None: + logging.warning( + f"Blob {bucket}, {key} does not have a version in its metadata" + ) + return None + logging.info( + f"Metadata for blob {bucket}, {key} has version: {version}" + ) + return blob.metadata.get("version") diff --git a/policyengine/utils/data_download.py b/policyengine/utils/data_download.py index 3f33fbf7..a7f95329 100644 --- a/policyengine/utils/data_download.py +++ b/policyengine/utils/data_download.py @@ -1,25 +1,45 @@ -from pathlib import Path +""" +Download orchestration for GCS-hosted datasets. + +Terminology: +- gcs_key: The path to a file within a GCS bucket (e.g., "states/RI.h5") +- local_path: The full local filesystem path where a file is stored +""" + import logging -import os +from typing import Optional, Tuple + from policyengine.utils.google_cloud_bucket import download_file_from_gcs -from pydantic import BaseModel -import json -from typing import Tuple, Optional def download( - filepath: str, + gcs_key: str, gcs_bucket: str, version: Optional[str] = None, return_version: bool = False, -) -> Tuple[str, str] | str: +) -> Tuple[str, Optional[str]] | str: + """ + Download a file from Google Cloud Storage. + + Args: + gcs_key: The path to the file within the bucket (e.g., "states/RI.h5"). + gcs_bucket: The name of the GCS bucket. + version: Optional version string. Can be: + - A GCS generation number (integer string) + - A metadata version string (e.g., "1.2.3") + - None to get the latest version + return_version: If True, return a tuple of (local_path, version). + + Returns: + If return_version is True: (local_path, version) tuple + Otherwise: just the local_path string + """ logging.info("Using Google Cloud Storage for download.") - downloaded_version = download_file_from_gcs( + local_path, downloaded_version = download_file_from_gcs( bucket_name=gcs_bucket, - file_name=filepath, - destination_path=filepath, + gcs_key=gcs_key, version=version, ) if return_version: - return filepath, downloaded_version - return filepath + return local_path, downloaded_version + return local_path diff --git a/policyengine/utils/google_cloud_bucket.py b/policyengine/utils/google_cloud_bucket.py index a7cc6842..c89ea860 100644 --- a/policyengine/utils/google_cloud_bucket.py +++ b/policyengine/utils/google_cloud_bucket.py @@ -1,13 +1,30 @@ -from .data.caching_google_storage_client import CachingGoogleStorageClient -import asyncio +""" +High-level interface for downloading files from Google Cloud Storage. + +This module provides a singleton-based caching client that handles: +- CRC-based cache invalidation (only downloads when content changes) +- Atomic file writes (prevents partial/corrupted files) +- Multiple versioning strategies (generation-based or metadata-based) + +Terminology: +- gcs_key: The path to a file within a GCS bucket (e.g., "states/RI.h5") +- local_path: The full local filesystem path where a file is stored +- DATASETS_DIR: The local directory where all downloaded datasets are stored +""" + from pathlib import Path -from google.cloud.storage import Blob -from typing import Iterable, Optional +from typing import Optional, Tuple + +from .data.caching_google_storage_client import CachingGoogleStorageClient _caching_client: CachingGoogleStorageClient | None = None +# All downloaded datasets are stored in this directory +DATASETS_DIR = Path(".datasets") + -def _get_client(): +def _get_client() -> CachingGoogleStorageClient: + """Get or create the singleton caching client.""" global _caching_client if _caching_client is not None: return _caching_client @@ -15,34 +32,51 @@ def _get_client(): return _caching_client -def _clear_client(): +def _clear_client() -> None: + """Clear the singleton caching client (useful for testing).""" global _caching_client _caching_client = None def download_file_from_gcs( bucket_name: str, - file_name: str, - destination_path: str, + gcs_key: str, version: Optional[str] = None, -) -> str | None: +) -> Tuple[str, Optional[str]]: """ Download a file from Google Cloud Storage to a local path. + Uses a caching layer that only downloads when the file's CRC changes, + and writes files atomically to prevent corruption. + + Files are stored in the DATASETS_DIR (.datasets/) directory, preserving + the GCS key structure. For example: + - gcs_key="enhanced_cps_2024.h5" -> .datasets/enhanced_cps_2024.h5 + - gcs_key="states/RI.h5" -> .datasets/states/RI.h5 + - gcs_key="districts/CA-01.h5" -> .datasets/districts/CA-01.h5 + Args: - bucket_name (str): The name of the GCS bucket. - file_name (str): The name of the file in the GCS bucket. - destination_path (str): The local path where the file will be saved. + bucket_name: The name of the GCS bucket. + gcs_key: The path to the file within the bucket. + version: Optional version string. Can be: + - A GCS generation number (integer string) + - A metadata version string (e.g., "1.2.3") + - None to get the latest version Returns: - version (str): The version of the file that was downloaded, if available. + A tuple of (local_path, version) where: + - local_path: The local filesystem path where the file was saved + - version: The version string of the downloaded file, or None if + no version metadata is available """ + local_path = DATASETS_DIR / gcs_key + local_path.parent.mkdir(parents=True, exist_ok=True) version = _get_client().download( bucket_name, - file_name, - Path(destination_path), + gcs_key, + local_path, version=version, return_version=True, ) - return version + return str(local_path), version diff --git a/policyengine/utils/maps.py b/policyengine/utils/maps.py index a6963f9b..2f8b6efa 100644 --- a/policyengine/utils/maps.py +++ b/policyengine/utils/maps.py @@ -8,16 +8,16 @@ def get_location_options_table(location_type: str) -> pd.DataFrame: if location_type == "parliamentary_constituencies": - area_names_file_path = download( - repo="policyengine/policyengine-uk-data", - filepath="constituencies_2024.csv", + area_names_local_path = download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="constituencies_2024.csv", ) elif location_type == "local_authorities": - area_names_file_path = download( - repo="policyengine/policyengine-uk-data", - filepath="local_authorities_2021.csv", + area_names_local_path = download( + gcs_bucket="policyengine-uk-data-private", + gcs_key="local_authorities_2021.csv", ) - df = pd.read_csv(area_names_file_path) + df = pd.read_csv(area_names_local_path) return df diff --git a/tests/utils/data/conftest.py b/tests/utils/data/conftest.py index 19cbcd9a..4d42ee01 100644 --- a/tests/utils/data/conftest.py +++ b/tests/utils/data/conftest.py @@ -5,36 +5,41 @@ class MockedStorageSupport: - def __init__(self, mock_simple_storage_client): - self.mock_simple_storage_client = mock_simple_storage_client + """Test support class for mocking the VersionAwareStorageClient.""" + + def __init__(self, mock_storage_client): + self.mock_storage_client = mock_storage_client def given_stored_data(self, data: str, crc: str): - self.mock_simple_storage_client.crc32c.return_value = crc - self.mock_simple_storage_client.download.return_value = ( + """Configure the mock to return specific data and CRC.""" + self.mock_storage_client.crc32c.return_value = crc + self.mock_storage_client.download.return_value = ( data.encode(), crc, ) - self.mock_simple_storage_client._get_latest_version.return_value = ( + self.mock_storage_client._get_latest_version.return_value = ( VALID_VERSION ) def given_crc_changes_on_download( self, data: str, initial_crc: str, download_crc: str ): - self.mock_simple_storage_client.crc32c.return_value = initial_crc - self.mock_simple_storage_client.download.return_value = ( + """Configure the mock where CRC differs between check and download.""" + self.mock_storage_client.crc32c.return_value = initial_crc + self.mock_storage_client.download.return_value = ( data.encode(), download_crc, ) - self.mock_simple_storage_client._get_latest_version.return_value = ( + self.mock_storage_client._get_latest_version.return_value = ( VALID_VERSION ) @pytest.fixture() def mocked_storage(): + """Fixture that mocks the VersionAwareStorageClient for testing.""" with patch( - "policyengine.utils.data.caching_google_storage_client.SimplifiedGoogleStorageClient", + "policyengine.utils.data.caching_google_storage_client.VersionAwareStorageClient", autospec=True, ) as mock_class: mock_instance = mock_class.return_value diff --git a/tests/utils/data/test_datasets.py b/tests/utils/data/test_datasets.py new file mode 100644 index 00000000..48ae5f5b --- /dev/null +++ b/tests/utils/data/test_datasets.py @@ -0,0 +1,82 @@ +"""Tests for datasets.py utilities.""" + +import pytest +from policyengine.utils.data.datasets import process_gs_path + + +class TestProcessGsPath: + """Tests for process_gs_path function.""" + + def test_basic_path(self): + """Test parsing a basic gs:// path without version.""" + bucket, path, version = process_gs_path( + "gs://my-bucket/path/to/file.h5" + ) + assert bucket == "my-bucket" + assert path == "path/to/file.h5" + assert version is None + + def test_path_with_version(self): + """Test parsing a gs:// path with @version suffix.""" + bucket, path, version = process_gs_path( + "gs://my-bucket/path/to/file.h5@1.2.3" + ) + assert bucket == "my-bucket" + assert path == "path/to/file.h5" + assert version == "1.2.3" + + def test_path_with_numeric_version(self): + """Test parsing a gs:// path with numeric version (GCS generation).""" + bucket, path, version = process_gs_path( + "gs://my-bucket/file.h5@1234567890" + ) + assert bucket == "my-bucket" + assert path == "file.h5" + assert version == "1234567890" + + def test_path_with_nested_directories(self): + """Test parsing a gs:// path with deeply nested directories.""" + bucket, path, version = process_gs_path( + "gs://policyengine-us-data/states/CA/districts/01.h5@2024.1.0" + ) + assert bucket == "policyengine-us-data" + assert path == "states/CA/districts/01.h5" + assert version == "2024.1.0" + + def test_invalid_path_no_gs_prefix(self): + """Test that non-gs:// paths raise ValueError.""" + with pytest.raises(ValueError) as exc_info: + process_gs_path("https://storage.googleapis.com/bucket/file.h5") + assert "Invalid gs:// URL format" in str(exc_info.value) + + def test_invalid_path_no_file(self): + """Test that paths without a file path raise ValueError.""" + with pytest.raises(ValueError) as exc_info: + process_gs_path("gs://bucket-only") + assert "Invalid gs:// URL format" in str(exc_info.value) + + def test_real_policyengine_paths(self): + """Test parsing actual PolicyEngine dataset paths.""" + # US data path + bucket, path, version = process_gs_path( + "gs://policyengine-us-data/cps_2023.h5" + ) + assert bucket == "policyengine-us-data" + assert path == "cps_2023.h5" + assert version is None + + # UK data path with version + bucket, path, version = process_gs_path( + "gs://policyengine-uk-data-private/enhanced_frs_2023_24.h5@1.0.0" + ) + assert bucket == "policyengine-uk-data-private" + assert path == "enhanced_frs_2023_24.h5" + assert version == "1.0.0" + + # State-level dataset + bucket, path, version = process_gs_path( + "gs://policyengine-us-data/states/CA.h5" + ) + assert bucket == "policyengine-us-data" + assert path == "states/CA.h5" + assert version is None diff --git a/tests/utils/data/test_google_cloud_bucket.py b/tests/utils/data/test_google_cloud_bucket.py index 71fd6ea4..3f358936 100644 --- a/tests/utils/data/test_google_cloud_bucket.py +++ b/tests/utils/data/test_google_cloud_bucket.py @@ -5,6 +5,7 @@ from policyengine.utils.google_cloud_bucket import ( download_file_from_gcs, _clear_client, + DATASETS_DIR, ) @@ -17,17 +18,25 @@ def setUp(self): autospec=True, ) def test_download_uses_storage_client(self, client_class): + """Test that download calls the caching client with correct arguments.""" client_instance = client_class.return_value - download_file_from_gcs( + client_instance.download.return_value = "1.0.0" + + local_path, version = download_file_from_gcs( "TEST_BUCKET", "TEST/FILE/NAME.TXT", - "TARGET/PATH", version=None, ) + + # Verify the local path is constructed correctly + expected_local_path = DATASETS_DIR / "TEST/FILE/NAME.TXT" + assert local_path == str(expected_local_path) + + # Verify the client was called with correct arguments client_instance.download.assert_called_with( "TEST_BUCKET", "TEST/FILE/NAME.TXT", - Path("TARGET/PATH"), + expected_local_path, version=None, return_version=True, ) @@ -37,10 +46,49 @@ def test_download_uses_storage_client(self, client_class): autospec=True, ) def test_download_only_creates_client_once(self, client_class): - download_file_from_gcs( - "TEST_BUCKET", "TEST/FILE/NAME.TXT", "TARGET/PATH", version=None + """Test that the singleton client is reused across calls.""" + client_instance = client_class.return_value + client_instance.download.return_value = "1.0.0" + + download_file_from_gcs("TEST_BUCKET", "file1.h5", version=None) + download_file_from_gcs("TEST_BUCKET", "file2.h5", version=None) + client_class.assert_called_once() + + @patch( + "policyengine.utils.google_cloud_bucket.CachingGoogleStorageClient", + autospec=True, + ) + def test_download_returns_local_path_and_version(self, client_class): + """Test that download returns both local path and version.""" + client_instance = client_class.return_value + client_instance.download.return_value = "2.3.4" + + local_path, version = download_file_from_gcs( + "my-bucket", + "states/CA.h5", + version="2.3.4", + ) + + assert local_path == str(DATASETS_DIR / "states/CA.h5") + assert version == "2.3.4" + + @patch( + "policyengine.utils.google_cloud_bucket.CachingGoogleStorageClient", + autospec=True, + ) + def test_download_creates_nested_directories(self, client_class): + """Test that nested GCS keys result in correct local paths.""" + client_instance = client_class.return_value + client_instance.download.return_value = None + + # Test nested path + local_path, _ = download_file_from_gcs( + "bucket", "districts/CA-01.h5", version=None ) - download_file_from_gcs( - "TEST_BUCKET", "TEST/FILE/NAME.TXT", "ANOTHER/PATH", version=None + assert local_path == str(DATASETS_DIR / "districts/CA-01.h5") + + # Test flat path + local_path, _ = download_file_from_gcs( + "bucket", "enhanced_cps_2024.h5", version=None ) - client_class.assert_called_once() + assert local_path == str(DATASETS_DIR / "enhanced_cps_2024.h5") diff --git a/tests/utils/data/test_simplified_google_storage_client.py b/tests/utils/data/test_simplified_google_storage_client.py deleted file mode 100644 index b61fbf5f..00000000 --- a/tests/utils/data/test_simplified_google_storage_client.py +++ /dev/null @@ -1,107 +0,0 @@ -from unittest.mock import patch, call -import pytest -from policyengine.utils.data import SimplifiedGoogleStorageClient - -VALID_VERSION = "1.2.3" - - -class TestSimplifiedGoogleStorageClient: - @patch( - "policyengine.utils.data.simplified_google_storage_client.Client", - autospec=True, - ) - def test_crc32c__gets_crc(self, mock_client_class): - mock_instance = mock_client_class.return_value - bucket = mock_instance.bucket.return_value - blob = bucket.blob.return_value - - blob.crc32c = "TEST_CRC" - - client = SimplifiedGoogleStorageClient() - assert client.crc32c("bucket_name", "content.txt") == "TEST_CRC" - mock_instance.bucket.assert_called_with("bucket_name") - bucket.blob.assert_called_with("content.txt") - blob.reload.assert_called() - - @patch( - "policyengine.utils.data.simplified_google_storage_client.Client", - autospec=True, - ) - def test_download__downloads_content(self, mock_client_class): - mock_instance = mock_client_class.return_value - bucket = mock_instance.bucket.return_value - blob = bucket.blob.return_value - - blob.download_as_bytes.return_value = "hello, world".encode() - blob.crc32c = "TEST_CRC" - - client = SimplifiedGoogleStorageClient() - [data, crc] = client.download("bucket", "blob.txt") - assert data == "hello, world".encode() - assert crc == "TEST_CRC" - - mock_instance.bucket.assert_called_with("bucket") - bucket.blob.assert_called_with("blob.txt") - - @patch( - "policyengine.utils.data.simplified_google_storage_client.Client", - autospec=True, - ) - def test_get_latest_version__returns_version_from_metadata( - self, mock_client_class - ): - mock_instance = mock_client_class.return_value - bucket = mock_instance.get_bucket.return_value - blob = bucket.get_blob.return_value - - # Test case where metadata exists with version - blob.metadata = {"version": VALID_VERSION} - - client = SimplifiedGoogleStorageClient() - result = client._get_latest_version("test_bucket", "test_key") - - assert result == VALID_VERSION - mock_instance.get_bucket.assert_called_with("test_bucket") - bucket.get_blob.assert_called_with("test_key") - - @patch( - "policyengine.utils.data.simplified_google_storage_client.Client", - autospec=True, - ) - def test_get_latest_version__returns_none_when_no_metadata( - self, mock_client_class - ): - mock_instance = mock_client_class.return_value - bucket = mock_instance.get_bucket.return_value - blob = bucket.get_blob.return_value - - # Test case where metadata is None - blob.metadata = None - - client = SimplifiedGoogleStorageClient() - result = client._get_latest_version("test_bucket", "test_key") - - assert result is None - mock_instance.get_bucket.assert_called_with("test_bucket") - bucket.get_blob.assert_called_with("test_key") - - @patch( - "policyengine.utils.data.simplified_google_storage_client.Client", - autospec=True, - ) - def test_get_latest_version__returns_none_when_no_version_in_metadata( - self, mock_client_class - ): - mock_instance = mock_client_class.return_value - bucket = mock_instance.get_bucket.return_value - blob = bucket.get_blob.return_value - - # Test case where metadata exists but no version field - blob.metadata = {"other_field": "value"} - - client = SimplifiedGoogleStorageClient() - result = client._get_latest_version("test_bucket", "test_key") - - assert result is None - mock_instance.get_bucket.assert_called_with("test_bucket") - bucket.get_blob.assert_called_with("test_key") diff --git a/tests/utils/data/test_version_aware_storage_client.py b/tests/utils/data/test_version_aware_storage_client.py new file mode 100644 index 00000000..38b6ebac --- /dev/null +++ b/tests/utils/data/test_version_aware_storage_client.py @@ -0,0 +1,227 @@ +"""Tests for VersionAwareStorageClient.""" + +from unittest.mock import MagicMock, patch, PropertyMock +import pytest +from policyengine.utils.data import VersionAwareStorageClient + +VALID_VERSION = "1.2.3" + + +class TestVersionAwareStorageClient: + """Tests for VersionAwareStorageClient.""" + + @patch( + "policyengine.utils.data.version_aware_storage_client.Client", + autospec=True, + ) + def test_crc32c__gets_crc(self, mock_client_class): + """Test that crc32c returns the blob's CRC.""" + mock_instance = mock_client_class.return_value + bucket = mock_instance.bucket.return_value + blob = bucket.blob.return_value + + blob.crc32c = "TEST_CRC" + + client = VersionAwareStorageClient() + assert client.crc32c("bucket_name", "content.txt") == "TEST_CRC" + mock_instance.bucket.assert_called_with("bucket_name") + bucket.blob.assert_called_with("content.txt") + blob.reload.assert_called() + + @patch( + "policyengine.utils.data.version_aware_storage_client.Client", + autospec=True, + ) + def test_download__downloads_content(self, mock_client_class): + """Test that download returns content and CRC.""" + mock_instance = mock_client_class.return_value + bucket = mock_instance.bucket.return_value + blob = bucket.blob.return_value + + blob.download_as_bytes.return_value = "hello, world".encode() + blob.crc32c = "TEST_CRC" + + client = VersionAwareStorageClient() + data, crc = client.download("bucket", "blob.txt") + assert data == "hello, world".encode() + assert crc == "TEST_CRC" + + mock_instance.bucket.assert_called_with("bucket") + bucket.blob.assert_called_with("blob.txt") + + @patch( + "policyengine.utils.data.version_aware_storage_client.Client", + autospec=True, + ) + def test_get_latest_version__returns_version_from_metadata( + self, mock_client_class + ): + """Test that _get_latest_version returns the version from blob metadata.""" + mock_instance = mock_client_class.return_value + bucket = mock_instance.get_bucket.return_value + blob = bucket.get_blob.return_value + + blob.metadata = {"version": VALID_VERSION} + + client = VersionAwareStorageClient() + result = client._get_latest_version("test_bucket", "test_key") + + assert result == VALID_VERSION + mock_instance.get_bucket.assert_called_with("test_bucket") + bucket.get_blob.assert_called_with("test_key") + + @patch( + "policyengine.utils.data.version_aware_storage_client.Client", + autospec=True, + ) + def test_get_latest_version__returns_none_when_no_metadata( + self, mock_client_class + ): + """Test that _get_latest_version returns None when blob has no metadata.""" + mock_instance = mock_client_class.return_value + bucket = mock_instance.get_bucket.return_value + blob = bucket.get_blob.return_value + + blob.metadata = None + + client = VersionAwareStorageClient() + result = client._get_latest_version("test_bucket", "test_key") + + assert result is None + mock_instance.get_bucket.assert_called_with("test_bucket") + bucket.get_blob.assert_called_with("test_key") + + @patch( + "policyengine.utils.data.version_aware_storage_client.Client", + autospec=True, + ) + def test_get_latest_version__returns_none_when_no_version_in_metadata( + self, mock_client_class + ): + """Test that _get_latest_version returns None when metadata has no version field.""" + mock_instance = mock_client_class.return_value + bucket = mock_instance.get_bucket.return_value + blob = bucket.get_blob.return_value + + blob.metadata = {"other_field": "value"} + + client = VersionAwareStorageClient() + result = client._get_latest_version("test_bucket", "test_key") + + assert result is None + mock_instance.get_bucket.assert_called_with("test_bucket") + bucket.get_blob.assert_called_with("test_key") + + @patch( + "policyengine.utils.data.version_aware_storage_client.Client", + autospec=True, + ) + def test_get_blob__no_version_returns_latest(self, mock_client_class): + """Test that get_blob with no version returns the latest blob.""" + mock_instance = mock_client_class.return_value + bucket = mock_instance.bucket.return_value + expected_blob = bucket.blob.return_value + + client = VersionAwareStorageClient() + result = client.get_blob("test_bucket", "test_key") + + assert result == expected_blob + mock_instance.bucket.assert_called_with("test_bucket") + bucket.blob.assert_called_with("test_key") + + @patch( + "policyengine.utils.data.version_aware_storage_client.Client", + autospec=True, + ) + def test_get_blob__generation_version_uses_generation( + self, mock_client_class + ): + """Test that numeric version strings are treated as GCS generations.""" + mock_instance = mock_client_class.return_value + bucket = mock_instance.bucket.return_value + expected_blob = bucket.blob.return_value + + client = VersionAwareStorageClient() + result = client.get_blob( + "test_bucket", "test_key", version="1234567890" + ) + + assert result == expected_blob + mock_instance.bucket.assert_called_with("test_bucket") + bucket.blob.assert_called_with("test_key", generation=1234567890) + + @patch( + "policyengine.utils.data.version_aware_storage_client.Client", + autospec=True, + ) + def test_get_blob__metadata_version_searches_blobs( + self, mock_client_class + ): + """Test that semantic version strings search blob metadata.""" + mock_instance = mock_client_class.return_value + bucket = mock_instance.bucket.return_value + + # Create mock blobs with different metadata versions + blob1 = MagicMock() + blob1.metadata = {"version": "1.0.0"} + blob2 = MagicMock() + blob2.metadata = {"version": "1.2.3"} + blob3 = MagicMock() + blob3.metadata = None + + bucket.list_blobs.return_value = [blob1, blob3, blob2] + + client = VersionAwareStorageClient() + result = client.get_blob("test_bucket", "test_key", version="1.2.3") + + assert result == blob2 + bucket.list_blobs.assert_called_with(prefix="test_key", versions=True) + + @patch( + "policyengine.utils.data.version_aware_storage_client.Client", + autospec=True, + ) + def test_get_blob__metadata_version_not_found_raises( + self, mock_client_class + ): + """Test that missing metadata version raises ValueError.""" + mock_instance = mock_client_class.return_value + bucket = mock_instance.bucket.return_value + + blob1 = MagicMock() + blob1.metadata = {"version": "1.0.0"} + + bucket.list_blobs.return_value = [blob1] + + client = VersionAwareStorageClient() + with pytest.raises(ValueError) as exc_info: + client.get_blob("test_bucket", "test_key", version="2.0.0") + + assert "No blob found with version '2.0.0'" in str(exc_info.value) + + @patch( + "policyengine.utils.data.version_aware_storage_client.Client", + autospec=True, + ) + def test_get_blob__generation_fallback_to_metadata( + self, mock_client_class + ): + """Test that generation lookup falls back to metadata if reload fails.""" + mock_instance = mock_client_class.return_value + bucket = mock_instance.bucket.return_value + + # Make generation-based lookup fail + generation_blob = bucket.blob.return_value + generation_blob.reload.side_effect = Exception("Generation not found") + + # Set up metadata-based lookup to succeed + metadata_blob = MagicMock() + metadata_blob.metadata = {"version": "999"} + bucket.list_blobs.return_value = [metadata_blob] + + client = VersionAwareStorageClient() + # "999" looks like a number, so it tries generation first, then falls back + result = client.get_blob("test_bucket", "test_key", version="999") + + assert result == metadata_blob + bucket.list_blobs.assert_called_with(prefix="test_key", versions=True)