Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/code_changes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ on:
pull_request:
branches:
- main
- '0.x' # Maintenance branch for legacy 0.x releases

paths:
- policyengine/**
Expand Down
6 changes: 6 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- bump: minor
changes:
changed:
- Replaced US state filtering with conditional file loading for states and Congressional districts
- Updated data source URLs for state and district files
- Modified GCS loading to support new versioning introduced by -core v3.21.0
Original file line number Diff line number Diff line change
Expand Up @@ -712,21 +712,21 @@ def uk_constituency_breakdown(
baseline_hnet = baseline.household_net_income
reform_hnet = reform.household_net_income

constituency_weights_path = download(
constituency_weights_local_path = download(
gcs_bucket="policyengine-uk-data-private",
filepath="parliamentary_constituency_weights.h5",
gcs_key="parliamentary_constituency_weights.h5",
)
with h5py.File(constituency_weights_path, "r") as f:
with h5py.File(constituency_weights_local_path, "r") as f:
weights = f["2025"][
...
] # {2025: array(650, 100180) where cell i, j is the weight of household record i in constituency j}

constituency_names_path = download(
constituency_names_local_path = download(
gcs_bucket="policyengine-uk-data-private",
filepath="constituencies_2024.csv",
gcs_key="constituencies_2024.csv",
)
constituency_names = pd.read_csv(
constituency_names_path
constituency_names_local_path
) # columns code (constituency code), name (constituency name), x, y (geographic position)

for i in range(len(constituency_names)):
Expand Down
170 changes: 111 additions & 59 deletions policyengine/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .utils.data.datasets import (
get_default_dataset,
process_gs_path,
POLICYENGINE_DATASETS,
DATASET_TIME_PERIODS,
)
from policyengine_core.simulations import Simulation as CountrySimulation
Expand Down Expand Up @@ -146,11 +145,23 @@ def _add_output_functions(self):
)

def _set_data(self, file_address: str | None = None) -> None:
"""Load and set the dataset for this simulation."""
# Step 1: Resolve file address (if None, get default)
file_address = self._resolve_file_address(file_address)
print(f"Using dataset: {file_address}", file=sys.stderr)

# filename refers to file's unique name + extension;
# file_address refers to URI + filename
# Step 2: Acquire the file (download if PE dataset, or use local path)
filepath, version = self._acquire_dataset_file(file_address)
self.data_version = version

# If None is passed, user wants default dataset; get URL, then continue initializing.
# Step 3: Load into country-specific format
time_period = self._set_data_time_period(file_address)
self.options.data = self._load_dataset_for_country(
filepath, time_period
)

def _resolve_file_address(self, file_address: str | None) -> str:
"""If no file address provided, get the default dataset for this country/region."""
if file_address is None:
file_address = get_default_dataset(
country=self.options.country, region=self.options.region
Expand All @@ -159,31 +170,33 @@ def _set_data(self, file_address: str | None = None) -> None:
f"No data provided, using default dataset: {file_address}",
file=sys.stderr,
)
return file_address

if file_address not in POLICYENGINE_DATASETS:
# If it's a local file, no URI present and unable to infer version.
filename = file_address
version = None

def _acquire_dataset_file(
self, file_address: str
) -> tuple[str, str | None]:
"""
Get the dataset file, downloading from GCS if it's a GCS path.
Returns (filepath, version) where version is None for local files.
"""
if file_address.startswith("gs://"):
return self._set_data_from_gs(file_address)
else:
# All official PolicyEngine datasets are stored in GCS;
# load accordingly
filename, version = self._set_data_from_gs(file_address)
self.data_version = version

time_period = self._set_data_time_period(file_address)
# Local file - no download needed, no version available
return file_address, None

# UK needs custom loading
def _load_dataset_for_country(
self, filepath: str, time_period: int | None
) -> Dataset:
"""Load the dataset file using the appropriate country-specific loader."""
if self.options.country == "us":
self.options.data = Dataset.from_file(
filename, time_period=time_period
)
else:
return Dataset.from_file(filepath, time_period=time_period)
elif self.options.country == "uk":
from policyengine_uk.data import UKSingleYearDataset

self.options.data = UKSingleYearDataset(
file_path=filename,
)
return UKSingleYearDataset(file_path=filepath)
else:
raise ValueError(f"Unsupported country: {self.options.country}")

def _initialise_simulations(self):
self.baseline_simulation = self._initialise_simulation(
Expand Down Expand Up @@ -270,18 +283,12 @@ def _apply_region_to_simulation(
time_period: TimePeriodType,
) -> CountrySimulation:
if country == "us":
df = simulation.to_input_dataframe()
state_code = simulation.calculate(
"state_code_str", map_to="person"
).values
if region == "city/nyc":
in_nyc = simulation.calculate("in_nyc", map_to="person").values
simulation = simulation_type(dataset=df[in_nyc], reform=reform)
elif "state/" in region:
state = region.split("/")[1]
simulation = simulation_type(
dataset=df[state_code == state.upper()], reform=reform
)
simulation = self._apply_us_region_to_simulation(
simulation=simulation,
simulation_type=simulation_type,
region=region,
reform=reform,
)
elif country == "uk":
if "country/" in region:
region = region.split("/")[1]
Expand All @@ -294,14 +301,14 @@ def _apply_region_to_simulation(
)
elif "constituency/" in region:
constituency = region.split("/")[1]
constituency_names_file_path = download(
constituency_names_local_path = download(
gcs_bucket="policyengine-uk-data-private",
filepath="constituencies_2024.csv",
gcs_key="constituencies_2024.csv",
)
constituency_names_file_path = Path(
constituency_names_file_path
constituency_names_local_path = Path(
constituency_names_local_path
)
constituency_names = pd.read_csv(constituency_names_file_path)
constituency_names = pd.read_csv(constituency_names_local_path)
if constituency in constituency_names.code.values:
constituency_id = constituency_names[
constituency_names.code == constituency
Expand All @@ -312,14 +319,14 @@ def _apply_region_to_simulation(
].index[0]
else:
raise ValueError(
f"Constituency {constituency} not found. See {constituency_names_file_path} for the list of available constituencies."
f"Constituency {constituency} not found. See {constituency_names_local_path} for the list of available constituencies."
)
weights_file_path = download(
weights_local_path = download(
gcs_bucket="policyengine-uk-data-private",
filepath="parliamentary_constituency_weights.h5",
gcs_key="parliamentary_constituency_weights.h5",
)

with h5py.File(weights_file_path, "r") as f:
with h5py.File(weights_local_path, "r") as f:
weights = f[str(time_period)][...]

simulation.set_input(
Expand All @@ -329,26 +336,26 @@ def _apply_region_to_simulation(
)
elif "local_authority/" in region:
la = region.split("/")[1]
la_names_file_path = download(
la_names_local_path = download(
gcs_bucket="policyengine-uk-data-private",
filepath="local_authorities_2021.csv",
gcs_key="local_authorities_2021.csv",
)
la_names_file_path = Path(la_names_file_path)
la_names = pd.read_csv(la_names_file_path)
la_names_local_path = Path(la_names_local_path)
la_names = pd.read_csv(la_names_local_path)
if la in la_names.code.values:
la_id = la_names[la_names.code == la].index[0]
elif la in la_names.name.values:
la_id = la_names[la_names.name == la].index[0]
else:
raise ValueError(
f"Local authority {la} not found. See {la_names_file_path} for the list of available local authorities."
f"Local authority {la} not found. See {la_names_local_path} for the list of available local authorities."
)
weights_file_path = download(
weights_local_path = download(
gcs_bucket="policyengine-uk-data-private",
filepath="local_authority_weights.h5",
gcs_key="local_authority_weights.h5",
)

with h5py.File(weights_file_path, "r") as f:
with h5py.File(weights_local_path, "r") as f:
weights = f[str(self.time_period)][...]

simulation.set_input(
Expand All @@ -359,6 +366,40 @@ def _apply_region_to_simulation(

return simulation

def _apply_us_region_to_simulation(
self,
simulation: CountryMicrosimulation,
simulation_type: type,
region: RegionType,
reform: ReformType | None,
) -> CountrySimulation:
"""Apply US-specific regional filtering to a simulation.

Note: Most US regions (states, congressional districts) now use
scoped datasets rather than filtering. Only NYC still requires
filtering from the national dataset (and is still using the pooled
CPS by default). This should be replaced with an approach based on
the new datasets.
"""
if region == "city/nyc":
simulation = self._filter_us_simulation_by_nyc(
simulation=simulation,
simulation_type=simulation_type,
reform=reform,
)
return simulation

def _filter_us_simulation_by_nyc(
self,
simulation: CountryMicrosimulation,
simulation_type: type,
reform: ReformType | None,
) -> CountrySimulation:
"""Filter a US simulation to only include NYC households."""
df = simulation.to_input_dataframe()
in_nyc = simulation.calculate("in_nyc", map_to="person").values
return simulation_type(dataset=df[in_nyc], reform=reform)

def check_model_version(self) -> None:
"""
Check the package versions of the simulation against the current package versions.
Expand Down Expand Up @@ -402,19 +443,30 @@ def _set_data_time_period(self, file_address: str) -> Optional[int]:

def _set_data_from_gs(self, file_address: str) -> tuple[str, str | None]:
"""
Set the data from a GCS path and return the filename and version.
Download data from a GCS path and return the local path and version.

Supports version specification in three ways (in priority order):
1. Explicit data_version option: Simulation(data="gs://...", data_version="1.2.3")
2. URL suffix: Simulation(data="gs://bucket/file.h5@1.2.3")
3. None (latest): Simulation(data="gs://bucket/file.h5")

Returns:
A tuple of (local_path, version) where:
- local_path: The local filesystem path where the file was saved
- version: The version string, or None if not available
"""
bucket, gcs_key, url_version = process_gs_path(file_address)

bucket, filename = process_gs_path(file_address)
version = self.options.data_version
# Priority: explicit option > URL suffix > None (latest)
version = self.options.data_version or url_version

print(f"Downloading {filename} from bucket {bucket}", file=sys.stderr)
print(f"Downloading {gcs_key} from bucket {bucket}", file=sys.stderr)

filepath, version = download(
filepath=filename,
local_path, version = download(
gcs_key=gcs_key,
gcs_bucket=bucket,
version=version,
return_version=True,
)

return filename, version
return local_path, version
2 changes: 1 addition & 1 deletion policyengine/utils/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .caching_google_storage_client import CachingGoogleStorageClient
from .simplified_google_storage_client import SimplifiedGoogleStorageClient
from .version_aware_storage_client import VersionAwareStorageClient
26 changes: 22 additions & 4 deletions policyengine/utils/data/caching_google_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,27 @@
from pathlib import Path
from policyengine_core.data.dataset import atomic_write
import logging
from .simplified_google_storage_client import SimplifiedGoogleStorageClient
from .version_aware_storage_client import VersionAwareStorageClient
from typing import Optional

logger = logging.getLogger(__name__)


class CachingGoogleStorageClient(AbstractContextManager):
"""
Client for downloaded resources from a google storage bucket only when the CRC
of the blob changes.
Client for downloading resources from a Google Storage bucket with caching.

Only downloads when the CRC of the blob changes, using a disk-based cache
to persist downloaded data across sessions.

The client supports multiple versioning strategies via VersionAwareStorageClient:
- Generation-based: version is a GCS generation number
- Metadata-based: version is stored in blob.metadata["version"]
- Latest: when no version is specified, gets the latest blob
"""

def __init__(self):
self.client = SimplifiedGoogleStorageClient()
self.client = VersionAwareStorageClient()
self.cache = diskcache.Cache()

def _data_key(
Expand All @@ -36,6 +43,17 @@ def download(
):
"""
Atomically write the latest version of the cloud storage blob to the target path.

Args:
bucket: The GCS bucket name.
key: The blob path within the bucket.
target: The local file path to write the downloaded content to.
version: Optional version string. Can be a GCS generation number,
a metadata version string, or None for latest.
return_version: If True, return the version string of the downloaded blob.

Returns:
The version string if return_version is True, otherwise None.
"""
if version is None:
# If no version is specified, get the latest version from the cache
Expand Down
Loading
Loading