From 49a130f7b762f68b5db546790c3fdef65f6f7968 Mon Sep 17 00:00:00 2001 From: Yuxin Chang Date: Thu, 6 Nov 2025 14:48:09 -0800 Subject: [PATCH] Add link to S2P2 camera-ready paper, and script for preprocessing EHRSHOT dataset. --- README.md | 2 +- notebooks/s2p2_preprocess_ehrshot_cpt4.ipynb | 410 +++++++++++++++++++ 2 files changed, 411 insertions(+), 1 deletion(-) create mode 100644 notebooks/s2p2_preprocess_ehrshot_cpt4.ipynb diff --git a/README.md b/README.md index 8dc987f..6b9983c 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ We provide reference implementations of various state-of-the-art TPP papers: | 6 | ICLR'20 | IntensityFree | [Intensity-Free Learning of Temporal Point Processes](https://arxiv.org/abs/1909.12127) | [PyTorch](easy_tpp/model/torch_model/torch_intensity_free.py) | | 7 | ICLR'21 | ODETPP | [Neural Spatio-Temporal Point Processes (simplified)](https://arxiv.org/abs/2011.04583) | [PyTorch](easy_tpp/model/torch_model/torch_ode_tpp.py) | | 8 | ICLR'22 | AttNHP | [Transformer Embeddings of Irregularly Spaced Events and Their Participants](https://arxiv.org/abs/2201.00044) | [PyTorch](easy_tpp/model/torch_model/torch_attnhp.py) | -| 9 | NeurIPS'25 | S2P2 | Deep Continuous-Time State-Space Models for Marked Event Sequences | [PyTorch](easy_tpp/model/torch_model/torch_s2p2.py) | +| 9 | NeurIPS'25 | S2P2 | [Deep Continuous-Time State-Space Models for Marked Event Sequences](https://openreview.net/pdf?id=74SvE2GZwW) | [PyTorch](easy_tpp/model/torch_model/torch_s2p2.py) | diff --git a/notebooks/s2p2_preprocess_ehrshot_cpt4.ipynb b/notebooks/s2p2_preprocess_ehrshot_cpt4.ipynb new file mode 100644 index 0000000..1886d09 --- /dev/null +++ b/notebooks/s2p2_preprocess_ehrshot_cpt4.ipynb @@ -0,0 +1,410 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f0c61845", + "metadata": {}, + "source": [ + "# EHRSHOT Dataset Preprocessing (S2P2 Paper)\n", + "This notebook includes code for preparing the EHRSHOT event sequence dataset from the raw [EHRSHOT dataset](https://som-shahlab.github.io/ehrshot-website/), where medical services and procedures are treated as marks, as identified by _Current Procedural Terminology_ (CPT-4) codes.\n", + "\n", + "This version of dataset was originally used in evaluating the [State-Space Point Process (S2P2)](https://openreview.net/pdf?id=74SvE2GZwW) model. Note that we cannot distribute the raw data (or derivative dataset) under the terms of the original EHRSHOT dataset. The access to data can be applied [here](https://stanford.redivis.com/datasets/53gc-8rhx41kgt)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e3fb9cd3", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from collections import defaultdict\n", + "import heapq\n", + "from tqdm import tqdm\n", + "from easy_tpp.utils import set_seed\n", + "import random\n", + "import json" + ] + }, + { + "cell_type": "markdown", + "id": "83a40695ae9eeed7", + "metadata": {}, + "source": [ + "### 0. Load data and check if it's complete" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d2d2d13a79f2268c", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T21:52:03.248612Z", + "start_time": "2025-05-22T21:51:13.046926Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/fz/4gzrrkvs2_77xs43jry5yp9w0000gn/T/ipykernel_26867/1025842674.py:3: DtypeWarning: Columns (3,5) have mixed types. Specify dtype option on import or set low_memory=False.\n", + " df_dataset = pd.read_csv(path_to_data_csv)\n" + ] + } + ], + "source": [ + "path_to_data_csv = '../data/EHRSHOT/EHRSHOT_ASSETS/data/ehrshot.csv'\n", + "path_to_splits_csv = '../data/EHRSHOT/EHRSHOT_ASSETS/splits/person_id_map.csv'\n", + "df_dataset = pd.read_csv(path_to_data_csv)\n", + "df_split = pd.read_csv(path_to_splits_csv)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "18ef70b0ec958120", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T21:52:04.952046Z", + "start_time": "2025-05-22T21:52:04.031668Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "# of events: 41661637\n", + "# of patients: 6739\n", + "# of visits: 921499\n", + "# of train patients 2295\n", + "# of val patients 2232\n", + "# of test patients 2212\n" + ] + } + ], + "source": [ + "# check if the same data as the original repo: https://github.com/som-shahlab/ehrshot-benchmark/blob/main/ehrshot/stats.ipynb\n", + "print(\"# of events:\", df_dataset.shape[0])\n", + "print(\"# of patients:\", df_dataset['patient_id'].nunique())\n", + "print(\"# of visits:\", df_dataset['visit_id'].nunique())\n", + "print(\"# of train patients\", df_split[df_split['split'] == 'train']['omop_person_id'].nunique())\n", + "print(\"# of val patients\", df_split[df_split['split'] == 'val']['omop_person_id'].nunique())\n", + "print(\"# of test patients\", df_split[df_split['split'] == 'test']['omop_person_id'].nunique())" + ] + }, + { + "cell_type": "markdown", + "id": "bd3aa280bbe0d3e", + "metadata": {}, + "source": [ + "### 1. Get event times for visit occurrence" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "55e9a82ed0910e84", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T21:55:43.462174Z", + "start_time": "2025-05-22T21:55:40.272675Z" + } + }, + "outputs": [], + "source": [ + "df_visit = df_dataset[df_dataset['omop_table'] == 'visit_occurrence']\n", + "df_visit.loc[:, 'start'] = pd.to_datetime(df_visit['start']).apply(lambda x: int(round(x.timestamp())))\n", + "df_visit_time = df_visit[['patient_id', 'start']].drop_duplicates(keep=False)\n", + "df_visit_time = df_visit_time.groupby(['patient_id'])['start'].apply(lambda x: sorted(list(set(x)))).reset_index(name='timestamp')\n", + "visit_dict = pd.Series(df_visit_time.timestamp.values, index=df_visit_time.patient_id).to_dict()\n", + "patient_visit = df_visit_time['patient_id'].to_numpy()" + ] + }, + { + "cell_type": "markdown", + "id": "7eddf9ac67745f1a", + "metadata": {}, + "source": [ + "### 2. Get CPT4 codes that have at least 100 frequencies" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "23d14c81fec5f549", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T21:59:11.017201Z", + "start_time": "2025-05-22T21:59:00.594692Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of marks after filtering: 668\n" + ] + } + ], + "source": [ + "df_cpt4 = df_dataset[df_dataset['code'].str.contains('CPT4', case=False, na=False)]\n", + "mark_val, mark_count = np.unique(df_cpt4.loc[:,'code'].to_numpy(), return_counts=True)\n", + "\n", + "mark_mask = (mark_count >= 100)\n", + "print(f'Number of marks after filtering: {sum(mark_mask)}')" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "370a230ac0d5213", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T21:59:12.675372Z", + "start_time": "2025-05-22T21:59:11.025996Z" + } + }, + "outputs": [], + "source": [ + "mark_val = mark_val[mark_mask]\n", + "mark_val_set = set(mark_val)\n", + "df_cpt4_subset = df_cpt4[df_cpt4['code'].isin(mark_val_set)][['patient_id', 'start', 'code']]\n", + "df_cpt4_subset['start'] = pd.to_datetime(df_cpt4_subset.loc[:,'start']).apply(lambda x: int(round(x.timestamp())))\n", + "df_cpt4_subset['code'] = df_cpt4_subset['code'].astype('category').cat.codes\n", + "mark_val_subset, mark_count_subset = np.unique(df_cpt4_subset.loc[:,'code'].to_numpy(), return_counts=True)\n", + "mark_count_dict = dict(zip(mark_val_subset, mark_count_subset))\n", + "patient_cpt4 = df_cpt4_subset['patient_id'].unique()" + ] + }, + { + "cell_type": "markdown", + "id": "10020d5895b3afc9", + "metadata": {}, + "source": [ + "### 3. Generate event sequences" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a744ef372bd6dd72", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T21:59:37.489196Z", + "start_time": "2025-05-22T21:59:37.485658Z" + } + }, + "outputs": [], + "source": [ + "def sample_event_times(real_event_time, std, size):\n", + " sampled_times = np.random.normal(real_event_time, scale=std, size=size)\n", + " # resample if not all non-negative, might be updated\n", + " while not np.all(sampled_times > 0):\n", + " sampled_times = np.random.normal(real_event_time, scale=std, size=size)\n", + " return sampled_times" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "5ff4de761c2847c6", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T22:01:05.315820Z", + "start_time": "2025-05-22T22:00:56.872143Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 6634/6634 [00:08<00:00, 786.48it/s] " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6183\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "time_norm = 60 * 60 # in seconds\n", + "min_events = 5\n", + "max_marks_per_time = 10\n", + "padding_events = 668\n", + "all_sequences = []\n", + "idx = 0\n", + "set_seed(123)\n", + "\n", + "\n", + "for patient in tqdm(patient_cpt4):\n", + " patient = int(patient)\n", + " data = df_cpt4_subset[df_cpt4_subset.patient_id == patient]\n", + " if len(data) < 5 or len(data.start.unique()) < 2:\n", + " continue\n", + " events = list(zip(data['start'], data['code']))\n", + " sorted_unique_times = sorted(data.start.unique())\n", + " if not len(np.diff(sorted_unique_times)):\n", + " print(len(data))\n", + " print(len(events))\n", + " min_diff = min(np.diff(sorted_unique_times)) # minimum time between two consecutive events\n", + "\n", + " event_dict = defaultdict(list)\n", + " base_time = int(sorted_unique_times[0])\n", + " for t, m in events:\n", + " event_dict[(t - base_time)/time_norm].append(m)\n", + "\n", + " std = min(min_diff/time_norm, 1) / 10 # std. for Normal distribution to jitter event times\n", + " event_times = []\n", + " event_marks = []\n", + " for t in sorted_unique_times:\n", + " t = (t - base_time)/time_norm\n", + " v = event_dict[t]\n", + " if len(v) > max_marks_per_time: # choose mark by frequency\n", + " h = []\n", + " for mark in v:\n", + " if len(h) < max_marks_per_time:\n", + " heapq.heappush(h, (mark_count_dict[mark], mark))\n", + " else:\n", + " heapq.heappushpop(h, (mark_count_dict[mark], mark))\n", + " v = [x[1] for x in h]\n", + " else:\n", + " np.random.shuffle(v)\n", + "\n", + " sampled_times = sample_event_times(t, std, min(max_marks_per_time, len(v)) - 1)\n", + " times = sorted([t] + list(sampled_times))\n", + " times = [float(t) for t in times]\n", + " event_times.extend(times)\n", + " event_marks.extend(v)\n", + " assert len(v) <= max_marks_per_time\n", + " assert(len(event_times) == len(event_marks))\n", + " assert(min_events <= len(event_times))\n", + "\n", + " # padding the start and end of sequences to have padding events\n", + " event_marks[0] = padding_events\n", + " event_marks[-1] = padding_events\n", + "\n", + " all_sequences.append(\n", + " {\n", + " 'dim_process': padding_events,\n", + " 'seq_idx': idx,\n", + " 'seq_len': len(event_times),\n", + " 'time_since_start': event_times,\n", + " 'time_since_last_event': [0] + [event_times[i+1] - event_times[i] for i in range(len(event_times) - 1)],\n", + " 'type_event': event_marks,\n", + " }\n", + " )\n", + " idx += 1\n", + "print(len(all_sequences)) # 6183" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "61b6d934ea56003b", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T22:01:15.774095Z", + "start_time": "2025-05-22T22:01:15.766903Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "test: 927\n", + "valid: 927\n", + "train: 4329\n" + ] + } + ], + "source": [ + "test_pct, valid_pct, train_pct = 0.15, 0.15, 0.7\n", + "test_seqs, valid_seqs, train_seqs = [], [], []\n", + "\n", + "random.shuffle(all_sequences)\n", + "for i, seq in enumerate(all_sequences):\n", + " progress = (i + 1) / len(all_sequences)\n", + " if progress <= test_pct:\n", + " test_seqs.append(seq)\n", + " elif progress <= test_pct + valid_pct:\n", + " valid_seqs.append(seq)\n", + " else:\n", + " train_seqs.append(seq)\n", + "\n", + "print(f'test: {len(test_seqs)}')\n", + "print(f'valid: {len(valid_seqs)}')\n", + "print(f'train: {len(train_seqs)}')" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "f9c1d7894a5792bb", + "metadata": { + "ExecuteTime": { + "end_time": "2025-05-22T22:02:34.872771Z", + "start_time": "2025-05-22T22:02:32.829332Z" + } + }, + "outputs": [], + "source": [ + "# # Save results\n", + "# with open('./ehrshot_cpt4/train.json', 'w') as f:\n", + "# json.dump(train_seqs, f)\n", + "#\n", + "# with open('./ehrshot_cpt4/dev.json', 'w') as f:\n", + "# json.dump(valid_seqs, f)\n", + "#\n", + "# with open('./ehrshot_cpt4/test.json', 'w') as f:\n", + "# json.dump(test_seqs, f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60c3455349b9fb3d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "easytpp", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}