-
Notifications
You must be signed in to change notification settings - Fork 3
Description
What I've hinted at in the two recent PRs could look like something like this, @felix0097.
Essentially rather than 3 scripts and one notebook just create all results end-to-end in a single script, which can then be run with different param combinations & easily be reviewed/understood & iterated on by anyone. The results dataframe could be as I drafted below and visualized in all kinds of ways:
results_af = ln.Artifact.get(key="dataloader_v2_benchmark_results/tahoe100m_benchmark.parquet")
results_df = results_af.load()
new_results = pd.DataFrame({
"data loader": ["arrayloaders", "MappedCollection", "scDataset"],
"samples/sec": [arrayloaders_w_obs, mapped_collection, sc_dataset],
"collection_size_in_mio": 3*[100],
"chunk_size": 3*[chunk_size],
"run_uid": 3*[ln.context.run.uid],
"timestamp": 3*[datetime.datetime.now(datetime.UTC)]
....
})
results_df.append(new_results)
ln.Artifact.from_dataframe(
results,
key="dataloader_v2_benchmark_results/tahoe100m_benchmark.parquet",
description="Results of the dataloder v2 benchmarks"
).save()Because we'll run each function sequentially garbage collection should be fine; but one could also create prod data via calling single methods if one is worried about this.
python run_tahoe100_data_loader_benchmarks.py --method annbatch
I just added together the snippets from your existing scripts -- with AI this could be finalized quickly.
I'd also be happy to do this if you agree this is cleaner overall. It'll also show nice data lineage and could enable feeding a collection uid as a parameter to any benchmark of the benchmarks for different collection sizes.
# run_tahoe100_data_loader_benchmarks.py
def run_scdataset(
block_size: int = 4,
fetch_factor: int = 16,
num_workers: int = 6,
batch_size: int = 4096,
n_samples: int = 2_000_000,
):
from scdataset import BlockShuffling, scDataset # local import so that it can be run without installing all dependencies
from torch.utils.data import DataLoader
benchmarking_collections = ln.Collection.using("laminlabs/arrayloader-benchmarks")
h5ad_shards = benchmarking_collections.get("eAgoduHMxuDs5Wem0000").cache()
adata_collection = ad.experimental.AnnCollection(
[ad.read_h5ad(shard, backed="r") for shard in h5ad_shards]
)
def fetch_adata(collection, indices):
return collection[indices].X
strategy = BlockShuffling(block_size=block_size)
dataset = scDataset(
adata_collection,
strategy,
batch_size=batch_size,
fetch_factor=fetch_factor,
fetch_callback=fetch_adata,
)
loader = DataLoader(
dataset,
batch_size=None,
num_workers=num_workers,
prefetch_factor=fetch_factor + 1,
)
samples_per_sec, _, _ = benchmark_loader(loader, n_samples, batch_size)
return return samples_per_sec
def run_mapped_collection(
num_workers: int = 6,
batch_size: int = 4096,
n_samples: int = 2_000_000,
):
benchmarking_collections = ln.Collection.using("laminlabs/arrayloader-benchmarks")
h5ad_shards = benchmarking_collections.get("eAgoduHMxuDs5Wem0000").cache()
mapped_collection = MappedCollection(h5ad_shards)
loader = DataLoader(
mapped_collection,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
samples_per_sec, _, _ = benchmark_loader(loader, n_samples, batch_size)
return samples_per_sec
def run_annbatch( # noqa: PLR0917
chunk_size: int = 256,
preload_nchunks: int = 64,
use_torch_loader: bool = False, # noqa: FBT001, FBT002
num_workers: int = 6,
batch_size: int = 4096,
n_samples: int = 2_000_000,
include_obs: bool = True, # noqa: FBT001, FBT002
):
benchmarking_collections = ln.Collection.using("laminlabs/arrayloader-benchmarks")
collection = benchmarking_collections.get("LaJOdLd0xZ3v5ZBw0000")
store_shards = [
artifact.cache(batch_size=48) for artifact in collection.ordered_artifacts.all()
]
ds = ZarrSparseDataset(
shuffle=True,
chunk_size=chunk_size,
preload_nchunks=preload_nchunks,
batch_size=1 if use_torch_loader else batch_size,
)
ds.add_datasets(
datasets=[ad.io.sparse_dataset(zarr.open(p)["X"]) for p in store_shards],
obs=[
ad.io.read_elem(zarr.open(p)["obs"])["cell_line"].to_numpy()
for p in store_shards
]
if include_obs
else None,
)
n_samples = n_samples if n_samples != -1 else len(ds)
if use_torch_loader:
loader = DataLoader(
ds,
batch_size=batch_size,
num_workers=num_workers,
drop_last=True,
collate_fn=collate_fn,
)
samples_per_sec, _, _ = benchmark_loader(loader, n_samples, batch_size)
else:
samples_per_sec, _, _ = benchmark_loader(ds, n_samples, batch_size)
return samples_per_sec
@click.command()
@click.option("--method", type=Literal["all", "annbatch", "MappedCollection", "scDataset")
@click.option("--chunk_size", type=int, default=256)
@click.option("--preload_nchunks", type=int, default=8)
@click.option("--use_torch_loader", type=bool, default=True)
@click.option("--num_workers", type=int, default=6)
@click.option("--batch_size", type=int, default=4096)
@click.option("--n_samples", type=int, default=2_000_000)
@click.option("--include_obs", type=bool, default=True)
def benchmark( # noqa: PLR0917
method: ...,
chunk_size: int = 256,
preload_nchunks: int = 64,
use_torch_loader: bool = False, # noqa: FBT001, FBT002
num_workers: int = 6,
batch_size: int = 4096,
n_samples: int = 2_000_000,
include_obs: bool = True, # noqa: FBT001, FBT002
):
# ln.save(ln.Feature.from_dict(locals())) # only needed a single time to define valid params
ln.track(project="zjQ6EYzMXif4", params=locals()) # param tracking optional
if method == "all" or mehod == "annbatch":
n_samples_per_sec = run_annbatch(...)
# etc.
results_af = ln.Artifact.get(key="dataloader_v2_benchmark_results/tahoe100m_benchmark.parquet")
results_df = results_af.load()
new_results = pd.DataFrame({
"data loader": ["arrayloaders", "MappedCollection", "scDataset"],
"samples/sec": [arrayloaders_w_obs, mapped_collection, sc_dataset],
"collection_size_in_mio": [100, 100, 100],
"chunk_size": [chunk_size, chunk_size, chunk_size],
"run_uid": [ln.context.run.uid, ln.context.run.uid, ln.context.run.uid],
"timestamp": [...]
....
})
results_df.append(new_results)
ln.Artifact.from_dataframe(
results,
key="dataloader_v2_benchmark_results/tahoe100m_benchmark.parquet",
description="Results of the dataloder v2 benchmarks"
).save()