From f69b9bd4f05078818128320b3014d77a4a9b853c Mon Sep 17 00:00:00 2001 From: Christopher Cave-Ayland Date: Thu, 27 Jun 2024 17:20:23 +0100 Subject: [PATCH 01/10] First pass at duckdb data interface --- src/muse/new_input/readers.py | 76 +++++++++++++++++ tests/test_readers.py | 149 ++++++++++++++++++++++++++++++++++ 2 files changed, 225 insertions(+) create mode 100644 src/muse/new_input/readers.py diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py new file mode 100644 index 000000000..eafa4fb07 --- /dev/null +++ b/src/muse/new_input/readers.py @@ -0,0 +1,76 @@ +import duckdb +import numpy as np +import xarray as xr + + +def read_inputs(data_dir): + data = {} + con = duckdb.connect(":memory:") + + with open(data_dir / "regions.csv") as f: + regions = read_regions_csv(f, con) # noqa: F841 + + with open(data_dir / "commodities.csv") as f: + commodities = read_commodities_csv(f, con) + + with open(data_dir / "demand.csv") as f: + demand = read_demand_csv(f, con) # noqa: F841 + + data["global_commodities"] = calculate_global_commodities(commodities) + return data + + +def read_regions_csv(buffer_, con): + sql = """CREATE TABLE regions ( + name VARCHAR PRIMARY KEY, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("INSERT INTO regions SELECT name FROM rel;") + return con.sql("SELECT name from regions").fetchnumpy() + + +def read_commodities_csv(buffer_, con): + sql = """CREATE TABLE commodities ( + name VARCHAR PRIMARY KEY, + type VARCHAR CHECK (type IN ('energy', 'service', 'material', 'environmental')), + unit VARCHAR, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("INSERT INTO commodities SELECT name, type, unit FROM rel;") + + return con.sql("select name, type, unit from commodities").fetchnumpy() + + +def calculate_global_commodities(commodities): + names = commodities["name"].astype(np.dtype("str")) + types = commodities["type"].astype(np.dtype("str")) + units = commodities["unit"].astype(np.dtype("str")) + + type_array = xr.DataArray( + data=types, dims=["commodity"], coords=dict(commodity=names) + ) + + unit_array = xr.DataArray( + data=units, dims=["commodity"], coords=dict(commodity=names) + ) + + data = xr.Dataset(data_vars=dict(type=type_array, unit=unit_array)) + return data + + +def read_demand_csv(buffer_, con): + sql = """CREATE TABLE demand ( + year BIGINT, + commodity VARCHAR REFERENCES commodities(name), + region VARCHAR REFERENCES regions(name), + demand DOUBLE, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("INSERT INTO demand SELECT year, commodity_name, region, demand FROM rel;") + return con.sql("SELECT * from demand").fetchnumpy() diff --git a/tests/test_readers.py b/tests/test_readers.py index 924dcacff..107a5bfb1 100644 --- a/tests/test_readers.py +++ b/tests/test_readers.py @@ -1,6 +1,9 @@ +from io import StringIO from itertools import chain, permutations from pathlib import Path +import duckdb +import numpy as np import pandas as pd import toml import xarray as xr @@ -314,3 +317,149 @@ def test_get_nan_coordinates(): dataset3 = xr.Dataset.from_dataframe(df3.set_index(["region", "year"])) nan_coords3 = get_nan_coordinates(dataset3) assert nan_coords3 == [] + + +@fixture +def default_new_input(tmp_path): + from muse.examples import copy_model + + copy_model("default_new_input", tmp_path) + return tmp_path / "model" + + +@fixture +def con(): + return duckdb.connect(":memory:") + + +@fixture +def populate_regions(default_new_input, con): + from muse.new_input.readers import read_regions_csv + + with open(default_new_input / "regions.csv") as f: + return read_regions_csv(f, con) + + +@fixture +def populate_commodities(default_new_input, con): + from muse.new_input.readers import read_commodities_csv + + with open(default_new_input / "commodities.csv") as f: + return read_commodities_csv(f, con) + + +@fixture +def populate_demand(default_new_input, con, populate_regions, populate_commodities): + from muse.new_input.readers import read_demand_csv + + with open(default_new_input / "demand.csv") as f: + return read_demand_csv(f, con) + + +def test_read_regions(populate_regions): + assert populate_regions["name"] == np.array(["R1"]) + + +def test_read_new_global_commodities(populate_commodities): + data = populate_commodities + assert list(data["name"]) == ["electricity", "gas", "heat", "wind", "CO2f"] + assert list(data["type"]) == ["energy"] * 5 + assert list(data["unit"]) == ["PJ"] * 4 + ["kt"] + + +def test_calculate_global_commodities(populate_commodities): + from muse.new_input.readers import calculate_global_commodities + + data = calculate_global_commodities(populate_commodities) + + assert isinstance(data, xr.Dataset) + assert set(data.dims) == {"commodity"} + for dt in data.dtypes.values(): + assert np.issubdtype(dt, np.dtype("str")) + + assert list(data.coords["commodity"].values) == list(populate_commodities["name"]) + assert list(data.data_vars["type"].values) == list(populate_commodities["type"]) + assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"]) + + +def test_read_new_global_commodities_type_constraint(default_new_input, con): + from muse.new_input.readers import read_commodities_csv + + csv = StringIO("name,type,unit\nfoo,invalid,bar\n") + with raises(duckdb.ConstraintException): + read_commodities_csv(csv, con) + + +def test_new_read_demand_csv(populate_demand): + data = populate_demand + assert np.all(data["year"] == np.array([2020, 2050])) + assert np.all(data["commodity"] == np.array(["heat", "heat"])) + assert np.all(data["region"] == np.array(["R1", "R1"])) + assert np.all(data["demand"] == np.array([10, 30])) + + +def test_new_read_demand_csv_commodity_constraint( + default_new_input, con, populate_commodities, populate_regions +): + from muse.new_input.readers import read_demand_csv + + csv = StringIO("year,commodity_name,region,demand\n2020,invalid,R1,0\n") + with raises(duckdb.ConstraintException, match=".*foreign key.*"): + read_demand_csv(csv, con) + + +def test_new_read_demand_csv_region_constraint( + default_new_input, con, populate_commodities, populate_regions +): + from muse.new_input.readers import read_demand_csv + + csv = StringIO("year,commodity_name,region,demand\n2020,heat,invalid,0\n") + with raises(duckdb.ConstraintException, match=".*foreign key.*"): + read_demand_csv(csv, con) + + +@mark.xfail +def test_demand_dataset(default_new_input): + import duckdb + + from muse.new_input.readers import read_commodities, read_demand, read_regions + + con = duckdb.connect(":memory:") + + read_regions(default_new_input, con) + read_commodities(default_new_input, con) + data = read_demand(default_new_input, con) + + assert isinstance(data, xr.DataArray) + assert data.dtype == np.float64 + + assert set(data.dims) == {"year", "commodity", "region", "timeslice"} + assert list(data.coords["region"].values) == ["R1"] + assert list(data.coords["timeslice"].values) == list(range(1, 7)) + assert list(data.coords["year"].values) == [2020, 2050] + assert set(data.coords["commodity"].values) == { + "electricity", + "gas", + "heat", + "wind", + "CO2f", + } + + assert data.sel(year=2020, commodity="electricity", region="R1", timeslice=0) == 1 + + +@mark.xfail +def test_new_read_initial_market(default_new_input): + from muse.new_input.readers import read_inputs + + all_data = read_inputs(default_new_input) + data = all_data["initial_market"] + + assert isinstance(data, xr.Dataset) + assert set(data.dims) == {"region", "year", "commodity", "timeslice"} + assert dict(data.dtypes) == dict( + prices=np.float64, + exports=np.float64, + imports=np.float64, + static_trade=np.float64, + ) From 2685eb09d7f94c81777c096b90c9cdfcf9cbf2dd Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Wed, 3 Jul 2024 16:03:19 +0100 Subject: [PATCH 02/10] New db tables --- src/muse/new_input/readers.py | 111 +++++++++++++++++++++++++--------- 1 file changed, 83 insertions(+), 28 deletions(-) diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py index eafa4fb07..a02f40a84 100644 --- a/src/muse/new_input/readers.py +++ b/src/muse/new_input/readers.py @@ -7,28 +7,26 @@ def read_inputs(data_dir): data = {} con = duckdb.connect(":memory:") - with open(data_dir / "regions.csv") as f: - regions = read_regions_csv(f, con) # noqa: F841 - with open(data_dir / "commodities.csv") as f: commodities = read_commodities_csv(f, con) + with open(data_dir / "commodity_trade.csv") as f: + commodity_trade = read_commodity_trade_csv(f, con) # noqa: F841 + + with open(data_dir / "commodity_costs.csv") as f: + commodity_costs = read_commodity_costs_csv(f, con) # noqa: F841 + with open(data_dir / "demand.csv") as f: demand = read_demand_csv(f, con) # noqa: F841 - data["global_commodities"] = calculate_global_commodities(commodities) - return data + with open(data_dir / "demand_slicing.csv") as f: + demand_slicing = read_demand_slicing_csv(f, con) # noqa: F841 + with open(data_dir / "regions.csv") as f: + regions = read_regions_csv(f, con) # noqa: F841 -def read_regions_csv(buffer_, con): - sql = """CREATE TABLE regions ( - name VARCHAR PRIMARY KEY, - ); - """ - con.sql(sql) - rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 - con.sql("INSERT INTO regions SELECT name FROM rel;") - return con.sql("SELECT name from regions").fetchnumpy() + data["global_commodities"] = calculate_global_commodities(commodities) + return data def read_commodities_csv(buffer_, con): @@ -41,25 +39,38 @@ def read_commodities_csv(buffer_, con): con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 con.sql("INSERT INTO commodities SELECT name, type, unit FROM rel;") - return con.sql("select name, type, unit from commodities").fetchnumpy() -def calculate_global_commodities(commodities): - names = commodities["name"].astype(np.dtype("str")) - types = commodities["type"].astype(np.dtype("str")) - units = commodities["unit"].astype(np.dtype("str")) - - type_array = xr.DataArray( - data=types, dims=["commodity"], coords=dict(commodity=names) - ) +def read_commodity_trade_csv(buffer_, con): + sql = """CREATE TABLE commodity_trade ( + commodity VARCHAR REFERENCES commodities(name), + region VARCHAR REFERENCES regions(name), + year BIGINT, + import DOUBLE, + export DOUBLE, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("""INSERT INTO commodity_trade SELECT + commodity, region, year, import, export FROM rel;""") + return con.sql("SELECT * from commodity_trade").fetchnumpy() - unit_array = xr.DataArray( - data=units, dims=["commodity"], coords=dict(commodity=names) - ) - data = xr.Dataset(data_vars=dict(type=type_array, unit=unit_array)) - return data +def read_commodity_costs_csv(buffer_, con): + sql = """CREATE TABLE commodity_costs ( + year BIGINT, + region VARCHAR REFERENCES regions(name), + commodity VARCHAR REFERENCES commodities(name), + value DOUBLE, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("""INSERT INTO commodity_costs SELECT + year, region, commodity_name, value FROM rel;""") + return con.sql("SELECT * from commodity_costs").fetchnumpy() def read_demand_csv(buffer_, con): @@ -74,3 +85,47 @@ def read_demand_csv(buffer_, con): rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 con.sql("INSERT INTO demand SELECT year, commodity_name, region, demand FROM rel;") return con.sql("SELECT * from demand").fetchnumpy() + + +def read_demand_slicing_csv(buffer_, con): + sql = """CREATE TABLE demand_slicing ( + commodity VARCHAR REFERENCES commodities(name), + region VARCHAR REFERENCES regions(name), + timeslice VARCHAR, + fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1), + year BIGINT, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("""INSERT INTO demand_slicing SELECT + commodity, region, timeslice, fraction, year FROM rel;""") + return con.sql("SELECT * from demand_slicing").fetchnumpy() + + +def read_regions_csv(buffer_, con): + sql = """CREATE TABLE regions ( + name VARCHAR PRIMARY KEY, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("INSERT INTO regions SELECT name FROM rel;") + return con.sql("SELECT name from regions").fetchnumpy() + + +def calculate_global_commodities(commodities): + names = commodities["name"].astype(np.dtype("str")) + types = commodities["type"].astype(np.dtype("str")) + units = commodities["unit"].astype(np.dtype("str")) + + type_array = xr.DataArray( + data=types, dims=["commodity"], coords=dict(commodity=names) + ) + + unit_array = xr.DataArray( + data=units, dims=["commodity"], coords=dict(commodity=names) + ) + + data = xr.Dataset(data_vars=dict(type=type_array, unit=unit_array)) + return data From 71d6388cab4c886140397a5608c1940930a41c7b Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 4 Jul 2024 09:29:16 +0100 Subject: [PATCH 03/10] Update tables for new csv columns --- src/muse/new_input/readers.py | 40 +++++++++++++++++------------------ tests/test_readers.py | 28 ++++++++++++------------ 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py index a02f40a84..b9228f5bf 100644 --- a/src/muse/new_input/readers.py +++ b/src/muse/new_input/readers.py @@ -31,21 +31,21 @@ def read_inputs(data_dir): def read_commodities_csv(buffer_, con): sql = """CREATE TABLE commodities ( - name VARCHAR PRIMARY KEY, + id VARCHAR PRIMARY KEY, type VARCHAR CHECK (type IN ('energy', 'service', 'material', 'environmental')), unit VARCHAR, ); """ con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 - con.sql("INSERT INTO commodities SELECT name, type, unit FROM rel;") - return con.sql("select name, type, unit from commodities").fetchnumpy() + con.sql("INSERT INTO commodities SELECT id, type, unit FROM rel;") + return con.sql("select * from commodities").fetchnumpy() def read_commodity_trade_csv(buffer_, con): sql = """CREATE TABLE commodity_trade ( - commodity VARCHAR REFERENCES commodities(name), - region VARCHAR REFERENCES regions(name), + commodity VARCHAR REFERENCES commodities(id), + region VARCHAR REFERENCES regions(id), year BIGINT, import DOUBLE, export DOUBLE, @@ -54,68 +54,68 @@ def read_commodity_trade_csv(buffer_, con): con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 con.sql("""INSERT INTO commodity_trade SELECT - commodity, region, year, import, export FROM rel;""") + commodity_id, region_id, year, import, export FROM rel;""") return con.sql("SELECT * from commodity_trade").fetchnumpy() def read_commodity_costs_csv(buffer_, con): sql = """CREATE TABLE commodity_costs ( + commodity VARCHAR REFERENCES commodities(id), + region VARCHAR REFERENCES regions(id), year BIGINT, - region VARCHAR REFERENCES regions(name), - commodity VARCHAR REFERENCES commodities(name), value DOUBLE, ); """ con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 con.sql("""INSERT INTO commodity_costs SELECT - year, region, commodity_name, value FROM rel;""") + commidity_id, region_id, year, value FROM rel;""") return con.sql("SELECT * from commodity_costs").fetchnumpy() def read_demand_csv(buffer_, con): sql = """CREATE TABLE demand ( + commodity VARCHAR REFERENCES commodities(id), + region VARCHAR REFERENCES regions(id), year BIGINT, - commodity VARCHAR REFERENCES commodities(name), - region VARCHAR REFERENCES regions(name), demand DOUBLE, ); """ con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 - con.sql("INSERT INTO demand SELECT year, commodity_name, region, demand FROM rel;") + con.sql("INSERT INTO demand SELECT commodity_id, region_id, year, demand FROM rel;") return con.sql("SELECT * from demand").fetchnumpy() def read_demand_slicing_csv(buffer_, con): sql = """CREATE TABLE demand_slicing ( - commodity VARCHAR REFERENCES commodities(name), - region VARCHAR REFERENCES regions(name), + commodity VARCHAR REFERENCES commodities(id), + region VARCHAR REFERENCES regions(id), + year BIGINT, timeslice VARCHAR, fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1), - year BIGINT, ); """ con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 con.sql("""INSERT INTO demand_slicing SELECT - commodity, region, timeslice, fraction, year FROM rel;""") + commodity_id, region_id, year, timeslice, fraction FROM rel;""") return con.sql("SELECT * from demand_slicing").fetchnumpy() def read_regions_csv(buffer_, con): sql = """CREATE TABLE regions ( - name VARCHAR PRIMARY KEY, + id VARCHAR PRIMARY KEY, ); """ con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 - con.sql("INSERT INTO regions SELECT name FROM rel;") - return con.sql("SELECT name from regions").fetchnumpy() + con.sql("INSERT INTO regions SELECT id FROM rel;") + return con.sql("SELECT * from regions").fetchnumpy() def calculate_global_commodities(commodities): - names = commodities["name"].astype(np.dtype("str")) + names = commodities["id"].astype(np.dtype("str")) types = commodities["type"].astype(np.dtype("str")) units = commodities["unit"].astype(np.dtype("str")) diff --git a/tests/test_readers.py b/tests/test_readers.py index 107a5bfb1..221ec097b 100644 --- a/tests/test_readers.py +++ b/tests/test_readers.py @@ -332,14 +332,6 @@ def con(): return duckdb.connect(":memory:") -@fixture -def populate_regions(default_new_input, con): - from muse.new_input.readers import read_regions_csv - - with open(default_new_input / "regions.csv") as f: - return read_regions_csv(f, con) - - @fixture def populate_commodities(default_new_input, con): from muse.new_input.readers import read_commodities_csv @@ -356,13 +348,21 @@ def populate_demand(default_new_input, con, populate_regions, populate_commoditi return read_demand_csv(f, con) +@fixture +def populate_regions(default_new_input, con): + from muse.new_input.readers import read_regions_csv + + with open(default_new_input / "regions.csv") as f: + return read_regions_csv(f, con) + + def test_read_regions(populate_regions): - assert populate_regions["name"] == np.array(["R1"]) + assert populate_regions["id"] == np.array(["R1"]) def test_read_new_global_commodities(populate_commodities): data = populate_commodities - assert list(data["name"]) == ["electricity", "gas", "heat", "wind", "CO2f"] + assert list(data["id"]) == ["electricity", "gas", "heat", "wind", "CO2f"] assert list(data["type"]) == ["energy"] * 5 assert list(data["unit"]) == ["PJ"] * 4 + ["kt"] @@ -377,7 +377,7 @@ def test_calculate_global_commodities(populate_commodities): for dt in data.dtypes.values(): assert np.issubdtype(dt, np.dtype("str")) - assert list(data.coords["commodity"].values) == list(populate_commodities["name"]) + assert list(data.coords["commodity"].values) == list(populate_commodities["id"]) assert list(data.data_vars["type"].values) == list(populate_commodities["type"]) assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"]) @@ -385,7 +385,7 @@ def test_calculate_global_commodities(populate_commodities): def test_read_new_global_commodities_type_constraint(default_new_input, con): from muse.new_input.readers import read_commodities_csv - csv = StringIO("name,type,unit\nfoo,invalid,bar\n") + csv = StringIO("id,type,unit\nfoo,invalid,bar\n") with raises(duckdb.ConstraintException): read_commodities_csv(csv, con) @@ -403,7 +403,7 @@ def test_new_read_demand_csv_commodity_constraint( ): from muse.new_input.readers import read_demand_csv - csv = StringIO("year,commodity_name,region,demand\n2020,invalid,R1,0\n") + csv = StringIO("year,commodity_id,region_id,demand\n2020,invalid,R1,0\n") with raises(duckdb.ConstraintException, match=".*foreign key.*"): read_demand_csv(csv, con) @@ -413,7 +413,7 @@ def test_new_read_demand_csv_region_constraint( ): from muse.new_input.readers import read_demand_csv - csv = StringIO("year,commodity_name,region,demand\n2020,heat,invalid,0\n") + csv = StringIO("year,commodity_id,region_id,demand\n2020,heat,invalid,0\n") with raises(duckdb.ConstraintException, match=".*foreign key.*"): read_demand_csv(csv, con) From e3b6ece3b450e8263b5c6edc50629381d630aa8d Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Fri, 5 Jul 2024 14:15:29 +0100 Subject: [PATCH 04/10] Split new tests into new file --- tests/test_new_readers.py | 222 ++++++++++++++++++++++++++++++++++++++ tests/test_readers.py | 173 ----------------------------- 2 files changed, 222 insertions(+), 173 deletions(-) create mode 100644 tests/test_new_readers.py diff --git a/tests/test_new_readers.py b/tests/test_new_readers.py new file mode 100644 index 000000000..6c82434d6 --- /dev/null +++ b/tests/test_new_readers.py @@ -0,0 +1,222 @@ +from io import StringIO + +import duckdb +import numpy as np +import xarray as xr +from pytest import approx, fixture, mark, raises + + +@fixture +def default_new_input(tmp_path): + from muse.examples import copy_model + + copy_model("default_new_input", tmp_path) + return tmp_path / "model" + + +@fixture +def con(): + return duckdb.connect(":memory:") + + +@fixture +def populate_commodities(default_new_input, con): + from muse.new_input.readers import read_commodities_csv + + with open(default_new_input / "commodities.csv") as f: + return read_commodities_csv(f, con) + + +@fixture +def populate_demand(default_new_input, con, populate_regions, populate_commodities): + from muse.new_input.readers import read_demand_csv + + with open(default_new_input / "demand.csv") as f: + return read_demand_csv(f, con) + + +@fixture +def populate_regions(default_new_input, con): + from muse.new_input.readers import read_regions_csv + + with open(default_new_input / "regions.csv") as f: + return read_regions_csv(f, con) + + +def test_read_regions(populate_regions): + assert populate_regions["id"] == np.array(["R1"]) + + +def test_read_new_global_commodities(populate_commodities): + data = populate_commodities + assert list(data["id"]) == ["electricity", "gas", "heat", "wind", "CO2f"] + assert list(data["type"]) == ["energy"] * 5 + assert list(data["unit"]) == ["PJ"] * 4 + ["kt"] + + +def test_calculate_global_commodities(populate_commodities): + from muse.new_input.readers import calculate_global_commodities + + data = calculate_global_commodities(populate_commodities) + + assert isinstance(data, xr.Dataset) + assert set(data.dims) == {"commodity"} + for dt in data.dtypes.values(): + assert np.issubdtype(dt, np.dtype("str")) + + assert list(data.coords["commodity"].values) == list(populate_commodities["id"]) + assert list(data.data_vars["type"].values) == list(populate_commodities["type"]) + assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"]) + + +def test_read_new_global_commodities_type_constraint(default_new_input, con): + from muse.new_input.readers import read_commodities_csv + + csv = StringIO("id,type,unit\nfoo,invalid,bar\n") + with raises(duckdb.ConstraintException): + read_commodities_csv(csv, con) + + +def test_new_read_demand_csv(populate_demand): + data = populate_demand + assert np.all(data["year"] == np.array([2020, 2050])) + assert np.all(data["commodity"] == np.array(["heat", "heat"])) + assert np.all(data["region"] == np.array(["R1", "R1"])) + assert np.all(data["demand"] == np.array([10, 30])) + + +def test_new_read_demand_csv_commodity_constraint( + default_new_input, con, populate_commodities, populate_regions +): + from muse.new_input.readers import read_demand_csv + + csv = StringIO("year,commodity_id,region_id,demand\n2020,invalid,R1,0\n") + with raises(duckdb.ConstraintException, match=".*foreign key.*"): + read_demand_csv(csv, con) + + +def test_new_read_demand_csv_region_constraint( + default_new_input, con, populate_commodities, populate_regions +): + from muse.new_input.readers import read_demand_csv + + csv = StringIO("year,commodity_id,region_id,demand\n2020,heat,invalid,0\n") + with raises(duckdb.ConstraintException, match=".*foreign key.*"): + read_demand_csv(csv, con) + + +@mark.xfail +def test_demand_dataset(default_new_input): + import duckdb + + from muse.new_input.readers import read_commodities, read_demand, read_regions + + con = duckdb.connect(":memory:") + + read_regions(default_new_input, con) + read_commodities(default_new_input, con) + data = read_demand(default_new_input, con) + + assert isinstance(data, xr.DataArray) + assert data.dtype == np.float64 + + assert set(data.dims) == {"year", "commodity", "region", "timeslice"} + assert list(data.coords["region"].values) == ["R1"] + assert list(data.coords["timeslice"].values) == list(range(1, 7)) + assert list(data.coords["year"].values) == [2020, 2050] + assert set(data.coords["commodity"].values) == { + "electricity", + "gas", + "heat", + "wind", + "CO2f", + } + + assert data.sel(year=2020, commodity="electricity", region="R1", timeslice=0) == 1 + + +@mark.xfail +def test_new_read_initial_market(default_new_input): + from muse.new_input.readers import read_inputs + + all_data = read_inputs(default_new_input) + data = all_data["initial_market"] + + assert isinstance(data, xr.Dataset) + assert set(data.dims) == {"region", "year", "commodity", "timeslice"} + assert dict(data.dtypes) == dict( + prices=np.float64, + exports=np.float64, + imports=np.float64, + static_trade=np.float64, + ) + assert list(data.coords["region"].values) == ["R1"] + assert list(data.coords["year"].values) == list(range(2010, 2105, 5)) + assert list(data.coords["commodity"].values) == [ + "electricity", + "gas", + "heat", + "CO2f", + "wind", + ] + month_values = ["all-year"] * 6 + day_values = ["all-week"] * 6 + hour_values = [ + "night", + "morning", + "afternoon", + "early-peak", + "late-peak", + "evening", + ] + + assert list(data.coords["timeslice"].values) == list( + zip(month_values, day_values, hour_values) + ) + assert list(data.coords["month"]) == month_values + assert list(data.coords["day"]) == day_values + assert list(data.coords["hour"]) == hour_values + + assert all(var.coords.equals(data.coords) for var in data.data_vars.values()) + + prices = data.data_vars["prices"] + assert approx( + prices.sel( + year=2010, + region="R1", + commodity="electricity", + timeslice=("all-year", "all-week", "night"), + ) + - 14.81481, + abs=1e-4, + ) + + exports = data.data_vars["exports"] + assert ( + exports.sel( + year=2010, + region="R1", + commodity="electricity", + timeslice=("all-year", "all-week", "night"), + ) + ) == 0 + + imports = data.data_vars["imports"] + assert ( + imports.sel( + year=2010, + region="R1", + commodity="electricity", + timeslice=("all-year", "all-week", "night"), + ) + ) == 0 + + static_trade = data.data_vars["static_trade"] + assert ( + static_trade.sel( + year=2010, + region="R1", + commodity="electricity", + timeslice=("all-year", "all-week", "night"), + ) + ) == 0 diff --git a/tests/test_readers.py b/tests/test_readers.py index 221ec097b..e0b4bcd63 100644 --- a/tests/test_readers.py +++ b/tests/test_readers.py @@ -1,9 +1,6 @@ -from io import StringIO from itertools import chain, permutations from pathlib import Path -import duckdb -import numpy as np import pandas as pd import toml import xarray as xr @@ -293,173 +290,3 @@ def test_get_nan_coordinates(): dataset1 = xr.Dataset.from_dataframe(df1.set_index(["region", "year"])) nan_coords1 = get_nan_coordinates(dataset1) assert nan_coords1 == [("R1", 2021)] - - # Test 2: Missing coordinate combinations - df2 = pd.DataFrame( - { - "region": ["R1", "R1", "R2"], # Missing R2-2021 - "year": [2020, 2021, 2020], - "value": [1.0, 2.0, 3.0], - } - ) - dataset2 = xr.Dataset.from_dataframe(df2.set_index(["region", "year"])) - nan_coords2 = get_nan_coordinates(dataset2) - assert nan_coords2 == [("R2", 2021)] - - # Test 3: No NaN values - df3 = pd.DataFrame( - { - "region": ["R1", "R1", "R2", "R2"], - "year": [2020, 2021, 2020, 2021], - "value": [1.0, 2.0, 3.0, 4.0], - } - ) - dataset3 = xr.Dataset.from_dataframe(df3.set_index(["region", "year"])) - nan_coords3 = get_nan_coordinates(dataset3) - assert nan_coords3 == [] - - -@fixture -def default_new_input(tmp_path): - from muse.examples import copy_model - - copy_model("default_new_input", tmp_path) - return tmp_path / "model" - - -@fixture -def con(): - return duckdb.connect(":memory:") - - -@fixture -def populate_commodities(default_new_input, con): - from muse.new_input.readers import read_commodities_csv - - with open(default_new_input / "commodities.csv") as f: - return read_commodities_csv(f, con) - - -@fixture -def populate_demand(default_new_input, con, populate_regions, populate_commodities): - from muse.new_input.readers import read_demand_csv - - with open(default_new_input / "demand.csv") as f: - return read_demand_csv(f, con) - - -@fixture -def populate_regions(default_new_input, con): - from muse.new_input.readers import read_regions_csv - - with open(default_new_input / "regions.csv") as f: - return read_regions_csv(f, con) - - -def test_read_regions(populate_regions): - assert populate_regions["id"] == np.array(["R1"]) - - -def test_read_new_global_commodities(populate_commodities): - data = populate_commodities - assert list(data["id"]) == ["electricity", "gas", "heat", "wind", "CO2f"] - assert list(data["type"]) == ["energy"] * 5 - assert list(data["unit"]) == ["PJ"] * 4 + ["kt"] - - -def test_calculate_global_commodities(populate_commodities): - from muse.new_input.readers import calculate_global_commodities - - data = calculate_global_commodities(populate_commodities) - - assert isinstance(data, xr.Dataset) - assert set(data.dims) == {"commodity"} - for dt in data.dtypes.values(): - assert np.issubdtype(dt, np.dtype("str")) - - assert list(data.coords["commodity"].values) == list(populate_commodities["id"]) - assert list(data.data_vars["type"].values) == list(populate_commodities["type"]) - assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"]) - - -def test_read_new_global_commodities_type_constraint(default_new_input, con): - from muse.new_input.readers import read_commodities_csv - - csv = StringIO("id,type,unit\nfoo,invalid,bar\n") - with raises(duckdb.ConstraintException): - read_commodities_csv(csv, con) - - -def test_new_read_demand_csv(populate_demand): - data = populate_demand - assert np.all(data["year"] == np.array([2020, 2050])) - assert np.all(data["commodity"] == np.array(["heat", "heat"])) - assert np.all(data["region"] == np.array(["R1", "R1"])) - assert np.all(data["demand"] == np.array([10, 30])) - - -def test_new_read_demand_csv_commodity_constraint( - default_new_input, con, populate_commodities, populate_regions -): - from muse.new_input.readers import read_demand_csv - - csv = StringIO("year,commodity_id,region_id,demand\n2020,invalid,R1,0\n") - with raises(duckdb.ConstraintException, match=".*foreign key.*"): - read_demand_csv(csv, con) - - -def test_new_read_demand_csv_region_constraint( - default_new_input, con, populate_commodities, populate_regions -): - from muse.new_input.readers import read_demand_csv - - csv = StringIO("year,commodity_id,region_id,demand\n2020,heat,invalid,0\n") - with raises(duckdb.ConstraintException, match=".*foreign key.*"): - read_demand_csv(csv, con) - - -@mark.xfail -def test_demand_dataset(default_new_input): - import duckdb - - from muse.new_input.readers import read_commodities, read_demand, read_regions - - con = duckdb.connect(":memory:") - - read_regions(default_new_input, con) - read_commodities(default_new_input, con) - data = read_demand(default_new_input, con) - - assert isinstance(data, xr.DataArray) - assert data.dtype == np.float64 - - assert set(data.dims) == {"year", "commodity", "region", "timeslice"} - assert list(data.coords["region"].values) == ["R1"] - assert list(data.coords["timeslice"].values) == list(range(1, 7)) - assert list(data.coords["year"].values) == [2020, 2050] - assert set(data.coords["commodity"].values) == { - "electricity", - "gas", - "heat", - "wind", - "CO2f", - } - - assert data.sel(year=2020, commodity="electricity", region="R1", timeslice=0) == 1 - - -@mark.xfail -def test_new_read_initial_market(default_new_input): - from muse.new_input.readers import read_inputs - - all_data = read_inputs(default_new_input) - data = all_data["initial_market"] - - assert isinstance(data, xr.Dataset) - assert set(data.dims) == {"region", "year", "commodity", "timeslice"} - assert dict(data.dtypes) == dict( - prices=np.float64, - exports=np.float64, - imports=np.float64, - static_trade=np.float64, - ) From 5b827cde3601235da2130d2a2636404fd3c1a608 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 8 Jul 2024 10:28:08 +0100 Subject: [PATCH 05/10] Tests for new tables --- src/muse/new_input/readers.py | 2 +- tests/test_new_readers.py | 80 ++++++++++++++++++++++++++++------- 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py index b9228f5bf..4e2c09de0 100644 --- a/src/muse/new_input/readers.py +++ b/src/muse/new_input/readers.py @@ -69,7 +69,7 @@ def read_commodity_costs_csv(buffer_, con): con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 con.sql("""INSERT INTO commodity_costs SELECT - commidity_id, region_id, year, value FROM rel;""") + commodity_id, region_id, year, value FROM rel;""") return con.sql("SELECT * from commodity_costs").fetchnumpy() diff --git a/tests/test_new_readers.py b/tests/test_new_readers.py index 6c82434d6..467215bba 100644 --- a/tests/test_new_readers.py +++ b/tests/test_new_readers.py @@ -27,6 +27,26 @@ def populate_commodities(default_new_input, con): return read_commodities_csv(f, con) +@fixture +def populate_commodity_trade( + default_new_input, con, populate_commodities, populate_regions +): + from muse.new_input.readers import read_commodity_trade_csv + + with open(default_new_input / "commodity_trade.csv") as f: + return read_commodity_trade_csv(f, con) + + +@fixture +def populate_commodity_costs( + default_new_input, con, populate_commodities, populate_regions +): + from muse.new_input.readers import read_commodity_costs_csv + + with open(default_new_input / "commodity_costs.csv") as f: + return read_commodity_costs_csv(f, con) + + @fixture def populate_demand(default_new_input, con, populate_regions, populate_commodities): from muse.new_input.readers import read_demand_csv @@ -35,6 +55,16 @@ def populate_demand(default_new_input, con, populate_regions, populate_commoditi return read_demand_csv(f, con) +@fixture +def populate_demand_slicing( + default_new_input, con, populate_regions, populate_commodities +): + from muse.new_input.readers import read_demand_slicing_csv + + with open(default_new_input / "demand_slicing.csv") as f: + return read_demand_slicing_csv(f, con) + + @fixture def populate_regions(default_new_input, con): from muse.new_input.readers import read_regions_csv @@ -43,17 +73,43 @@ def populate_regions(default_new_input, con): return read_regions_csv(f, con) -def test_read_regions(populate_regions): - assert populate_regions["id"] == np.array(["R1"]) - - -def test_read_new_global_commodities(populate_commodities): +def test_read_commodities_csv(populate_commodities): data = populate_commodities assert list(data["id"]) == ["electricity", "gas", "heat", "wind", "CO2f"] assert list(data["type"]) == ["energy"] * 5 assert list(data["unit"]) == ["PJ"] * 4 + ["kt"] +def test_read_commodity_trade_csv(populate_commodity_trade): + data = populate_commodity_trade + assert data["commodity"].size == 0 + assert data["region"].size == 0 + assert data["year"].size == 0 + assert data["import"].size == 0 + assert data["export"].size == 0 + + +def test_read_commodity_costs_csv(populate_commodity_costs): + data = populate_commodity_costs + # Only checking the first element of each array, as the table is large + assert next(iter(data["commodity"])) == "electricity" + assert next(iter(data["region"])) == "R1" + assert next(iter(data["year"])) == 2010 + assert next(iter(data["value"])) == approx(14.81481) + + +def test_read_demand_csv(populate_demand): + data = populate_demand + assert np.all(data["year"] == np.array([2020, 2050])) + assert np.all(data["commodity"] == np.array(["heat", "heat"])) + assert np.all(data["region"] == np.array(["R1", "R1"])) + assert np.all(data["demand"] == np.array([10, 30])) + + +def test_read_regions_csv(populate_regions): + assert populate_regions["id"] == np.array(["R1"]) + + def test_calculate_global_commodities(populate_commodities): from muse.new_input.readers import calculate_global_commodities @@ -69,7 +125,7 @@ def test_calculate_global_commodities(populate_commodities): assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"]) -def test_read_new_global_commodities_type_constraint(default_new_input, con): +def test_read_global_commodities_type_constraint(default_new_input, con): from muse.new_input.readers import read_commodities_csv csv = StringIO("id,type,unit\nfoo,invalid,bar\n") @@ -77,15 +133,7 @@ def test_read_new_global_commodities_type_constraint(default_new_input, con): read_commodities_csv(csv, con) -def test_new_read_demand_csv(populate_demand): - data = populate_demand - assert np.all(data["year"] == np.array([2020, 2050])) - assert np.all(data["commodity"] == np.array(["heat", "heat"])) - assert np.all(data["region"] == np.array(["R1", "R1"])) - assert np.all(data["demand"] == np.array([10, 30])) - - -def test_new_read_demand_csv_commodity_constraint( +def test_read_demand_csv_commodity_constraint( default_new_input, con, populate_commodities, populate_regions ): from muse.new_input.readers import read_demand_csv @@ -95,7 +143,7 @@ def test_new_read_demand_csv_commodity_constraint( read_demand_csv(csv, con) -def test_new_read_demand_csv_region_constraint( +def test_read_demand_csv_region_constraint( default_new_input, con, populate_commodities, populate_regions ): from muse.new_input.readers import read_demand_csv From 73014f16607c6b2073308f84a0b84f2cd7d602a9 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 15 Aug 2024 10:39:58 +0100 Subject: [PATCH 06/10] Add functions for demand data and (in progress) initial market --- src/muse/new_input/readers.py | 264 +++++++++++++++++++++++++++++++--- tests/test_new_readers.py | 127 ++++++++++------ 2 files changed, 329 insertions(+), 62 deletions(-) diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py index 4e2c09de0..b216ba0d4 100644 --- a/src/muse/new_input/readers.py +++ b/src/muse/new_input/readers.py @@ -1,5 +1,6 @@ import duckdb import numpy as np +import pandas as pd import xarray as xr @@ -7,28 +8,54 @@ def read_inputs(data_dir): data = {} con = duckdb.connect(":memory:") + with open(data_dir / "timeslices.csv") as f: + timeslices = read_timeslices_csv(f, con) + with open(data_dir / "commodities.csv") as f: commodities = read_commodities_csv(f, con) + with open(data_dir / "regions.csv") as f: + regions = read_regions_csv(f, con) + with open(data_dir / "commodity_trade.csv") as f: - commodity_trade = read_commodity_trade_csv(f, con) # noqa: F841 + commodity_trade = read_commodity_trade_csv(f, con) with open(data_dir / "commodity_costs.csv") as f: - commodity_costs = read_commodity_costs_csv(f, con) # noqa: F841 + commodity_costs = read_commodity_costs_csv(f, con) with open(data_dir / "demand.csv") as f: - demand = read_demand_csv(f, con) # noqa: F841 + demand = read_demand_csv(f, con) with open(data_dir / "demand_slicing.csv") as f: - demand_slicing = read_demand_slicing_csv(f, con) # noqa: F841 - - with open(data_dir / "regions.csv") as f: - regions = read_regions_csv(f, con) # noqa: F841 + demand_slicing = read_demand_slicing_csv(f, con) data["global_commodities"] = calculate_global_commodities(commodities) + data["demand"] = calculate_demand( + commodities, regions, timeslices, demand, demand_slicing + ) + data["initial_market"] = calculate_initial_market( + commodities, regions, timeslices, commodity_trade, commodity_costs + ) return data +def read_timeslices_csv(buffer_, con): + sql = """CREATE TABLE timeslices ( + id VARCHAR PRIMARY KEY, + season VARCHAR, + day VARCHAR, + time_of_day VARCHAR, + fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1), + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql( + "INSERT INTO timeslices SELECT id, season, day, time_of_day, fraction FROM rel;" + ) + return con.sql("SELECT * from timeslices").fetchnumpy() + + def read_commodities_csv(buffer_, con): sql = """CREATE TABLE commodities ( id VARCHAR PRIMARY KEY, @@ -42,6 +69,17 @@ def read_commodities_csv(buffer_, con): return con.sql("select * from commodities").fetchnumpy() +def read_regions_csv(buffer_, con): + sql = """CREATE TABLE regions ( + id VARCHAR PRIMARY KEY, + ); + """ + con.sql(sql) + rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 + con.sql("INSERT INTO regions SELECT id FROM rel;") + return con.sql("SELECT * from regions").fetchnumpy() + + def read_commodity_trade_csv(buffer_, con): sql = """CREATE TABLE commodity_trade ( commodity VARCHAR REFERENCES commodities(id), @@ -49,6 +87,7 @@ def read_commodity_trade_csv(buffer_, con): year BIGINT, import DOUBLE, export DOUBLE, + PRIMARY KEY (commodity, region, year) ); """ con.sql(sql) @@ -64,6 +103,7 @@ def read_commodity_costs_csv(buffer_, con): region VARCHAR REFERENCES regions(id), year BIGINT, value DOUBLE, + PRIMARY KEY (commodity, region, year) ); """ con.sql(sql) @@ -79,6 +119,7 @@ def read_demand_csv(buffer_, con): region VARCHAR REFERENCES regions(id), year BIGINT, demand DOUBLE, + PRIMARY KEY (commodity, region, year) ); """ con.sql(sql) @@ -92,28 +133,19 @@ def read_demand_slicing_csv(buffer_, con): commodity VARCHAR REFERENCES commodities(id), region VARCHAR REFERENCES regions(id), year BIGINT, - timeslice VARCHAR, + timeslice VARCHAR REFERENCES timeslices(id), fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1), + PRIMARY KEY (commodity, region, year, timeslice), + FOREIGN KEY (commodity, region, year) REFERENCES demand(commodity, region, year) ); """ con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 con.sql("""INSERT INTO demand_slicing SELECT - commodity_id, region_id, year, timeslice, fraction FROM rel;""") + commodity_id, region_id, year, timeslice_id, fraction FROM rel;""") return con.sql("SELECT * from demand_slicing").fetchnumpy() -def read_regions_csv(buffer_, con): - sql = """CREATE TABLE regions ( - id VARCHAR PRIMARY KEY, - ); - """ - con.sql(sql) - rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 - con.sql("INSERT INTO regions SELECT id FROM rel;") - return con.sql("SELECT * from regions").fetchnumpy() - - def calculate_global_commodities(commodities): names = commodities["id"].astype(np.dtype("str")) types = commodities["type"].astype(np.dtype("str")) @@ -129,3 +161,195 @@ def calculate_global_commodities(commodities): data = xr.Dataset(data_vars=dict(type=type_array, unit=unit_array)) return data + + +def calculate_demand( + commodities, regions, timeslices, demand, demand_slicing +) -> xr.DataArray: + """Calculate demand data for all commodities, regions, years, and timeslices. + + Result: A DataArray with a demand value for every combination of: + - commodity: all commodities specified in the commodities table + - region: all regions specified in the regions table + - year: all years specified in the demand table + - timeslice: all timeslices specified in the timeslices table + + Checks: + - If demand data is specified for one year, it must be specified for all years. + - If demand is nonzero, slicing data must be present. + - If slicing data is specified for a commodity/region/year, the sum of + the fractions must be 1, and all timeslices must be present. + + Fills: + - If demand data is not specified for a commodity/region combination, the demand is + 0 for all years and timeslices. + + Todo: + - Interpolation to allow for missing years in demand data. + - Ability to leave the year field blank in both tables to indicate all years + - Allow slicing data to be missing -> demand is spread equally across timeslices + - Allow more flexibility for timeslices (e.g. can specify "winter" to apply to all + winter timeslices, or "all" to apply to all timeslices) + """ + # Prepare dataframes + df_demand = pd.DataFrame(demand).set_index(["commodity", "region", "year"]) + df_slicing = pd.DataFrame(demand_slicing).set_index( + ["commodity", "region", "year", "timeslice"] + ) + + # DataArray dimensions + all_commodities = commodities["id"].astype(np.dtype("str")) + all_regions = regions["id"].astype(np.dtype("str")) + all_years = df_demand.index.get_level_values("year").unique() + all_timeslices = timeslices["id"].astype(np.dtype("str")) + + # CHECK: all years are specified for each commodity/region combination + check_all_values_specified(df_demand, ["commodity", "region"], "year", all_years) + + # CHECK: if slicing data is present, all timeslices must be specified + check_all_values_specified( + df_slicing, ["commodity", "region", "year"], "timeslice", all_timeslices + ) + + # CHECK: timeslice fractions sum to 1 + check_timeslice_sum = df_slicing.groupby(["commodity", "region", "year"]).apply( + lambda x: np.isclose(x["fraction"].sum(), 1) + ) + if not check_timeslice_sum.all(): + raise DataValidationError + + # CHECK: if demand data >0, fraction data must be specified + check_fraction_data_present = ( + df_demand[df_demand["demand"] > 0] + .index.isin(df_slicing.droplevel("timeslice").index) + .all() + ) + if not check_fraction_data_present.all(): + raise DataValidationError + + # FILL: demand is zero if unspecified + df_demand = df_demand.reindex( + pd.MultiIndex.from_product( + [all_commodities, all_regions, all_years], + names=["commodity", "region", "year"], + ), + fill_value=0, + ) + + # FILL: slice data is zero if unspecified + df_slicing = df_slicing.reindex( + pd.MultiIndex.from_product( + [all_commodities, all_regions, all_years, all_timeslices], + names=["commodity", "region", "year", "timeslice"], + ), + fill_value=0, + ) + + # Create DataArray + da_demand = df_demand.to_xarray()["demand"] + da_slicing = df_slicing.to_xarray()["fraction"] + data = da_demand * da_slicing + return data + + +def calculate_initial_market( + commodities, regions, timeslices, commodity_trade, commodity_costs +) -> xr.Dataset: + """Calculate trade and price data for all commodities, regions and years. + + Result: A Dataset with variables: + - prices + - exports + - imports + - static_trade + For every combination of: + - commodity: all commodities specified in the commodities table + - region: all regions specified in the regions table + - year: all years specified in the commodity_costs table + - timeslice (multiindex): all timeslices specified in the timeslices table + + Checks: + - If trade data is specified for one year, it must be specified for all years. + - If price data is specified for one year, it must be specified for all years. + + Fills: + - If trade data is not specified for a commodity/region combination, imports and + exports are both zero + - If price data is not specified for a commodity/region combination, the price is + zero + + """ + from muse.timeslices import QuantityType, convert_timeslice + + # Prepare dataframes + df_trade = pd.DataFrame(commodity_trade).set_index(["commodity", "region", "year"]) + df_costs = ( + pd.DataFrame(commodity_costs) + .set_index(["commodity", "region", "year"]) + .rename(columns={"value": "prices"}) + ) + df_timeslices = pd.DataFrame(timeslices).set_index(["season", "day", "time_of_day"]) + + # DataArray dimensions + all_commodities = commodities["id"].astype(np.dtype("str")) + all_regions = regions["id"].astype(np.dtype("str")) + all_years = df_costs.index.get_level_values("year").unique() + + # CHECK: all years are specified for each commodity/region combination + check_all_values_specified(df_trade, ["commodity", "region"], "year", all_years) + check_all_values_specified(df_costs, ["commodity", "region"], "year", all_years) + + # FILL: price is zero if unspecified + df_costs = df_costs.reindex( + pd.MultiIndex.from_product( + [all_commodities, all_regions, all_years], + names=["commodity", "region", "year"], + ), + fill_value=0, + ) + + # FILL: trade is zero if unspecified + df_trade = df_trade.reindex( + pd.MultiIndex.from_product( + [all_commodities, all_regions, all_years], + names=["commodity", "region", "year"], + ), + fill_value=0, + ) + + # Calculate static trade + df_trade["static_trade"] = df_trade["export"] - df_trade["import"] + + # Create Data + df_full = df_costs.join(df_trade) + data = df_full.to_xarray() + ts = df_timeslices.to_xarray()["fraction"] + ts = ts.stack(timeslice=("season", "day", "time_of_day")) + convert_timeslice(data, ts, QuantityType.EXTENSIVE) + + return data + + +class DataValidationError(ValueError): + pass + + +def check_all_values_specified( + df: pd.DataFrame, group_by_cols: list[str], column_name: str, values: list +) -> None: + """Check that the required values are specified in a dataframe. + + Checks that a row exists for all specified values of column_name for each + group in the grouped dataframe. + """ + if not ( + df.groupby(group_by_cols) + .apply( + lambda x: ( + set(x.index.get_level_values(column_name).unique()) == set(values) + ) + ) + .all() + ).all(): + msg = "" # TODO + raise DataValidationError(msg) diff --git a/tests/test_new_readers.py b/tests/test_new_readers.py index 467215bba..483d42627 100644 --- a/tests/test_new_readers.py +++ b/tests/test_new_readers.py @@ -57,7 +57,7 @@ def populate_demand(default_new_input, con, populate_regions, populate_commoditi @fixture def populate_demand_slicing( - default_new_input, con, populate_regions, populate_commodities + default_new_input, con, populate_regions, populate_commodities, populate_demand ): from muse.new_input.readers import read_demand_slicing_csv @@ -73,6 +73,28 @@ def populate_regions(default_new_input, con): return read_regions_csv(f, con) +@fixture +def populate_timeslices(default_new_input, con): + from muse.new_input.readers import read_timeslices_csv + + with open(default_new_input / "timeslices.csv") as f: + return read_timeslices_csv(f, con) + + +def test_read_timeslices_csv(populate_timeslices): + data = populate_timeslices + assert len(data["id"]) == 6 + assert next(iter(data["id"])) == "1" + assert next(iter(data["season"])) == "all" + assert next(iter(data["day"])) == "all" + assert next(iter(data["time_of_day"])) == "night" + assert next(iter(data["fraction"])) == approx(0.1667) + + +def test_read_regions_csv(populate_regions): + assert populate_regions["id"] == np.array(["R1"]) + + def test_read_commodities_csv(populate_commodities): data = populate_commodities assert list(data["id"]) == ["electricity", "gas", "heat", "wind", "CO2f"] @@ -106,26 +128,18 @@ def test_read_demand_csv(populate_demand): assert np.all(data["demand"] == np.array([10, 30])) -def test_read_regions_csv(populate_regions): - assert populate_regions["id"] == np.array(["R1"]) - - -def test_calculate_global_commodities(populate_commodities): - from muse.new_input.readers import calculate_global_commodities - - data = calculate_global_commodities(populate_commodities) - - assert isinstance(data, xr.Dataset) - assert set(data.dims) == {"commodity"} - for dt in data.dtypes.values(): - assert np.issubdtype(dt, np.dtype("str")) - - assert list(data.coords["commodity"].values) == list(populate_commodities["id"]) - assert list(data.data_vars["type"].values) == list(populate_commodities["type"]) - assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"]) +def test_read_demand_slicing_csv(populate_demand_slicing): + data = populate_demand_slicing + assert np.all(data["commodity"] == "heat") + assert np.all(data["region"] == "R1") + # assert np.all(data["timeslice"] == np.array([0, 1])) + assert np.all( + data["fraction"] + == np.array([0.1, 0.15, 0.1, 0.15, 0.3, 0.2, 0.1, 0.15, 0.1, 0.15, 0.3, 0.2]) + ) -def test_read_global_commodities_type_constraint(default_new_input, con): +def test_read_commodities_csv_type_constraint(con): from muse.new_input.readers import read_commodities_csv csv = StringIO("id,type,unit\nfoo,invalid,bar\n") @@ -134,7 +148,7 @@ def test_read_global_commodities_type_constraint(default_new_input, con): def test_read_demand_csv_commodity_constraint( - default_new_input, con, populate_commodities, populate_regions + con, populate_commodities, populate_regions ): from muse.new_input.readers import read_demand_csv @@ -143,9 +157,7 @@ def test_read_demand_csv_commodity_constraint( read_demand_csv(csv, con) -def test_read_demand_csv_region_constraint( - default_new_input, con, populate_commodities, populate_regions -): +def test_read_demand_csv_region_constraint(con, populate_commodities, populate_regions): from muse.new_input.readers import read_demand_csv csv = StringIO("year,commodity_id,region_id,demand\n2020,heat,invalid,0\n") @@ -153,24 +165,44 @@ def test_read_demand_csv_region_constraint( read_demand_csv(csv, con) -@mark.xfail -def test_demand_dataset(default_new_input): - import duckdb +def test_calculate_global_commodities(populate_commodities): + from muse.new_input.readers import calculate_global_commodities + + data = calculate_global_commodities(populate_commodities) + + assert isinstance(data, xr.Dataset) + assert set(data.dims) == {"commodity"} + for dt in data.dtypes.values(): + assert np.issubdtype(dt, np.dtype("str")) - from muse.new_input.readers import read_commodities, read_demand, read_regions + assert list(data.coords["commodity"].values) == list(populate_commodities["id"]) + assert list(data.data_vars["type"].values) == list(populate_commodities["type"]) + assert list(data.data_vars["unit"].values) == list(populate_commodities["unit"]) - con = duckdb.connect(":memory:") - read_regions(default_new_input, con) - read_commodities(default_new_input, con) - data = read_demand(default_new_input, con) +def test_calculate_demand( + populate_commodities, + populate_regions, + populate_timeslices, + populate_demand, + populate_demand_slicing, +): + from muse.new_input.readers import calculate_demand + + data = calculate_demand( + populate_commodities, + populate_regions, + populate_timeslices, + populate_demand, + populate_demand_slicing, + ) assert isinstance(data, xr.DataArray) assert data.dtype == np.float64 assert set(data.dims) == {"year", "commodity", "region", "timeslice"} assert list(data.coords["region"].values) == ["R1"] - assert list(data.coords["timeslice"].values) == list(range(1, 7)) + assert list(data.coords["timeslice"].values) == ["1", "2", "3", "4", "5", "6"] assert list(data.coords["year"].values) == [2020, 2050] assert set(data.coords["commodity"].values) == { "electricity", @@ -180,15 +212,26 @@ def test_demand_dataset(default_new_input): "CO2f", } - assert data.sel(year=2020, commodity="electricity", region="R1", timeslice=0) == 1 + assert data.sel(year=2020, commodity="heat", region="R1", timeslice="1") == 1 @mark.xfail -def test_new_read_initial_market(default_new_input): - from muse.new_input.readers import read_inputs - - all_data = read_inputs(default_new_input) - data = all_data["initial_market"] +def test_calculate_initial_market( + populate_commodities, + populate_regions, + populate_timeslices, + populate_commodity_trade, + populate_commodity_costs, +): + from muse.new_input.readers import calculate_initial_market + + data = calculate_initial_market( + populate_commodities, + populate_regions, + populate_timeslices, + populate_commodity_trade, + populate_commodity_costs, + ) assert isinstance(data, xr.Dataset) assert set(data.dims) == {"region", "year", "commodity", "timeslice"} @@ -198,15 +241,15 @@ def test_new_read_initial_market(default_new_input): imports=np.float64, static_trade=np.float64, ) - assert list(data.coords["region"].values) == ["R1"] - assert list(data.coords["year"].values) == list(range(2010, 2105, 5)) - assert list(data.coords["commodity"].values) == [ + assert set(data.coords["region"].values) == {"R1"} + assert set(data.coords["year"].values) == set(range(2010, 2105, 5)) + assert set(data.coords["commodity"].values) == { "electricity", "gas", "heat", "CO2f", "wind", - ] + } month_values = ["all-year"] * 6 day_values = ["all-week"] * 6 hour_values = [ From 1a6652c21c4559329a35ec3752225340cc9ed7fb Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 19 Aug 2024 12:25:32 +0100 Subject: [PATCH 07/10] Convert timeslice id to int, fix failing test --- src/muse/new_input/readers.py | 6 +++--- tests/test_new_readers.py | 13 +++++++++---- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py index b216ba0d4..67c5dd9aa 100644 --- a/src/muse/new_input/readers.py +++ b/src/muse/new_input/readers.py @@ -41,7 +41,7 @@ def read_inputs(data_dir): def read_timeslices_csv(buffer_, con): sql = """CREATE TABLE timeslices ( - id VARCHAR PRIMARY KEY, + id BIGINT PRIMARY KEY, season VARCHAR, day VARCHAR, time_of_day VARCHAR, @@ -133,7 +133,7 @@ def read_demand_slicing_csv(buffer_, con): commodity VARCHAR REFERENCES commodities(id), region VARCHAR REFERENCES regions(id), year BIGINT, - timeslice VARCHAR REFERENCES timeslices(id), + timeslice BIGINT REFERENCES timeslices(id), fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1), PRIMARY KEY (commodity, region, year, timeslice), FOREIGN KEY (commodity, region, year) REFERENCES demand(commodity, region, year) @@ -201,7 +201,7 @@ def calculate_demand( all_commodities = commodities["id"].astype(np.dtype("str")) all_regions = regions["id"].astype(np.dtype("str")) all_years = df_demand.index.get_level_values("year").unique() - all_timeslices = timeslices["id"].astype(np.dtype("str")) + all_timeslices = timeslices["id"].astype(np.dtype("int")) # CHECK: all years are specified for each commodity/region combination check_all_values_specified(df_demand, ["commodity", "region"], "year", all_years) diff --git a/tests/test_new_readers.py b/tests/test_new_readers.py index 483d42627..68f715d55 100644 --- a/tests/test_new_readers.py +++ b/tests/test_new_readers.py @@ -57,7 +57,12 @@ def populate_demand(default_new_input, con, populate_regions, populate_commoditi @fixture def populate_demand_slicing( - default_new_input, con, populate_regions, populate_commodities, populate_demand + default_new_input, + con, + populate_regions, + populate_commodities, + populate_demand, + populate_timeslices, ): from muse.new_input.readers import read_demand_slicing_csv @@ -84,7 +89,7 @@ def populate_timeslices(default_new_input, con): def test_read_timeslices_csv(populate_timeslices): data = populate_timeslices assert len(data["id"]) == 6 - assert next(iter(data["id"])) == "1" + assert next(iter(data["id"])) == 1 assert next(iter(data["season"])) == "all" assert next(iter(data["day"])) == "all" assert next(iter(data["time_of_day"])) == "night" @@ -202,7 +207,7 @@ def test_calculate_demand( assert set(data.dims) == {"year", "commodity", "region", "timeslice"} assert list(data.coords["region"].values) == ["R1"] - assert list(data.coords["timeslice"].values) == ["1", "2", "3", "4", "5", "6"] + assert set(data.coords["timeslice"].values) == set(range(1, 7)) assert list(data.coords["year"].values) == [2020, 2050] assert set(data.coords["commodity"].values) == { "electricity", @@ -212,7 +217,7 @@ def test_calculate_demand( "CO2f", } - assert data.sel(year=2020, commodity="heat", region="R1", timeslice="1") == 1 + assert data.sel(year=2020, commodity="heat", region="R1", timeslice=1) == 1 @mark.xfail From 3de85ddea6dd672eeef748cbf895a3b529d6d67b Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 19 Aug 2024 14:13:27 +0100 Subject: [PATCH 08/10] Finish initial market reader --- src/muse/new_input/readers.py | 61 +++++++++++++++++++++++++++-------- tests/test_new_readers.py | 45 ++++++++++++-------------- 2 files changed, 68 insertions(+), 38 deletions(-) diff --git a/src/muse/new_input/readers.py b/src/muse/new_input/readers.py index 67c5dd9aa..c8833e902 100644 --- a/src/muse/new_input/readers.py +++ b/src/muse/new_input/readers.py @@ -3,6 +3,8 @@ import pandas as pd import xarray as xr +from muse.timeslices import QuantityType + def read_inputs(data_dir): data = {} @@ -42,17 +44,15 @@ def read_inputs(data_dir): def read_timeslices_csv(buffer_, con): sql = """CREATE TABLE timeslices ( id BIGINT PRIMARY KEY, - season VARCHAR, + month VARCHAR, day VARCHAR, - time_of_day VARCHAR, + hour VARCHAR, fraction DOUBLE CHECK (fraction >= 0 AND fraction <= 1), ); """ con.sql(sql) rel = con.read_csv(buffer_, header=True, delimiter=",") # noqa: F841 - con.sql( - "INSERT INTO timeslices SELECT id, season, day, time_of_day, fraction FROM rel;" - ) + con.sql("INSERT INTO timeslices SELECT id, month, day, hour, fraction FROM rel;") return con.sql("SELECT * from timeslices").fetchnumpy() @@ -278,9 +278,11 @@ def calculate_initial_market( - If price data is not specified for a commodity/region combination, the price is zero - """ - from muse.timeslices import QuantityType, convert_timeslice + Todo: + - Allow data to be specified on a timeslice level (optional) + - Interpolation, missing year field, flexible timeslice specification as above + """ # Prepare dataframes df_trade = pd.DataFrame(commodity_trade).set_index(["commodity", "region", "year"]) df_costs = ( @@ -288,7 +290,7 @@ def calculate_initial_market( .set_index(["commodity", "region", "year"]) .rename(columns={"value": "prices"}) ) - df_timeslices = pd.DataFrame(timeslices).set_index(["season", "day", "time_of_day"]) + df_timeslices = pd.DataFrame(timeslices).set_index(["month", "day", "hour"]) # DataArray dimensions all_commodities = commodities["id"].astype(np.dtype("str")) @@ -320,13 +322,17 @@ def calculate_initial_market( # Calculate static trade df_trade["static_trade"] = df_trade["export"] - df_trade["import"] - # Create Data - df_full = df_costs.join(df_trade) - data = df_full.to_xarray() - ts = df_timeslices.to_xarray()["fraction"] - ts = ts.stack(timeslice=("season", "day", "time_of_day")) - convert_timeslice(data, ts, QuantityType.EXTENSIVE) + # Create xarray datasets + xr_costs = df_costs.to_xarray() + xr_trade = df_trade.to_xarray() + + # Project over timeslices + ts = df_timeslices.to_xarray()["fraction"].stack(timeslice=("month", "day", "hour")) + xr_costs = project_timeslice(xr_costs, ts, QuantityType.EXTENSIVE) + xr_trade = project_timeslice(xr_trade, ts, QuantityType.INTENSIVE) + # Combine data + data = xr.merge([xr_costs, xr_trade]) return data @@ -353,3 +359,30 @@ def check_all_values_specified( ).all(): msg = "" # TODO raise DataValidationError(msg) + + +def project_timeslice( + data: xr.Dataset, timeslices: xr.DataArray, quantity_type: QuantityType +) -> xr.Dataset: + """Project a dataset over a new timeslice dimension. + + The projection can be done in one of two ways, depending on whether the + quantity type is extensive or intensive. See `QuantityType`. + + Args: + data: Dataset to project + timeslices: DataArray of timeslice levels, with values between 0 and 1 + representing the timeslice length (fraction of the year) + quantity_type: Type of projection to perform. QuantityType.EXTENSIVE or + QuantityType.INTENSIVE + + Returns: + Projected dataset + """ + assert "timeslice" in timeslices.dims + assert "timeslice" not in data.dims + + if quantity_type is QuantityType.INTENSIVE: + return data * timeslices + if quantity_type is QuantityType.EXTENSIVE: + return data * xr.ones_like(timeslices) diff --git a/tests/test_new_readers.py b/tests/test_new_readers.py index 68f715d55..07a3889e3 100644 --- a/tests/test_new_readers.py +++ b/tests/test_new_readers.py @@ -3,7 +3,7 @@ import duckdb import numpy as np import xarray as xr -from pytest import approx, fixture, mark, raises +from pytest import approx, fixture, raises @fixture @@ -206,9 +206,9 @@ def test_calculate_demand( assert data.dtype == np.float64 assert set(data.dims) == {"year", "commodity", "region", "timeslice"} - assert list(data.coords["region"].values) == ["R1"] + assert set(data.coords["region"].values) == {"R1"} assert set(data.coords["timeslice"].values) == set(range(1, 7)) - assert list(data.coords["year"].values) == [2020, 2050] + assert set(data.coords["year"].values) == {2020, 2050} assert set(data.coords["commodity"].values) == { "electricity", "gas", @@ -220,7 +220,6 @@ def test_calculate_demand( assert data.sel(year=2020, commodity="heat", region="R1", timeslice=1) == 1 -@mark.xfail def test_calculate_initial_market( populate_commodities, populate_regions, @@ -240,12 +239,8 @@ def test_calculate_initial_market( assert isinstance(data, xr.Dataset) assert set(data.dims) == {"region", "year", "commodity", "timeslice"} - assert dict(data.dtypes) == dict( - prices=np.float64, - exports=np.float64, - imports=np.float64, - static_trade=np.float64, - ) + for dt in data.dtypes.values(): + assert dt == np.dtype("float64") assert set(data.coords["region"].values) == {"R1"} assert set(data.coords["year"].values) == set(range(2010, 2105, 5)) assert set(data.coords["commodity"].values) == { @@ -266,28 +261,30 @@ def test_calculate_initial_market( "evening", ] - assert list(data.coords["timeslice"].values) == list( + assert set(data.coords["timeslice"].values) == set( zip(month_values, day_values, hour_values) ) - assert list(data.coords["month"]) == month_values - assert list(data.coords["day"]) == day_values - assert list(data.coords["hour"]) == hour_values + assert set(data.coords["month"].values) == set(month_values) + assert set(data.coords["day"].values) == set(day_values) + assert set(data.coords["hour"].values) == set(hour_values) assert all(var.coords.equals(data.coords) for var in data.data_vars.values()) prices = data.data_vars["prices"] - assert approx( - prices.sel( - year=2010, - region="R1", - commodity="electricity", - timeslice=("all-year", "all-week", "night"), + assert ( + approx( + prices.sel( + year=2010, + region="R1", + commodity="electricity", + timeslice=("all-year", "all-week", "night"), + ), + abs=1e-4, ) - - 14.81481, - abs=1e-4, + == 14.81481 ) - exports = data.data_vars["exports"] + exports = data.data_vars["export"] assert ( exports.sel( year=2010, @@ -297,7 +294,7 @@ def test_calculate_initial_market( ) ) == 0 - imports = data.data_vars["imports"] + imports = data.data_vars["import"] assert ( imports.sel( year=2010, From cd3d3dc39e5cb75a95c0e4f70f2ba6fd75aad18d Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 19 Aug 2024 14:24:04 +0100 Subject: [PATCH 09/10] Fix test --- tests/test_new_readers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_new_readers.py b/tests/test_new_readers.py index 07a3889e3..e7d7e31d9 100644 --- a/tests/test_new_readers.py +++ b/tests/test_new_readers.py @@ -90,9 +90,9 @@ def test_read_timeslices_csv(populate_timeslices): data = populate_timeslices assert len(data["id"]) == 6 assert next(iter(data["id"])) == 1 - assert next(iter(data["season"])) == "all" - assert next(iter(data["day"])) == "all" - assert next(iter(data["time_of_day"])) == "night" + assert next(iter(data["month"])) == "all-year" + assert next(iter(data["day"])) == "all-week" + assert next(iter(data["hour"])) == "night" assert next(iter(data["fraction"])) == approx(0.1667) From e82cb11d644987db44126b0375cb58b1c8159151 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Fri, 8 Aug 2025 12:27:43 +0100 Subject: [PATCH 10/10] Undo rebase mistake --- tests/test_readers.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_readers.py b/tests/test_readers.py index e0b4bcd63..924dcacff 100644 --- a/tests/test_readers.py +++ b/tests/test_readers.py @@ -290,3 +290,27 @@ def test_get_nan_coordinates(): dataset1 = xr.Dataset.from_dataframe(df1.set_index(["region", "year"])) nan_coords1 = get_nan_coordinates(dataset1) assert nan_coords1 == [("R1", 2021)] + + # Test 2: Missing coordinate combinations + df2 = pd.DataFrame( + { + "region": ["R1", "R1", "R2"], # Missing R2-2021 + "year": [2020, 2021, 2020], + "value": [1.0, 2.0, 3.0], + } + ) + dataset2 = xr.Dataset.from_dataframe(df2.set_index(["region", "year"])) + nan_coords2 = get_nan_coordinates(dataset2) + assert nan_coords2 == [("R2", 2021)] + + # Test 3: No NaN values + df3 = pd.DataFrame( + { + "region": ["R1", "R1", "R2", "R2"], + "year": [2020, 2021, 2020, 2021], + "value": [1.0, 2.0, 3.0, 4.0], + } + ) + dataset3 = xr.Dataset.from_dataframe(df3.set_index(["region", "year"])) + nan_coords3 = get_nan_coordinates(dataset3) + assert nan_coords3 == []