Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 35 additions & 44 deletions scripts/run_loading_benchmark_on_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from arrayloader_benchmarks import benchmark_loader, compute_spec

if TYPE_CHECKING:
from collections.abc import Iterable
from pathlib import Path


Expand All @@ -24,7 +25,7 @@ def get_datasets(
collection = benchmarking_collections.get(key=collection_key)
if n_datasets == -1:
n_datasets = collection.artifacts.count()
local_shards = [
local_paths = [
artifact.cache(
batch_size=48
) # batch_size during download shouldn't be necessary to set
Expand All @@ -34,23 +35,23 @@ def get_datasets(
artifact.n_observations
for artifact in collection.ordered_artifacts.all()[:n_datasets]
]
return local_shards, sum(n_samples_collection)
return local_paths, sum(n_samples_collection)


def run_scdataset(
local_shards: list[Path],
def get_scdataset_loader(
local_paths: list[Path],
block_size: int = 4,
fetch_factor: int = 16,
num_workers: int = 6,
batch_size: int = 4096,
n_samples: int = 2_000_000,
) -> float:
) -> Iterable:
# local imports so that it can be run without installing all dependencies
from scdataset import BlockShuffling, scDataset
from torch.utils.data import DataLoader

adata_collection = ad.experimental.AnnCollection(
[ad.read_h5ad(shard, backed="r") for shard in local_shards]
[ad.read_h5ad(shard, backed="r") for shard in local_paths]
)

def fetch_adata(collection, indices):
Expand All @@ -70,17 +71,16 @@ def fetch_adata(collection, indices):
num_workers=num_workers,
prefetch_factor=fetch_factor + 1,
)
samples_per_sec, _, _ = benchmark_loader(loader, n_samples, batch_size)
return samples_per_sec
return loader


def run_mappedcollection(
local_shards: list[Path],
def get_mappedcollection_loader(
local_paths: list[Path],
num_workers: int = 6,
batch_size: int = 4096,
n_samples: int = 2_000_000,
) -> float:
mapped_collection = ln.core.MappedCollection(local_shards, parallel=True)
) -> Iterable:
mapped_collection = ln.core.MappedCollection(local_paths, parallel=True)
loader = DataLoader(
mapped_collection,
batch_size=batch_size,
Expand All @@ -89,12 +89,11 @@ def run_mappedcollection(
worker_init_fn=mapped_collection.torch_worker_init_fn,
drop_last=True,
)
samples_per_sec, _, _ = benchmark_loader(loader, n_samples, batch_size)
return samples_per_sec
return loader


def run_annbatch(
local_shards: list[Path],
def get_annbatch_loader(
local_paths: list[Path],
chunk_size: int = 256,
preload_nchunks: int = 64,
use_torch_loader: bool = False, # noqa: FBT001, FBT002
Expand Down Expand Up @@ -138,10 +137,10 @@ def collate_fn(elems):
batch_size=1 if use_torch_loader else batch_size,
)
ds.add_datasets(
datasets=[ad.io.sparse_dataset(zarr.open(p)["X"]) for p in local_shards],
datasets=[ad.io.sparse_dataset(zarr.open(p)["X"]) for p in local_paths],
obs=[
ad.io.read_elem(zarr.open(p)["obs"])["cell_line"].to_numpy()
for p in local_shards
for p in local_paths
]
if include_obs
else None,
Expand All @@ -155,11 +154,9 @@ def collate_fn(elems):
drop_last=True,
collate_fn=collate_fn,
)
samples_per_sec, _, _ = benchmark_loader(loader, n_samples, batch_size)
return loader
else:
samples_per_sec, _, _ = benchmark_loader(ds, n_samples, batch_size)

return samples_per_sec
return ds


@click.command()
Expand Down Expand Up @@ -202,11 +199,11 @@ def run(
ln.track("LDSa3IJYQkbm", project=project)

if tool in {"MappedCollection", "scDataset"}:
local_shards, n_samples_collection = get_datasets(
local_paths, n_samples_collection = get_datasets(
collection_key=f"{collection}_h5ad", cache=True, n_datasets=n_datasets
)
else:
local_shards, n_samples_collection = get_datasets(
local_paths, n_samples_collection = get_datasets(
collection_key=f"{collection}_zarr", cache=True, n_datasets=n_datasets
)

Expand All @@ -217,8 +214,8 @@ def run(
n_samples = min(n_samples, n_samples_collection)

if tool == "annbatch":
n_samples_per_sec = run_annbatch(
local_shards,
loader = get_annbatch_loader(
local_paths,
chunk_size=chunk_size,
preload_nchunks=preload_nchunks,
use_torch_loader=use_torch_loader,
Expand All @@ -228,22 +225,25 @@ def run(
include_obs=include_obs,
)
elif tool == "MappedCollection":
n_samples_per_sec = run_mappedcollection(
local_shards,
loader = get_mappedcollection_loader(
local_paths,
num_workers=num_workers,
batch_size=batch_size,
n_samples=n_samples,
)
elif tool == "scDataset":
n_samples_per_sec = run_scdataset(
local_shards,
loader = get_scdataset_loader(
local_paths,
block_size=block_size,
fetch_factor=fetch_factor,
num_workers=num_workers,
batch_size=batch_size,
n_samples=n_samples,
)

n_samples_per_sec, _, _ = benchmark_loader(loader, n_samples, batch_size)

# collect results and parameters
new_result = {
"tool": tool,
"collection": collection,
Expand All @@ -260,23 +260,14 @@ def run(
"user": ln.setup.settings.user.handle,
}

# save new result, appending to existing results if they exist
results_key = "arrayloader_benchmarks_v2/tahoe100m_benchmark.parquet"

try:
results_af = ln.Artifact.get(key=results_key)
results_df = results_af.load()
results_df = pd.concat(
[results_df, pd.DataFrame([new_result])], ignore_index=True
)
df = ln.Artifact.get(key=results_key).load()
df = pd.concat([df, pd.DataFrame([new_result])], ignore_index=True)
except ln.Artifact.DoesNotExist:
results_df = pd.DataFrame([new_result])

ln.Artifact.from_dataframe(
results_df,
key=results_key,
description="Results of v2 of the arrayloader benchmarks",
).save()

df = pd.DataFrame([new_result])
ln.Artifact.from_dataframe(df, key=results_key).save()
ln.finish()


Expand Down
Loading