Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions dataaug_platform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,24 @@
local_aug,
global_aug,
)
from .pipeline import Pipeline
from .dataset import SparkIterableDataset
from .ingestion import load_hdf5_group, hdf5_to_rdd, read_hdf5_metadata, write_trajectories_to_hdf5
from .pipeline import Pipeline

__all__ = [
# augmentation
"Augmentation",
"local_aug",
"global_aug",
"Pipeline",
# dataset
"SparkIterableDataset",
# ingestion
"load_hdf5_group",
"hdf5_to_rdd",
"read_hdf5_metadata",
"write_trajectories_to_hdf5",
# Pipeline
"Pipeline",
]

# Optional imports - only available if mimicgen dependencies are installed
Expand Down
52 changes: 52 additions & 0 deletions dataaug_platform/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
from torch.utils.data import IterableDataset

from .pipeline import Pipeline

class SparkIterableDataset(IterableDataset):
def __init__(
self,
spark_pipeline: Pipeline,
to_tensor: bool = True,
infinite: bool = False,
):
super().__init__()
self.spark_pipeline = spark_pipeline
self.to_tensor = to_tensor
self.infinite = infinite

def _convert(self, item):
"""Convert Spark output item to torch.Tensor/dict-of-tensors if needed."""
if not self.to_tensor:
return item

if isinstance(item, torch.Tensor):
return item

if isinstance(item, dict):
out = {}
for k, v in item.items():
if isinstance(v, torch.Tensor):
out[k] = v
else:
out[k] = torch.tensor(v, dtype=torch.float32)
return out

if isinstance(item, (list, tuple)):
return torch.tensor(item, dtype=torch.float32)

if isinstance(item, (int, float)):
return torch.tensor([item], dtype=torch.float32)

return item

def __iter__(self):
if self.infinite:
while True:
rdd = self.spark_pipeline.run()
for item in rdd.toLocalIterator():
yield self._convert(item)
else:
rdd = self.spark_pipeline.run()
for item in rdd.toLocalIterator():
yield self._convert(item)
66 changes: 59 additions & 7 deletions dataaug_platform/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,82 @@
from abc import ABC, abstractmethod
from typing import Optional

from pyspark.sql import SparkSession
from .augmentations.base_augmentation import Augmentation

from .augmentations.base_augmentation import Augmentation

class Pipeline:
"""Manages a sequence of augmentations."""

def __init__(self, spark=None):
def __init__(self, spark: Optional[SparkSession] = None):
self.spark = (
spark or SparkSession.builder.appName("TrajectoryPipeline").getOrCreate()
)
self.augmentations = []
self._base_rdd = None # stored input RDD (for set_data-style use)

def __enter__(self):
return self

def add(self, aug: Augmentation):
def __exit__(self, exc_type, exc_val, exc_tb):
if self.spark is not None:
self.spark.stop()

def add(self, aug: "Augmentation"):
"""Add an augmentation to the pipeline."""
self.augmentations.append(aug)
return self # enable chaining

def run(self, data):
"""Run all augmentations sequentially."""
def set_data(self, data, cache=False):
"""
Set the base data for the pipeline.

`data` can be a Python list or an existing RDD.
If `cache=True`, the RDD will be cached in memory.
"""
sc = self.spark.sparkContext
if not hasattr(data, "context"): # convert list ? RDD if needed

# list / iterable -> parallelize, RDD -> keep
if hasattr(data, "context"): # looks like an RDD
rdd = data
else:
rdd = sc.parallelize(data)

if cache:
rdd = rdd.cache()

self._base_rdd = rdd
return self

def run(self, data=None, use_stored_if_no_data=True):
"""
Run all augmentations sequentially.

- If `data` is provided:
behaves like the old version: converts list -> RDD if needed, does NOT
modify the stored base data.
- If `data` is None and `use_stored_if_no_data` is True:
uses the RDD set via `set_data(...)`.
"""
sc = self.spark.sparkContext

if data is not None:
# old behavior: convert to RDD if needed
if hasattr(data, "context"): # RDD
rdd = data
else: # list / iterable
rdd = sc.parallelize(data)
else:
rdd = data
if not use_stored_if_no_data:
raise ValueError("No data passed to run() and use_stored_if_no_data=False.")
if self._base_rdd is None:
raise ValueError(
"No data passed to run() and no data set via set_data(). "
"Call run(data=...) or set_data(...) first."
)
rdd = self._base_rdd

# apply augmentations
for aug in self.augmentations:
rdd = aug._apply_rdd(rdd)

Expand Down
137 changes: 137 additions & 0 deletions examples/explore_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
Example demonstrating how to use user-defined augmentation classes.

This shows how to define augmentation classes with @local_aug and @global_aug decorators.
"""

from torch.utils.data import DataLoader

from pyspark.sql import SparkSession
from dataaug_platform import Augmentation, local_aug, global_aug, Pipeline, SparkIterableDataset


# ============================================================================
# Example 1: User-defined local augmentation class
# ============================================================================

class AddOffsetAugmentation(Augmentation):
"""Add a constant offset to all numeric values."""

def __init__(self, offset=1.0):
self.offset = offset

@local_aug
def apply(self, traj):
"""Process one trajectory at a time."""
import numpy as np

modified = traj.copy()
for key, value in modified.items():
if isinstance(value, np.ndarray):
modified[key] = value + self.offset
return modified


# ============================================================================
# Example 2: User-defined global augmentation class
# ============================================================================

class AverageTrajectoriesAugmentation(Augmentation):
"""Create a new trajectory by averaging all trajectories."""

def __init__(self, times=1, keep_original=True):
"""
Initialize the augmentation.

Args:
times: Number of times to run this augmentation (default: 1).
Each run processes the whole dataset and produces new trajectories.
Multiple runs are parallelized using Spark.
keep_original: Whether to keep original trajectories in output (default: True).
If False, output contains only augmented trajectories.
"""
super().__init__(times=times, keep_original=keep_original)

@global_aug
def apply(self, trajs):
"""Process all trajectories together."""
import numpy as np

if not trajs:
return []

# Average all trajectories
avg_traj = {}
for key in trajs[0].keys():
values = [traj[key] for traj in trajs if key in traj]
if values and isinstance(values[0], np.ndarray):
avg_traj[key] = np.mean(values, axis=0)
else:
avg_traj[key] = values[0] if values else None

return [avg_traj]

# ============================================================================
# Example: Using the pipeline with class-based augmentations
# ============================================================================

def example():
"""Demonstrates the class-based augmentation style."""

spark = SparkSession.builder.appName("AugmentationExample").getOrCreate()
pipeline = Pipeline(spark)

# Add user-defined augmentation classes
pipeline.add(AddOffsetAugmentation(offset=2.0))
# Run global augmentation 3 times in parallel using Spark
# keep_original=True (default): output has 5 (original) + 3 (augmented) = 8 trajectories
pipeline.add(AverageTrajectoriesAugmentation(times=3))

# Example with keep_original=False: output has only 5 augmented trajectories
# pipeline.add(AverageTrajectoriesAugmentation(times=3, keep_original=False))

# Sample data: list of trajectory dictionaries
sample_data = [
{"x": [1, 2, 3], "y": [4, 5, 6]},
{"x": [2, 3, 4], "y": [5, 6, 7]},
{"x": [3, 4, 5], "y": [6, 7, 8]},
{"x": [7, 8, 9], "y": [10, 11, 12]},
{"x": [11, 12, 13], "y": [14, 15, 16]},
]

pipeline.set_data(sample_data)

print("===== Finite Dataset =====")
spark_dataset = SparkIterableDataset(pipeline)

spark_dataloader = DataLoader(spark_dataset, batch_size=6, num_workers=1)

for i, data in enumerate(spark_dataloader):
print(f'[{i}] {data=}')

print("===== Infinite Dataset =====")
num_batches = 4
batch_count = 0

print(f"Iteration count {num_batches}")

spark_dataset = SparkIterableDataset(pipeline, infinite=True)

spark_dataloader = DataLoader(spark_dataset, batch_size=6, num_workers=1)

for i, data in enumerate(spark_dataloader):
if batch_count == num_batches:
print("Iteration count reached")
break

print(f'[{i}] {data=}')
batch_count += 1

spark.stop()


if __name__ == "__main__":
print("=" * 60)
print("Example: Class-based augmentation style")
print("=" * 60)
example()