Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions deeprvat/deeprvat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def __init__(
name: METRICS[name]() for name in self.hparams.metrics["all"]
}

self.test_step_outputs = []
self.validation_step_outputs = []

self.objective_mode = self.hparams.metrics.get("objective_mode", "min")
if self.objective_mode == "max":
self.best_objective = float("-inf")
Expand Down Expand Up @@ -174,11 +177,11 @@ def validation_step(self, batch: dict, batch_idx: int):
and corresponding ground truth values ("y_by_pheno").
"""
y_by_pheno = {pheno: pheno_batch["y"] for pheno, pheno_batch in batch.items()}
return {"y_pred_by_pheno": self(batch), "y_by_pheno": y_by_pheno}
pred = {"y_pred_by_pheno": self(batch), "y_by_pheno": y_by_pheno}
self.validation_step_outputs.append(pred)
return pred

def validation_epoch_end(
self, prediction_y: List[Dict[str, Dict[str, torch.Tensor]]]
):
def on_validation_epoch_end(self):
"""
Evaluate accumulated phenotype predictions at the end of the validation epoch.

Expand All @@ -193,6 +196,7 @@ def validation_epoch_end(
:return: None
:rtype: None
"""
prediction_y = self.validation_step_outputs
y_pred_by_pheno = dict()
y_by_pheno = dict()
for result in prediction_y:
Expand Down Expand Up @@ -233,6 +237,8 @@ def validation_epoch_end(
self.best_objective, results[self.hparams.metrics["objective"]].item()
)

self.validation_step_outputs.clear() # free memory

def test_step(self, batch: dict, batch_idx: int):
"""
During testing, we do not compute backward passes, such that we can accumulate
Expand All @@ -247,16 +253,19 @@ def test_step(self, batch: dict, batch_idx: int):
and corresponding ground truth values ("y").
:rtype: dict
"""
return {"y_pred": self(batch), "y": batch["y"]}
pred = {"y_pred": self(batch), "y": batch["y"]}
self.test_step_outputs.append(pred)
return pred

def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]):
def on_test_epoch_end(self):
"""
Evaluate accumulated phenotype predictions at the end of the testing epoch.

:param prediction_y: A list of dictionaries containing accumulated phenotype predictions
and corresponding ground truth values obtained during the testing process.
:type prediction_y: List[Dict[str, Dict[str, torch.Tensor]]]
"""
prediction_y = self.test_step_outputs
y_pred = torch.cat([p["y_pred"] for p in prediction_y])
y = torch.cat([p["y"] for p in prediction_y])

Expand All @@ -269,6 +278,8 @@ def test_epoch_end(self, prediction_y: List[Dict[str, torch.Tensor]]):
self.best_objective, results[self.hparams.metrics["objective"]].item()
)

self.test_step_outputs.clear() # free memory

def configure_callbacks(self):
return [ModelSummary()]

Expand Down
17 changes: 9 additions & 8 deletions deeprvat_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,24 @@ channels:
dependencies:
- click=8.1
- cudatoolkit=11.8
- cython=0.29
- dask=2023.5
- fastparquet=0.5
- h5py=3.1
- mkl==2022.1.0
- numcodecs=0.11
- numpy=1.21
- numpy=1.24
- optuna=2.10
- pandas=1.5
- parallel=20230922
- pip=23.0.1
- plotnine=0.10.1
- pyarrow=11.0
- pyranges=0.0.129
- python=3.8
- pytorch=1.13
- pytorch-cuda=11
- pytorch-lightning=1.5
- pytorch=2.4
- pytorch-cuda=11.8
- pytorch-lightning=2.4
- pyyaml=5.4
- regenie=3.4.1
- scikit-learn=1.1
Expand All @@ -29,12 +33,9 @@ dependencies:
- snakemake=7.17
- sqlalchemy=1.4
- statsmodels=0.13
- tensorboard=2.14
- tqdm=4.59
- zarr=2.13
- Cython=0.29
- parallel=20230922
- pip=23.0.1
- plotnine=0.10.1
- pip:
- git+https://github.com/HealthML/seak@v0.4.3
- bgen==1.6.3
7 changes: 4 additions & 3 deletions deeprvat_env_no_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ dependencies:
- h5py=3.1
- mkl==2022.1.0
- numcodecs=0.11
- numpy=1.21
- numpy=1.24
- optuna=2.10
- pandas=1.5
- pyarrow=11.0
- pyranges=0.0.129
- python=3.8
- pytorch=1.13
- pytorch-lightning=1.5
- pytorch=2.4
- pytorch-lightning=2.4
- pyyaml=5.4
- regenie=3.4.1
- scikit-learn=1.1
Expand All @@ -26,6 +26,7 @@ dependencies:
- snakemake=7.17
- sqlalchemy=1.4
- statsmodels=0.13
- tensorboard=2.14
- tqdm=4.59
- zarr=2.13
- Cython=0.29
Expand Down
1 change: 0 additions & 1 deletion docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ git clone git@github.com:PMBio/deeprvat.git
1. Change directory to the repository: `cd deeprvat`
1. Install the conda environment. We recommend using [mamba](https://mamba.readthedocs.io/en/latest/index.html), though you may also replace `mamba` with `conda`

*Note: [the current deeprvat env does not support cuda when installed with conda](https://github.com/PMBio/deeprvat/issues/16). Install using mamba for cuda support.*
```shell
mamba env create -n deeprvat -f deeprvat_env.yaml
```
Expand Down
3 changes: 2 additions & 1 deletion example/config/deeprvat_input_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ seed_gene_results: #baseline_results
# DeepRVAT training settings
training:
pl_trainer: #PyTorch Lightening trainer settings
gpus: 1
accelerator: gpu
devices: 1
precision: 16
min_epochs: 50
max_epochs: 1000
Expand Down
3 changes: 2 additions & 1 deletion example/config/deeprvat_input_config_regenie.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ seed_gene_results: #baseline_results
# DeepRVAT training settings
training:
pl_trainer: #PyTorch Lightening trainer settings
gpus: 1
accelerator: gpu
devices: 1
precision: 16
min_epochs: 50
max_epochs: 1000
Expand Down
3 changes: 2 additions & 1 deletion example/config/deeprvat_input_training_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ seed_gene_results: #baseline_results
# DeepRVAT training settings
training:
pl_trainer: #PyTorch Lightening trainer settings
gpus: 1
accelerator: gpu
devices: 1
precision: 16
min_epochs: 50
max_epochs: 1000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ training:
Platelet_count: {}
pl_trainer:
check_val_every_n_epoch: 1
gpus: 0
accelerator: cpu
log_every_n_steps: 1
max_epochs: 1000
min_epochs: 50
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ training:
Apolipoprotein_B: {}
pl_trainer:
check_val_every_n_epoch: 1
gpus: 1
accelerator: gpu
devices: 1
log_every_n_steps: 1
max_epochs: 1000
min_epochs: 50
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ seed_gene_results: #baseline_results
# DeepRVAT training settings
training:
pl_trainer: #PyTorch Lightening trainer settings
gpus: 1
accelerator: gpu
devices: 1
precision: 16
min_epochs: 50
max_epochs: 1000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ training:
Apolipoprotein_B: {}
pl_trainer:
check_val_every_n_epoch: 1
gpus: 1
accelerator: gpu
devices: 1
log_every_n_steps: 1
max_epochs: 1000
min_epochs: 50
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ seed_gene_results: #baseline_results
# DeepRVAT training settings
training:
pl_trainer: #PyTorch Lightening trainer settings
gpus: 1
accelerator: gpu
devices: 1
precision: 16
min_epochs: 50
max_epochs: 1000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ training:
Apolipoprotein_B: {}
pl_trainer:
check_val_every_n_epoch: 1
gpus: 1
accelerator: gpu
devices: 1
log_every_n_steps: 1
max_epochs: 1000
min_epochs: 50
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ seed_gene_results: #baseline_results
# DeepRVAT training settings
training:
pl_trainer: #PyTorch Lightening trainer settings
gpus: 1
accelerator: gpu
devices: 1
precision: 16
min_epochs: 50
max_epochs: 1000
Expand Down
3 changes: 2 additions & 1 deletion tests/deeprvat/test_data/training/deeprvat_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ training:
Urate: {}
pl_trainer:
check_val_every_n_epoch: 1
gpus: 1
accelerator: gpu
devices: 1
log_every_n_steps: 1
max_epochs: 1000
min_epochs: 50
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ training:
Platelet_count: {}
pl_trainer:
check_val_every_n_epoch: 1
gpus: 0
accelerator: cpu
log_every_n_steps: 1
max_epochs: 1000
min_epochs: 50
Expand Down
Loading