diff --git a/scripts/run_loading_benchmark_on_collection.py b/scripts/run_loading_benchmark_on_collection.py index 571d385..0bc8d2d 100644 --- a/scripts/run_loading_benchmark_on_collection.py +++ b/scripts/run_loading_benchmark_on_collection.py @@ -14,6 +14,7 @@ from arrayloader_benchmarks import benchmark_loader, compute_spec if TYPE_CHECKING: + from collections.abc import Iterable from pathlib import Path @@ -24,7 +25,7 @@ def get_datasets( collection = benchmarking_collections.get(key=collection_key) if n_datasets == -1: n_datasets = collection.artifacts.count() - local_shards = [ + local_paths = [ artifact.cache( batch_size=48 ) # batch_size during download shouldn't be necessary to set @@ -34,23 +35,23 @@ def get_datasets( artifact.n_observations for artifact in collection.ordered_artifacts.all()[:n_datasets] ] - return local_shards, sum(n_samples_collection) + return local_paths, sum(n_samples_collection) -def run_scdataset( - local_shards: list[Path], +def get_scdataset_loader( + local_paths: list[Path], block_size: int = 4, fetch_factor: int = 16, num_workers: int = 6, batch_size: int = 4096, n_samples: int = 2_000_000, -) -> float: +) -> Iterable: # local imports so that it can be run without installing all dependencies from scdataset import BlockShuffling, scDataset from torch.utils.data import DataLoader adata_collection = ad.experimental.AnnCollection( - [ad.read_h5ad(shard, backed="r") for shard in local_shards] + [ad.read_h5ad(shard, backed="r") for shard in local_paths] ) def fetch_adata(collection, indices): @@ -70,17 +71,16 @@ def fetch_adata(collection, indices): num_workers=num_workers, prefetch_factor=fetch_factor + 1, ) - samples_per_sec, _, _ = benchmark_loader(loader, n_samples, batch_size) - return samples_per_sec + return loader -def run_mappedcollection( - local_shards: list[Path], +def get_mappedcollection_loader( + local_paths: list[Path], num_workers: int = 6, batch_size: int = 4096, n_samples: int = 2_000_000, -) -> float: - mapped_collection = ln.core.MappedCollection(local_shards, parallel=True) +) -> Iterable: + mapped_collection = ln.core.MappedCollection(local_paths, parallel=True) loader = DataLoader( mapped_collection, batch_size=batch_size, @@ -89,12 +89,11 @@ def run_mappedcollection( worker_init_fn=mapped_collection.torch_worker_init_fn, drop_last=True, ) - samples_per_sec, _, _ = benchmark_loader(loader, n_samples, batch_size) - return samples_per_sec + return loader -def run_annbatch( - local_shards: list[Path], +def get_annbatch_loader( + local_paths: list[Path], chunk_size: int = 256, preload_nchunks: int = 64, use_torch_loader: bool = False, # noqa: FBT001, FBT002 @@ -138,10 +137,10 @@ def collate_fn(elems): batch_size=1 if use_torch_loader else batch_size, ) ds.add_datasets( - datasets=[ad.io.sparse_dataset(zarr.open(p)["X"]) for p in local_shards], + datasets=[ad.io.sparse_dataset(zarr.open(p)["X"]) for p in local_paths], obs=[ ad.io.read_elem(zarr.open(p)["obs"])["cell_line"].to_numpy() - for p in local_shards + for p in local_paths ] if include_obs else None, @@ -155,11 +154,9 @@ def collate_fn(elems): drop_last=True, collate_fn=collate_fn, ) - samples_per_sec, _, _ = benchmark_loader(loader, n_samples, batch_size) + return loader else: - samples_per_sec, _, _ = benchmark_loader(ds, n_samples, batch_size) - - return samples_per_sec + return ds @click.command() @@ -202,11 +199,11 @@ def run( ln.track("LDSa3IJYQkbm", project=project) if tool in {"MappedCollection", "scDataset"}: - local_shards, n_samples_collection = get_datasets( + local_paths, n_samples_collection = get_datasets( collection_key=f"{collection}_h5ad", cache=True, n_datasets=n_datasets ) else: - local_shards, n_samples_collection = get_datasets( + local_paths, n_samples_collection = get_datasets( collection_key=f"{collection}_zarr", cache=True, n_datasets=n_datasets ) @@ -217,8 +214,8 @@ def run( n_samples = min(n_samples, n_samples_collection) if tool == "annbatch": - n_samples_per_sec = run_annbatch( - local_shards, + loader = get_annbatch_loader( + local_paths, chunk_size=chunk_size, preload_nchunks=preload_nchunks, use_torch_loader=use_torch_loader, @@ -228,15 +225,15 @@ def run( include_obs=include_obs, ) elif tool == "MappedCollection": - n_samples_per_sec = run_mappedcollection( - local_shards, + loader = get_mappedcollection_loader( + local_paths, num_workers=num_workers, batch_size=batch_size, n_samples=n_samples, ) elif tool == "scDataset": - n_samples_per_sec = run_scdataset( - local_shards, + loader = get_scdataset_loader( + local_paths, block_size=block_size, fetch_factor=fetch_factor, num_workers=num_workers, @@ -244,6 +241,9 @@ def run( n_samples=n_samples, ) + n_samples_per_sec, _, _ = benchmark_loader(loader, n_samples, batch_size) + + # collect results and parameters new_result = { "tool": tool, "collection": collection, @@ -260,23 +260,14 @@ def run( "user": ln.setup.settings.user.handle, } + # save new result, appending to existing results if they exist results_key = "arrayloader_benchmarks_v2/tahoe100m_benchmark.parquet" - try: - results_af = ln.Artifact.get(key=results_key) - results_df = results_af.load() - results_df = pd.concat( - [results_df, pd.DataFrame([new_result])], ignore_index=True - ) + df = ln.Artifact.get(key=results_key).load() + df = pd.concat([df, pd.DataFrame([new_result])], ignore_index=True) except ln.Artifact.DoesNotExist: - results_df = pd.DataFrame([new_result]) - - ln.Artifact.from_dataframe( - results_df, - key=results_key, - description="Results of v2 of the arrayloader benchmarks", - ).save() - + df = pd.DataFrame([new_result]) + ln.Artifact.from_dataframe(df, key=results_key).save() ln.finish()