Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ezyrb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
__all__ = [
'Database', 'Snapshot', 'Reduction', 'POD', 'Approximation', 'RBF', 'Linear', 'GPR',
'ANN', 'KNeighborsRegressor', 'RadiusNeighborsRegressor', 'AE',
'ReducedOrderModel', 'PODAE', 'RegularGrid'
'ReducedOrderModel', 'PODAE', 'RegularGrid',
'MultiReducedOrderModel'
]

from .database import Database
from .snapshot import Snapshot
from .parameter import Parameter
from .reducedordermodel import ReducedOrderModel
from .reducedordermodel import ReducedOrderModel, MultiReducedOrderModel
from .reduction import *
from .approximation import *
from .regular_grid import RegularGrid
5 changes: 4 additions & 1 deletion ezyrb/approximation/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ def _build_model(self, points, values):
layers.insert(0, points.shape[1])
layers.append(values.shape[1])

self.model = self._list_to_sequential(layers, self.function)
if self.model is None:
self.model = self._list_to_sequential(layers, self.function)
else:
self.model = self.model

def fit(self, points, values):
"""
Expand Down
43 changes: 39 additions & 4 deletions ezyrb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,33 @@ class Database():
None meaning no scaling.
:param array_like space: the input spatial data
"""
def __init__(self, parameters=None, snapshots=None):
def __init__(self, parameters=None, snapshots=None, space=None):
self._pairs = []

if parameters is None and snapshots is None:
return

if parameters is None:
parameters = [None] * len(snapshots)
elif snapshots is None:
snapshots = [None] * len(parameters)

if len(parameters) != len(snapshots):
raise ValueError
raise ValueError('parameters and snapshots must have the same length')

for param, snap in zip(parameters, snapshots):
self.add(Parameter(param), Snapshot(snap))
param = Parameter(param)
if isinstance(space, dict):
snap_space = space.get(tuple(param.values), None)
# print('snap_space', snap_space)
else:
snap_space = space
snap = Snapshot(snap, space=snap_space)

self.add(param, snap)

# TODO: eventually improve the `space` assignment in the snapshots,
# snapshots can have different space coordinates

@property
def parameters_matrix(self):
Expand Down Expand Up @@ -74,7 +90,9 @@ def __len__(self):

def __str__(self):
""" Print minimal info about the Database """
return str(self.parameters_matrix)
s = 'Database with {} snapshots and {} parameters'.format(
self.snapshots_matrix.shape[1], self.parameters_matrix.shape[1])
return s

def add(self, parameter, snapshot):
"""
Expand Down Expand Up @@ -103,6 +121,10 @@ def split(self, chunks, seed=None):
>>> train, test = db.split([80, 20]) # n snapshots

"""

if seed is not None:
np.random.seed(seed)

if all(isinstance(n, int) for n in chunks):
if sum(chunks) != len(self):
raise ValueError('chunk elements are inconsistent')
Expand All @@ -118,6 +140,7 @@ def split(self, chunks, seed=None):
if not np.isclose(sum(chunks), 1.):
raise ValueError('chunk elements are inconsistent')


cum_chunks = np.cumsum(chunks)
cum_chunks = np.insert(cum_chunks, 0, 0.0)
ids = np.ones(len(self)) * -1.
Expand All @@ -137,3 +160,15 @@ def split(self, chunks, seed=None):
new_database[i].add(p, s)

return new_database

def get_snapshot_space(self, index):
"""
Get the space coordinates of a snapshot by its index.

:param int index: The index of the snapshot.
:return: The space coordinates of the snapshot.
:rtype: numpy.ndarray
"""
if index < 0 or index >= len(self._pairs):
raise IndexError("Snapshot index out of range.")
return self._pairs[index][1].space
8 changes: 6 additions & 2 deletions ezyrb/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
class Parameter:

def __init__(self, values):
self.values = values
if isinstance(values, Parameter):
self.values = values.values
else:
self.values = values

@property
def values(self):
Expand All @@ -15,4 +18,5 @@ def values(self):
def values(self, new_values):
if np.asarray(new_values).ndim != 1:
raise ValueError('only 1D array are usable as parameter.')
self._values = new_values

self._values = np.asarray(new_values)
6 changes: 6 additions & 0 deletions ezyrb/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
'DatabaseScaler',
'ShiftSnapshots',
'AutomaticShiftSnapshots',
'Aggregation',
'DatabaseSplitter',
'DatabaseDictionarySplitter'
]

from .scaler import DatabaseScaler
from .plugin import Plugin
from .shift import ShiftSnapshots
from .automatic_shift import AutomaticShiftSnapshots
from .aggregation import Aggregation
from .database_splitter import DatabaseSplitter
from .database_splitter import DatabaseDictionarySplitter
Loading
Loading