Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
- bump: minor
changes:
added:
- Firm dataset support with 3 new entities (firm, sector, business_group)
- 16 firm variables for business microdata analysis
- UKFirmSingleYearDataset class for firm data handling
- FIRM_2023_24 dataset constant for firm_2023_24.h5 integration
1 change: 1 addition & 0 deletions policyengine_uk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
CountryTaxBenefitSystem,
Microsimulation,
Simulation,
FirmSimulation,
COUNTRY_DIR,
parameters,
variables,
Expand Down
3 changes: 3 additions & 0 deletions policyengine_uk/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
UKMultiYearDataset,
UKSingleYearDataset,
)
from policyengine_uk.data.firm_dataset_schema import (
UKFirmSingleYearDataset,
)
138 changes: 138 additions & 0 deletions policyengine_uk/data/firm_dataset_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import pandas as pd
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from policyengine_uk import Microsimulation

from pathlib import Path
import h5py


class UKFirmSingleYearDataset:
firm: pd.DataFrame
sector: pd.DataFrame
business_group: pd.DataFrame

@staticmethod
def validate_file_path(file_path: str, raise_exception: bool = True):
if not file_path.endswith(".h5"):
if raise_exception:
raise ValueError(
"File path must end with '.h5' for UKFirmDataset."
)
return False
if not Path(file_path).exists():
if raise_exception:
raise FileNotFoundError(f"File not found: {file_path}")
return False

# Check if the file contains time_period, firm, sector, and business_group datasets
with h5py.File(file_path, "r") as f:
required_datasets = [
"time_period",
"firm",
"sector",
"business_group",
]
for dataset in required_datasets:
if dataset not in f:
if raise_exception:
raise ValueError(
f"Dataset '{dataset}' not found in the file: {file_path}"
)
else:
return False

return True

def __init__(
self,
file_path: str = None,
firm: pd.DataFrame = None,
sector: pd.DataFrame = None,
business_group: pd.DataFrame = None,
fiscal_year: int = 2025,
):
file_path = str(file_path) if file_path else None
if file_path is not None:
self.validate_file_path(file_path)
with pd.HDFStore(file_path) as f:
self.firm = f["firm"]
self.sector = f["sector"]
self.business_group = f["business_group"]
self.time_period = str(f["time_period"].iloc[0])
else:
if firm is None or sector is None or business_group is None:
raise ValueError(
"Must provide either a file path or all three DataFrames (firm, sector, business_group)."
)
self.firm = firm
self.sector = sector
self.business_group = business_group
self.time_period = str(fiscal_year)

self.data_format = "arrays"
self.tables = (self.firm, self.sector, self.business_group)
self.table_names = ("firm", "sector", "business_group")

def save(self, file_path: str):
with pd.HDFStore(file_path) as f:
f.put("firm", self.firm, format="table", data_columns=True)
f.put("sector", self.sector, format="table", data_columns=True)
f.put(
"business_group",
self.business_group,
format="table",
data_columns=True,
)
f.put("time_period", pd.Series([self.time_period]), format="table")

def load(self):
data = {}
for df in (self.firm, self.sector, self.business_group):
for col in df.columns:
data[col] = df[col].values

return data

def copy(self):
return UKFirmSingleYearDataset(
firm=self.firm.copy(),
sector=self.sector.copy(),
business_group=self.business_group.copy(),
fiscal_year=self.time_period,
)

def validate(self):
# Check for NaNs in the tables
for df in self.tables:
for col in df.columns:
if df[col].isna().any():
raise ValueError(f"Column '{col}' contains NaN values.")

@staticmethod
def from_simulation(
simulation: "Microsimulation", fiscal_year: int = 2025
):
entity_dfs = {}

for entity in ["firm", "sector", "business_group"]:
input_variables = [
variable
for variable in simulation.input_variables
if simulation.tax_benefit_system.variables[variable].entity.key
== entity
]
if len(input_variables) == 0:
entity_dfs[entity] = pd.DataFrame()
else:
entity_dfs[entity] = simulation.calculate_dataframe(
input_variables, period=fiscal_year
)

return UKFirmSingleYearDataset(
firm=entity_dfs["firm"],
sector=entity_dfs["sector"],
business_group=entity_dfs["business_group"],
fiscal_year=fiscal_year,
)
40 changes: 39 additions & 1 deletion policyengine_uk/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,42 @@
is_person=True,
)

entities = [Household, BenUnit, Person]
Firm = build_entity(
key="firm",
plural="firms",
label="Firm",
doc="A business entity with employees and operations.",
is_person=True,
)

Sector = build_entity(
key="sector",
plural="sectors",
label="Sector",
doc="An economic sector containing multiple firms.",
roles=[
{
"key": "member",
"plural": "members",
"label": "Member",
"doc": "A sector in the business classification.",
}
],
)

BusinessGroup = build_entity(
key="business_group",
plural="business_groups",
label="Business Group",
doc="A business group containing multiple firms.",
roles=[
{
"key": "member",
"plural": "members",
"label": "Member",
"doc": "A business group in the firm dataset.",
}
],
)

entities = [Household, BenUnit, Person, Firm, Sector, BusinessGroup]
135 changes: 135 additions & 0 deletions policyengine_uk/firm_simulation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Firm-specific simulation class for business microdata analysis."""

from typing import Dict, Optional, Union, List
import numpy as np
import pandas as pd

from policyengine_core.simulations.simulation import Simulation
from policyengine_uk.data.firm_dataset_schema import UKFirmSingleYearDataset
from policyengine_uk.firm_tax_benefit_system import FirmTaxBenefitSystem


class FirmSimulation(Simulation):
"""Firm-specific simulation class for business microdata.

This simulation handles firm, sector, and business_group entities
for business-level policy analysis.
"""

default_input_period: int = 2025
default_calculation_period: int = 2025

def __init__(
self,
dataset: Optional[Union[UKFirmSingleYearDataset, str]] = None,
reform: Optional[Dict] = None,
trace: bool = False,
):
"""Initialize a firm simulation.

Args:
dataset: Firm dataset or path to firm dataset
reform: Optional reform to apply
trace: Whether to enable tracing
"""
# Initialize firm tax benefit system
tax_benefit_system = FirmTaxBenefitSystem()

if reform is not None:
# Apply reform to the tax benefit system if needed
pass

# Build populations from dataset
if isinstance(dataset, str):
dataset = UKFirmSingleYearDataset(file_path=dataset)
elif dataset is None:
raise ValueError("FirmSimulation requires a firm dataset")
elif not isinstance(dataset, UKFirmSingleYearDataset):
raise ValueError(f"Unsupported dataset type: {dataset.__class__}")

# Create populations using the dataset
from policyengine_core.simulations.simulation_builder import (
SimulationBuilder,
)

builder = SimulationBuilder()
builder.populations = tax_benefit_system.instantiate_entities()

# Declare entities
builder.declare_person_entity("firm", dataset.firm.firm_id.values)
builder.declare_entity("sector", dataset.sector.sector_id.values)
builder.declare_entity(
"business_group", dataset.business_group.business_group_id.values
)

# Link firms to sectors and business groups
builder.join_with_persons(
builder.populations["sector"],
dataset.firm.firm_sector_id.values,
np.array(["member"] * len(dataset.firm)),
)
builder.join_with_persons(
builder.populations["business_group"],
dataset.firm.firm_business_group_id.values,
np.array(["member"] * len(dataset.firm)),
)

# Initialize the parent Simulation with populations
super().__init__(
tax_benefit_system=tax_benefit_system,
populations=builder.populations,
)

# Set up tracing if requested
if trace:
self.trace = True

# Load variable values from dataset
for table in dataset.tables:
for variable in table.columns:
if variable not in self.tax_benefit_system.variables:
continue
self.set_input(
variable, dataset.time_period, table[variable].values
)

self.dataset = dataset
self.input_variables = self.get_known_variables()

def get_known_variables(self) -> List[str]:
"""Get list of variables with known values.

Returns:
List of variable names that have values set
"""
known = []
for variable in self.tax_benefit_system.variables:
try:
if len(self.get_holder(variable).get_known_periods()) > 0:
known.append(variable)
except:
pass
return known

def calculate_dataframe(
self,
variable_names: List[str],
period: Optional[int] = None,
) -> pd.DataFrame:
"""Calculate multiple variables and return as DataFrame.

Args:
variable_names: List of variables to calculate
period: Time period for calculation

Returns:
DataFrame with calculated values
"""
if period is None:
period = self.default_calculation_period

data = {}
for variable in variable_names:
data[variable] = self.calculate(variable, period)

return pd.DataFrame(data)
Loading
Loading